refactor lu + docs

This commit is contained in:
Andrei Kislitsyn 2021-04-30 11:08:22 +03:00
parent cba62a9468
commit 09f0a2879e
4 changed files with 23 additions and 23 deletions

View File

@ -42,8 +42,7 @@ fun main () {
// solve `Ax = b` system using LUP decomposition
// get P, L, U such that PA = LU
val (lu, pivots) = a.lu()
val (p, l, u) = luPivot(lu, pivots)
val (p, l, u) = a.lu()
// check that P is permutation matrix
println("P:\n$p")

View File

@ -60,23 +60,17 @@ public interface LinearOpsTensorAlgebra<T> :
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>>
/**
* TODO('Andrew')
* For more information: https://pytorch.org/docs/stable/generated/torch.lu.html
* LUP decomposition
*
* @return ...
*/
public fun TensorStructure<T>.lu(): Pair<TensorStructure<T>, TensorStructure<Int>>
/**
* TODO('Andrew')
* For more information: https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
* Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* @param luTensor ...
* @param pivotsTensor ...
* @return ...
* * @return triple of P, L and U tensors
*/
public fun luPivot(luTensor: TensorStructure<T>, pivotsTensor: TensorStructure<Int>):
Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
public fun TensorStructure<T>.lu(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
/**
* Singular Value Decomposition.

View File

@ -29,13 +29,13 @@ public class DoubleLinearOpsTensorAlgebra :
override fun TensorStructure<Double>.det(): DoubleTensor = detLU(1e-9)
public fun TensorStructure<Double>.lu(epsilon: Double): Pair<DoubleTensor, IntTensor> =
public fun TensorStructure<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) ?:
throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon")
override fun TensorStructure<Double>.lu(): Pair<DoubleTensor, IntTensor> = lu(1e-9)
public fun TensorStructure<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
override fun luPivot(
public fun luPivot(
luTensor: TensorStructure<Double>,
pivotsTensor: TensorStructure<Int>
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
@ -156,7 +156,7 @@ public class DoubleLinearOpsTensorAlgebra :
}
public fun TensorStructure<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = lu(epsilon)
val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike()
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
@ -167,6 +167,15 @@ public class DoubleLinearOpsTensorAlgebra :
return invTensor
}
public fun TensorStructure<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = this.luFactor(epsilon)
return luPivot(lu, pivots)
}
override fun TensorStructure<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
}
public inline fun <R> DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R =

View File

@ -124,9 +124,7 @@ class TestDoubleLinearOpsTensorAlgebra {
)
val tensor = fromArray(shape, buffer)
val (lu, pivots) = tensor.lu()
val (p, l, u) = luPivot(lu, pivots)
val (p, l, u) = tensor.lu()
assertTrue { p.shape contentEquals shape }
assertTrue { l.shape contentEquals shape }