refactor lu + docs

This commit is contained in:
Andrei Kislitsyn 2021-04-30 11:08:22 +03:00
parent cba62a9468
commit 09f0a2879e
4 changed files with 23 additions and 23 deletions

View File

@ -42,8 +42,7 @@ fun main () {
// solve `Ax = b` system using LUP decomposition // solve `Ax = b` system using LUP decomposition
// get P, L, U such that PA = LU // get P, L, U such that PA = LU
val (lu, pivots) = a.lu() val (p, l, u) = a.lu()
val (p, l, u) = luPivot(lu, pivots)
// check that P is permutation matrix // check that P is permutation matrix
println("P:\n$p") println("P:\n$p")

View File

@ -60,23 +60,17 @@ public interface LinearOpsTensorAlgebra<T> :
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>> public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>>
/** /**
* TODO('Andrew') * LUP decomposition
* For more information: https://pytorch.org/docs/stable/generated/torch.lu.html
* *
* @return ... * Computes the LUP decomposition of a matrix or a batch of matrices.
*/ * Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
public fun TensorStructure<T>.lu(): Pair<TensorStructure<T>, TensorStructure<Int>> * with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
/** * `U` being an upper triangular matrix or batch of matrices.
* TODO('Andrew')
* For more information: https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
* *
* @param luTensor ... * * @return triple of P, L and U tensors
* @param pivotsTensor ...
* @return ...
*/ */
public fun luPivot(luTensor: TensorStructure<T>, pivotsTensor: TensorStructure<Int>): public fun TensorStructure<T>.lu(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
/** /**
* Singular Value Decomposition. * Singular Value Decomposition.

View File

@ -29,13 +29,13 @@ public class DoubleLinearOpsTensorAlgebra :
override fun TensorStructure<Double>.det(): DoubleTensor = detLU(1e-9) override fun TensorStructure<Double>.det(): DoubleTensor = detLU(1e-9)
public fun TensorStructure<Double>.lu(epsilon: Double): Pair<DoubleTensor, IntTensor> = public fun TensorStructure<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) ?: computeLU(tensor, epsilon) ?:
throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon") throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon")
override fun TensorStructure<Double>.lu(): Pair<DoubleTensor, IntTensor> = lu(1e-9) public fun TensorStructure<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
override fun luPivot( public fun luPivot(
luTensor: TensorStructure<Double>, luTensor: TensorStructure<Double>,
pivotsTensor: TensorStructure<Int> pivotsTensor: TensorStructure<Int>
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { ): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
@ -156,7 +156,7 @@ public class DoubleLinearOpsTensorAlgebra :
} }
public fun TensorStructure<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor { public fun TensorStructure<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = lu(epsilon) val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike() val invTensor = luTensor.zeroesLike()
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence()) val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
@ -167,6 +167,15 @@ public class DoubleLinearOpsTensorAlgebra :
return invTensor return invTensor
} }
public fun TensorStructure<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = this.luFactor(epsilon)
return luPivot(lu, pivots)
}
override fun TensorStructure<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
} }
public inline fun <R> DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R = public inline fun <R> DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R =

View File

@ -124,9 +124,7 @@ class TestDoubleLinearOpsTensorAlgebra {
) )
val tensor = fromArray(shape, buffer) val tensor = fromArray(shape, buffer)
val (lu, pivots) = tensor.lu() val (p, l, u) = tensor.lu()
val (p, l, u) = luPivot(lu, pivots)
assertTrue { p.shape contentEquals shape } assertTrue { p.shape contentEquals shape }
assertTrue { l.shape contentEquals shape } assertTrue { l.shape contentEquals shape }