diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 1fd46bd57..f854beb29 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -5,8 +5,10 @@ package space.kscience.kmath.tensors.core +import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D +import space.kscience.kmath.structures.indices import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.Tensor @@ -813,28 +815,32 @@ public open class DoubleTensorAlgebra : val sTensor = zeros(commonShape + intArrayOf(min(n, m))) val vTensor = zeros(commonShape + intArrayOf(min(n, m), m)) - tensor.matrixSequence() - .zip( - uTensor.matrixSequence() - .zip( - sTensor.vectorSequence() - .zip(vTensor.matrixSequence()) - ) - ).forEach { (matrix, USV) -> - val matrixSize = matrix.shape.reduce { acc, i -> acc * i } - val curMatrix = DoubleTensor( - matrix.shape, - matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize) - .toDoubleArray() - ) - svdHelper(curMatrix, USV, m, n, epsilon) - } + val matrices = tensor.matrices + val uTensors = uTensor.matrices + val sTensorVectors = sTensor.vectors + val vTensors = vTensor.matrices + + for (index in matrices.indices) { + val matrix = matrices[index] + val usv = Triple( + uTensors[index], + sTensorVectors[index], + vTensors[index] + ) + val matrixSize = matrix.shape.reduce { acc, i -> acc * i } + val curMatrix = DoubleTensor( + matrix.shape, + matrix.mutableBuffer.array() + .slice(matrix.bufferStart until matrix.bufferStart + matrixSize) + .toDoubleArray() + ) + svdHelper(curMatrix, usv, m, n, epsilon) + } return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) } - override fun Tensor.symEig(): Pair = - symEig(epsilon = 1e-15) + override fun Tensor.symEig(): Pair = symEig(epsilon = 1e-15) /** * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, @@ -846,12 +852,26 @@ public open class DoubleTensorAlgebra : */ public fun Tensor.symEig(epsilon: Double): Pair { checkSymmetric(tensor, epsilon) + + fun MutableStructure2D.cleanSym(n: Int) { + for (i in 0 until n) { + for (j in 0 until n) { + if (i == j) { + this[i, j] = sign(this[i, j]) + } else { + this[i, j] = 0.0 + } + } + } + } + val (u, s, v) = tensor.svd(epsilon) val shp = s.shape + intArrayOf(1) val utv = u.transpose() dot v val n = s.shape.last() - for (matrix in utv.matrixSequence()) - cleanSymHelper(matrix.as2D(), n) + for (matrix in utv.matrixSequence()) { + matrix.as2D().cleanSym(n) + } val eig = (utv dot s.view(shp)).view(s.shape) return eig to v diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt index 7d3617547..0434bf96f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt @@ -10,41 +10,54 @@ import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D import space.kscience.kmath.operations.invoke -import space.kscience.kmath.tensors.core.* +import space.kscience.kmath.structures.VirtualBuffer +import space.kscience.kmath.structures.asSequence +import space.kscience.kmath.tensors.core.BufferedTensor +import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensorAlgebra -import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.valueOrNull +import space.kscience.kmath.tensors.core.IntTensor import kotlin.math.abs import kotlin.math.min -import kotlin.math.sign import kotlin.math.sqrt +internal val BufferedTensor.vectors: VirtualBuffer> + get() { + val n = shape.size + val vectorOffset = shape[n - 1] + val vectorShape = intArrayOf(shape.last()) -internal fun BufferedTensor.vectorSequence(): Sequence> = sequence { - val n = shape.size - val vectorOffset = shape[n - 1] - val vectorShape = intArrayOf(shape.last()) - for (offset in 0 until numElements step vectorOffset) { - val vector = BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset) - yield(vector) + return VirtualBuffer(numElements / vectorOffset) { index -> + val offset = index * vectorOffset + BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset) + } } -} -internal fun BufferedTensor.matrixSequence(): Sequence> = sequence { - val n = shape.size - check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" } - val matrixOffset = shape[n - 1] * shape[n - 2] - val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) - for (offset in 0 until numElements step matrixOffset) { - val matrix = BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset) - yield(matrix) + +internal fun BufferedTensor.vectorSequence(): Sequence> = vectors.asSequence() + +/** + * A random access alternative to [matrixSequence] + */ +internal val BufferedTensor.matrices: VirtualBuffer> + get() { + val n = shape.size + check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" } + val matrixOffset = shape[n - 1] * shape[n - 2] + val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) + + return VirtualBuffer(numElements / matrixOffset) { index -> + val offset = index * matrixOffset + BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset) + } } -} + +internal fun BufferedTensor.matrixSequence(): Sequence> = matrices.asSequence() internal fun dotHelper( a: MutableStructure2D, b: MutableStructure2D, res: MutableStructure2D, - l: Int, m: Int, n: Int + l: Int, m: Int, n: Int, ) { for (i in 0 until l) { for (j in 0 until n) { @@ -60,7 +73,7 @@ internal fun dotHelper( internal fun luHelper( lu: MutableStructure2D, pivots: MutableStructure1D, - epsilon: Double + epsilon: Double, ): Boolean { val m = lu.rowNum @@ -122,7 +135,7 @@ internal fun BufferedTensor.setUpPivots(): IntTensor { internal fun DoubleTensorAlgebra.computeLU( tensor: DoubleTensor, - epsilon: Double + epsilon: Double, ): Pair? { checkSquareMatrix(tensor.shape) @@ -139,7 +152,7 @@ internal fun DoubleTensorAlgebra.computeLU( internal fun pivInit( p: MutableStructure2D, pivot: MutableStructure1D, - n: Int + n: Int, ) { for (i in 0 until n) { p[i, pivot[i]] = 1.0 @@ -150,7 +163,7 @@ internal fun luPivotHelper( l: MutableStructure2D, u: MutableStructure2D, lu: MutableStructure2D, - n: Int + n: Int, ) { for (i in 0 until n) { for (j in 0 until n) { @@ -170,7 +183,7 @@ internal fun luPivotHelper( internal fun choleskyHelper( a: MutableStructure2D, l: MutableStructure2D, - n: Int + n: Int, ) { for (i in 0 until n) { for (j in 0 until i) { @@ -200,7 +213,7 @@ internal fun luMatrixDet(lu: MutableStructure2D, pivots: MutableStructur internal fun luMatrixInv( lu: MutableStructure2D, pivots: MutableStructure1D, - invMatrix: MutableStructure2D + invMatrix: MutableStructure2D, ) { val m = lu.shape[0] @@ -227,7 +240,7 @@ internal fun luMatrixInv( internal fun DoubleTensorAlgebra.qrHelper( matrix: DoubleTensor, q: DoubleTensor, - r: MutableStructure2D + r: MutableStructure2D, ) { checkSquareMatrix(matrix.shape) val n = matrix.shape[0] @@ -280,12 +293,11 @@ internal fun DoubleTensorAlgebra.svd1d(a: DoubleTensor, epsilon: Double = 1e-10) internal fun DoubleTensorAlgebra.svdHelper( matrix: DoubleTensor, - USV: Pair, Pair, BufferedTensor>>, - m: Int, n: Int, epsilon: Double + USV: Triple, BufferedTensor, BufferedTensor>, + m: Int, n: Int, epsilon: Double, ) { val res = ArrayList>(0) - val (matrixU, SV) = USV - val (matrixS, matrixV) = SV + val (matrixU, matrixS, matrixV) = USV for (k in 0 until min(n, m)) { var a = matrix.copy() @@ -329,14 +341,3 @@ internal fun DoubleTensorAlgebra.svdHelper( matrixV.mutableBuffer.array()[matrixV.bufferStart + i] = vBuffer[i] } } - -internal fun cleanSymHelper(matrix: MutableStructure2D, n: Int) { - for (i in 0 until n) - for (j in 0 until n) { - if (i == j) { - matrix[i, j] = sign(matrix[i, j]) - } else { - matrix[i, j] = 0.0 - } - } -}