qr fix
This commit is contained in:
parent
daa0777182
commit
1f19ac88ae
@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
|||||||
public fun TensorType.cholesky(): TensorType
|
public fun TensorType.cholesky(): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||||
public fun TensorType.qr(): TensorType
|
public fun TensorType.qr(): Pair<TensorType, TensorType>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||||
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.linear.Matrix
|
||||||
import space.kscience.kmath.nd.MutableStructure2D
|
import space.kscience.kmath.nd.MutableStructure2D
|
||||||
import space.kscience.kmath.nd.Structure1D
|
import space.kscience.kmath.nd.Structure1D
|
||||||
import space.kscience.kmath.nd.Structure2D
|
import space.kscience.kmath.nd.Structure2D
|
||||||
|
import space.kscience.kmath.nd.asND
|
||||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@ -141,7 +143,15 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
return lTensor
|
return lTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.qr(): DoubleTensor {
|
private fun matrixQR(
|
||||||
|
matrix: Structure2D<Double>,
|
||||||
|
q: MutableStructure2D<Double>,
|
||||||
|
r: MutableStructure2D<Double>
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||||
TODO("ANDREI")
|
TODO("ANDREI")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,6 +164,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
|
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
|
||||||
|
// todo check
|
||||||
val m = lu.shape[0]
|
val m = lu.shape[0]
|
||||||
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||||
@ -184,6 +195,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
pivots: Structure1D<Int>,
|
pivots: Structure1D<Int>,
|
||||||
invMatrix : MutableStructure2D<Double>
|
invMatrix : MutableStructure2D<Double>
|
||||||
): Unit {
|
): Unit {
|
||||||
|
//todo check
|
||||||
val m = lu.shape[0]
|
val m = lu.shape[0]
|
||||||
|
|
||||||
for (j in 0 until m) {
|
for (j in 0 until m) {
|
||||||
|
Loading…
Reference in New Issue
Block a user