refactor linops
This commit is contained in:
parent
1695fc5075
commit
e2c7751c7e
@ -53,14 +53,15 @@ public object DoubleLinearOpsTensorAlgebra :
|
||||
val lTensor = luTensor.zeroesLike()
|
||||
val uTensor = luTensor.zeroesLike()
|
||||
|
||||
for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence())
|
||||
.zip(luTensor.tensor.matrixSequence())) {
|
||||
lTensor.matrixSequence()
|
||||
.zip(uTensor.matrixSequence())
|
||||
.zip(luTensor.tensor.matrixSequence())
|
||||
.forEach { (pairLU, lu) ->
|
||||
val (l, u) = pairLU
|
||||
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
|
||||
}
|
||||
|
||||
return Triple(pTensor, lTensor, uTensor)
|
||||
|
||||
}
|
||||
|
||||
public fun TensorStructure<Double>.cholesky(epsilon: Double): DoubleTensor {
|
||||
@ -82,11 +83,13 @@ public object DoubleLinearOpsTensorAlgebra :
|
||||
checkSquareMatrix(shape)
|
||||
val qTensor = zeroesLike()
|
||||
val rTensor = zeroesLike()
|
||||
val seq = tensor.matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence())))
|
||||
for ((matrix, qr) in seq) {
|
||||
tensor.matrixSequence()
|
||||
.zip((qTensor.matrixSequence()
|
||||
.zip(rTensor.matrixSequence()))).forEach { (matrix, qr) ->
|
||||
val (q, r) = qr
|
||||
qrHelper(matrix.asTensor(), q.asTensor(), r.as2D())
|
||||
}
|
||||
|
||||
return qTensor to rTensor
|
||||
}
|
||||
|
||||
@ -97,20 +100,24 @@ public object DoubleLinearOpsTensorAlgebra :
|
||||
val size = tensor.linearStructure.dim
|
||||
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
||||
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
||||
val resU = zeros(commonShape + intArrayOf(min(n, m), n))
|
||||
val resS = zeros(commonShape + intArrayOf(min(n, m)))
|
||||
val resV = zeros(commonShape + intArrayOf(min(n, m), m))
|
||||
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
|
||||
val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
|
||||
val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
|
||||
|
||||
for ((matrix, USV) in tensor.matrixSequence()
|
||||
.zip(resU.matrixSequence().zip(resS.vectorSequence().zip(resV.matrixSequence())))) {
|
||||
tensor.matrixSequence()
|
||||
.zip(uTensor.matrixSequence()
|
||||
.zip(sTensor.vectorSequence()
|
||||
.zip(vTensor.matrixSequence()))).forEach { (matrix, USV) ->
|
||||
val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
|
||||
val curMatrix = DoubleTensor(
|
||||
matrix.shape,
|
||||
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize).toDoubleArray()
|
||||
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
|
||||
.toDoubleArray()
|
||||
)
|
||||
svdHelper(curMatrix, USV, m, n, epsilon)
|
||||
}
|
||||
return Triple(resU.transpose(), resS, resV.transpose())
|
||||
|
||||
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
|
||||
}
|
||||
|
||||
override fun TensorStructure<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
|
||||
@ -127,7 +134,7 @@ public object DoubleLinearOpsTensorAlgebra :
|
||||
cleanSymHelper(matrix.as2D(), n)
|
||||
|
||||
val eig = (utv dot s.view(shp)).view(s.shape)
|
||||
return Pair(eig, v)
|
||||
return eig to v
|
||||
}
|
||||
|
||||
public fun TensorStructure<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||
|
Loading…
Reference in New Issue
Block a user