Inv and det #262
@ -13,7 +13,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||
public fun TensorType.qr(): TensorType
|
||||
|
||||
//htt ps://pytorch.org/docs/stable/generated/torch.lu.html
|
||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||
|
@ -10,9 +10,9 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||
DoubleTensorAlgebra() {
|
||||
|
||||
override fun DoubleTensor.inv(): DoubleTensor {
|
||||
TODO("ANDREI")
|
||||
}
|
||||
override fun DoubleTensor.inv(): DoubleTensor = invLU()
|
||||
|
||||
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
||||
|
||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||
|
||||
@ -156,15 +156,11 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
|
||||
val m = lu.shape[0]
|
||||
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||
var det = sign
|
||||
for (i in 0 until m){
|
||||
det *= lu[i, i]
|
||||
}
|
||||
return det
|
||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||
}
|
||||
|
||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||
val (luTensor, pivotsTensor) = this.lu()
|
||||
val (luTensor, pivotsTensor) = lu()
|
||||
val n = shape.size
|
||||
|
||||
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
||||
@ -211,14 +207,12 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
public fun DoubleTensor.invLU(): DoubleTensor {
|
||||
val (luTensor, pivotsTensor) = this.lu()
|
||||
val (luTensor, pivotsTensor) = lu()
|
||||
val n = shape.size
|
||||
val invTensor = luTensor.zeroesLike()
|
||||
|
||||
for (
|
||||
(luP, invMatrix) in
|
||||
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||
) {
|
||||
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||
for ((luP, invMatrix) in seq) {
|
||||
val (lu, pivots) = luP
|
||||
luMatrixInv(lu, pivots, invMatrix)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user