refactor linops

This commit is contained in:
Andrei Kislitsyn 2021-04-30 19:45:31 +03:00
parent 1695fc5075
commit e2c7751c7e

View File

@ -53,14 +53,15 @@ public object DoubleLinearOpsTensorAlgebra :
val lTensor = luTensor.zeroesLike()
val uTensor = luTensor.zeroesLike()
for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence())
.zip(luTensor.tensor.matrixSequence())) {
lTensor.matrixSequence()
.zip(uTensor.matrixSequence())
.zip(luTensor.tensor.matrixSequence())
.forEach { (pairLU, lu) ->
val (l, u) = pairLU
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
}
return Triple(pTensor, lTensor, uTensor)
}
public fun TensorStructure<Double>.cholesky(epsilon: Double): DoubleTensor {
@ -82,11 +83,13 @@ public object DoubleLinearOpsTensorAlgebra :
checkSquareMatrix(shape)
val qTensor = zeroesLike()
val rTensor = zeroesLike()
val seq = tensor.matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence())))
for ((matrix, qr) in seq) {
tensor.matrixSequence()
.zip((qTensor.matrixSequence()
.zip(rTensor.matrixSequence()))).forEach { (matrix, qr) ->
val (q, r) = qr
qrHelper(matrix.asTensor(), q.asTensor(), r.as2D())
}
return qTensor to rTensor
}
@ -97,20 +100,24 @@ public object DoubleLinearOpsTensorAlgebra :
val size = tensor.linearStructure.dim
val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
val resU = zeros(commonShape + intArrayOf(min(n, m), n))
val resS = zeros(commonShape + intArrayOf(min(n, m)))
val resV = zeros(commonShape + intArrayOf(min(n, m), m))
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
for ((matrix, USV) in tensor.matrixSequence()
.zip(resU.matrixSequence().zip(resS.vectorSequence().zip(resV.matrixSequence())))) {
tensor.matrixSequence()
.zip(uTensor.matrixSequence()
.zip(sTensor.vectorSequence()
.zip(vTensor.matrixSequence()))).forEach { (matrix, USV) ->
val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
val curMatrix = DoubleTensor(
matrix.shape,
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize).toDoubleArray()
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
.toDoubleArray()
)
svdHelper(curMatrix, USV, m, n, epsilon)
}
return Triple(resU.transpose(), resS, resV.transpose())
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
}
override fun TensorStructure<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
@ -127,7 +134,7 @@ public object DoubleLinearOpsTensorAlgebra :
cleanSymHelper(matrix.as2D(), n)
val eig = (utv dot s.view(shp)).view(s.shape)
return Pair(eig, v)
return eig to v
}
public fun TensorStructure<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {