This commit is contained in:
Andrei Kislitsyn 2021-03-24 18:42:41 +03:00
parent 96a755071f
commit ab81370001
2 changed files with 9 additions and 15 deletions

View File

@ -13,7 +13,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
public fun TensorType.qr(): TensorType
//htt ps://pytorch.org/docs/stable/generated/torch.lu.html
//https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html

View File

@ -10,9 +10,9 @@ public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
DoubleTensorAlgebra() {
override fun DoubleTensor.inv(): DoubleTensor {
TODO("ANDREI")
}
override fun DoubleTensor.inv(): DoubleTensor = invLU()
override fun DoubleTensor.det(): DoubleTensor = detLU()
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
@ -156,15 +156,11 @@ public class DoubleLinearOpsTensorAlgebra :
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
val m = lu.shape[0]
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
var det = sign
for (i in 0 until m){
det *= lu[i, i]
}
return det
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
}
public fun DoubleTensor.detLU(): DoubleTensor {
val (luTensor, pivotsTensor) = this.lu()
val (luTensor, pivotsTensor) = lu()
val n = shape.size
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
@ -211,14 +207,12 @@ public class DoubleLinearOpsTensorAlgebra :
}
public fun DoubleTensor.invLU(): DoubleTensor {
val (luTensor, pivotsTensor) = this.lu()
val (luTensor, pivotsTensor) = lu()
val n = shape.size
val invTensor = luTensor.zeroesLike()
for (
(luP, invMatrix) in
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
) {
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
for ((luP, invMatrix) in seq) {
val (lu, pivots) = luP
luMatrixInv(lu, pivots, invMatrix)
}