forked from kscience/kmath
qr fix
This commit is contained in:
parent
daa0777182
commit
1f19ac88ae
@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
||||
public fun TensorType.cholesky(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||
public fun TensorType.qr(): TensorType
|
||||
public fun TensorType.qr(): Pair<TensorType, TensorType>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||
|
@ -1,8 +1,10 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.linear.Matrix
|
||||
import space.kscience.kmath.nd.MutableStructure2D
|
||||
import space.kscience.kmath.nd.Structure1D
|
||||
import space.kscience.kmath.nd.Structure2D
|
||||
import space.kscience.kmath.nd.asND
|
||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||
import kotlin.math.sqrt
|
||||
|
||||
@ -141,7 +143,15 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
return lTensor
|
||||
}
|
||||
|
||||
override fun DoubleTensor.qr(): DoubleTensor {
|
||||
private fun matrixQR(
|
||||
matrix: Structure2D<Double>,
|
||||
q: MutableStructure2D<Double>,
|
||||
r: MutableStructure2D<Double>
|
||||
) {
|
||||
|
||||
}
|
||||
|
||||
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||
TODO("ANDREI")
|
||||
}
|
||||
|
||||
@ -154,6 +164,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
|
||||
// todo check
|
||||
val m = lu.shape[0]
|
||||
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||
@ -184,6 +195,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
pivots: Structure1D<Int>,
|
||||
invMatrix : MutableStructure2D<Double>
|
||||
): Unit {
|
||||
//todo check
|
||||
val m = lu.shape[0]
|
||||
|
||||
for (j in 0 until m) {
|
||||
|
Loading…
Reference in New Issue
Block a user