Inv and det #262
@ -13,7 +13,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
|||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||||
public fun TensorType.qr(): TensorType
|
public fun TensorType.qr(): TensorType
|
||||||
|
|
||||||
//htt ps://pytorch.org/docs/stable/generated/torch.lu.html
|
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||||
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||||
|
@ -10,9 +10,9 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||||
DoubleTensorAlgebra() {
|
DoubleTensorAlgebra() {
|
||||||
|
|
||||||
override fun DoubleTensor.inv(): DoubleTensor {
|
override fun DoubleTensor.inv(): DoubleTensor = invLU()
|
||||||
TODO("ANDREI")
|
|
||||||
}
|
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
||||||
|
|
||||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||||
|
|
||||||
@ -156,15 +156,11 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
|
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
|
||||||
val m = lu.shape[0]
|
val m = lu.shape[0]
|
||||||
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||||
var det = sign
|
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||||
for (i in 0 until m){
|
|
||||||
det *= lu[i, i]
|
|
||||||
}
|
|
||||||
return det
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||||
val (luTensor, pivotsTensor) = this.lu()
|
val (luTensor, pivotsTensor) = lu()
|
||||||
val n = shape.size
|
val n = shape.size
|
||||||
|
|
||||||
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
||||||
@ -211,14 +207,12 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.invLU(): DoubleTensor {
|
public fun DoubleTensor.invLU(): DoubleTensor {
|
||||||
val (luTensor, pivotsTensor) = this.lu()
|
val (luTensor, pivotsTensor) = lu()
|
||||||
val n = shape.size
|
val n = shape.size
|
||||||
val invTensor = luTensor.zeroesLike()
|
val invTensor = luTensor.zeroesLike()
|
||||||
|
|
||||||
for (
|
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||||
(luP, invMatrix) in
|
for ((luP, invMatrix) in seq) {
|
||||||
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
|
||||||
) {
|
|
||||||
val (lu, pivots) = luP
|
val (lu, pivots) = luP
|
||||||
luMatrixInv(lu, pivots, invMatrix)
|
luMatrixInv(lu, pivots, invMatrix)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user