diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt index f551d524a..e412af7c6 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt @@ -13,7 +13,7 @@ public interface LinearOpsTensorAlgebra, Inde //https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr public fun TensorType.qr(): TensorType - //htt ps://pytorch.org/docs/stable/generated/torch.lu.html + //https://pytorch.org/docs/stable/generated/torch.lu.html public fun TensorType.lu(): Pair //https://pytorch.org/docs/stable/generated/torch.lu_unpack.html diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index c8c4c413f..0b9d824f0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -10,9 +10,9 @@ public class DoubleLinearOpsTensorAlgebra : LinearOpsTensorAlgebra, DoubleTensorAlgebra() { - override fun DoubleTensor.inv(): DoubleTensor { - TODO("ANDREI") - } + override fun DoubleTensor.inv(): DoubleTensor = invLU() + + override fun DoubleTensor.det(): DoubleTensor = detLU() override fun DoubleTensor.lu(): Pair { @@ -156,15 +156,11 @@ public class DoubleLinearOpsTensorAlgebra : private fun luMatrixDet(lu: Structure2D, pivots: Structure1D): Double { val m = lu.shape[0] val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0 - var det = sign - for (i in 0 until m){ - det *= lu[i, i] - } - return det + return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right } } public fun DoubleTensor.detLU(): DoubleTensor { - val (luTensor, pivotsTensor) = this.lu() + val (luTensor, pivotsTensor) = lu() val n = shape.size val detTensorShape = IntArray(n - 1) { i -> shape[i] } @@ -211,14 +207,12 @@ public class DoubleLinearOpsTensorAlgebra : } public fun DoubleTensor.invLU(): DoubleTensor { - val (luTensor, pivotsTensor) = this.lu() + val (luTensor, pivotsTensor) = lu() val n = shape.size val invTensor = luTensor.zeroesLike() - for ( - (luP, invMatrix) in - luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence()) - ) { + val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence()) + for ((luP, invMatrix) in seq) { val (lu, pivots) = luP luMatrixInv(lu, pivots, invMatrix) }