refactor linops

This commit is contained in:
Andrei Kislitsyn 2021-04-30 19:45:31 +03:00
parent 1695fc5075
commit e2c7751c7e

View File

@ -53,14 +53,15 @@ public object DoubleLinearOpsTensorAlgebra :
val lTensor = luTensor.zeroesLike() val lTensor = luTensor.zeroesLike()
val uTensor = luTensor.zeroesLike() val uTensor = luTensor.zeroesLike()
for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence()) lTensor.matrixSequence()
.zip(luTensor.tensor.matrixSequence())) { .zip(uTensor.matrixSequence())
.zip(luTensor.tensor.matrixSequence())
.forEach { (pairLU, lu) ->
val (l, u) = pairLU val (l, u) = pairLU
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n) luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
} }
return Triple(pTensor, lTensor, uTensor) return Triple(pTensor, lTensor, uTensor)
} }
public fun TensorStructure<Double>.cholesky(epsilon: Double): DoubleTensor { public fun TensorStructure<Double>.cholesky(epsilon: Double): DoubleTensor {
@ -82,11 +83,13 @@ public object DoubleLinearOpsTensorAlgebra :
checkSquareMatrix(shape) checkSquareMatrix(shape)
val qTensor = zeroesLike() val qTensor = zeroesLike()
val rTensor = zeroesLike() val rTensor = zeroesLike()
val seq = tensor.matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence()))) tensor.matrixSequence()
for ((matrix, qr) in seq) { .zip((qTensor.matrixSequence()
.zip(rTensor.matrixSequence()))).forEach { (matrix, qr) ->
val (q, r) = qr val (q, r) = qr
qrHelper(matrix.asTensor(), q.asTensor(), r.as2D()) qrHelper(matrix.asTensor(), q.asTensor(), r.as2D())
} }
return qTensor to rTensor return qTensor to rTensor
} }
@ -97,20 +100,24 @@ public object DoubleLinearOpsTensorAlgebra :
val size = tensor.linearStructure.dim val size = tensor.linearStructure.dim
val commonShape = tensor.shape.sliceArray(0 until size - 2) val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size) val (n, m) = tensor.shape.sliceArray(size - 2 until size)
val resU = zeros(commonShape + intArrayOf(min(n, m), n)) val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
val resS = zeros(commonShape + intArrayOf(min(n, m))) val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
val resV = zeros(commonShape + intArrayOf(min(n, m), m)) val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
for ((matrix, USV) in tensor.matrixSequence() tensor.matrixSequence()
.zip(resU.matrixSequence().zip(resS.vectorSequence().zip(resV.matrixSequence())))) { .zip(uTensor.matrixSequence()
.zip(sTensor.vectorSequence()
.zip(vTensor.matrixSequence()))).forEach { (matrix, USV) ->
val matrixSize = matrix.shape.reduce { acc, i -> acc * i } val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
val curMatrix = DoubleTensor( val curMatrix = DoubleTensor(
matrix.shape, matrix.shape,
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize).toDoubleArray() matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
.toDoubleArray()
) )
svdHelper(curMatrix, USV, m, n, epsilon) svdHelper(curMatrix, USV, m, n, epsilon)
} }
return Triple(resU.transpose(), resS, resV.transpose())
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
} }
override fun TensorStructure<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = override fun TensorStructure<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
@ -127,7 +134,7 @@ public object DoubleLinearOpsTensorAlgebra :
cleanSymHelper(matrix.as2D(), n) cleanSymHelper(matrix.as2D(), n)
val eig = (utv dot s.view(shp)).view(s.shape) val eig = (utv dot s.view(shp)).view(s.shape)
return Pair(eig, v) return eig to v
} }
public fun TensorStructure<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor { public fun TensorStructure<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {