refactor lu + docs
This commit is contained in:
parent
cba62a9468
commit
09f0a2879e
@ -42,8 +42,7 @@ fun main () {
|
||||
// solve `Ax = b` system using LUP decomposition
|
||||
|
||||
// get P, L, U such that PA = LU
|
||||
val (lu, pivots) = a.lu()
|
||||
val (p, l, u) = luPivot(lu, pivots)
|
||||
val (p, l, u) = a.lu()
|
||||
|
||||
// check that P is permutation matrix
|
||||
println("P:\n$p")
|
||||
|
@ -60,23 +60,17 @@ public interface LinearOpsTensorAlgebra<T> :
|
||||
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>>
|
||||
|
||||
/**
|
||||
* TODO('Andrew')
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||
* LUP decomposition
|
||||
*
|
||||
* @return ...
|
||||
*/
|
||||
public fun TensorStructure<T>.lu(): Pair<TensorStructure<T>, TensorStructure<Int>>
|
||||
|
||||
/**
|
||||
* TODO('Andrew')
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
||||
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
|
||||
* with `P` being a permutation matrix or batch of matrices,
|
||||
* `L` being a lower triangular matrix or batch of matrices,
|
||||
* `U` being an upper triangular matrix or batch of matrices.
|
||||
*
|
||||
* @param luTensor ...
|
||||
* @param pivotsTensor ...
|
||||
* @return ...
|
||||
* * @return triple of P, L and U tensors
|
||||
*/
|
||||
public fun luPivot(luTensor: TensorStructure<T>, pivotsTensor: TensorStructure<Int>):
|
||||
Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
||||
public fun TensorStructure<T>.lu(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
||||
|
||||
/**
|
||||
* Singular Value Decomposition.
|
||||
|
@ -29,13 +29,13 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
|
||||
override fun TensorStructure<Double>.det(): DoubleTensor = detLU(1e-9)
|
||||
|
||||
public fun TensorStructure<Double>.lu(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
||||
public fun TensorStructure<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
||||
computeLU(tensor, epsilon) ?:
|
||||
throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon")
|
||||
|
||||
override fun TensorStructure<Double>.lu(): Pair<DoubleTensor, IntTensor> = lu(1e-9)
|
||||
public fun TensorStructure<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
|
||||
|
||||
override fun luPivot(
|
||||
public fun luPivot(
|
||||
luTensor: TensorStructure<Double>,
|
||||
pivotsTensor: TensorStructure<Int>
|
||||
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
@ -156,7 +156,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
public fun TensorStructure<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||
val (luTensor, pivotsTensor) = lu(epsilon)
|
||||
val (luTensor, pivotsTensor) = luFactor(epsilon)
|
||||
val invTensor = luTensor.zeroesLike()
|
||||
|
||||
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||
@ -167,6 +167,15 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
|
||||
return invTensor
|
||||
}
|
||||
|
||||
public fun TensorStructure<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
val (lu, pivots) = this.luFactor(epsilon)
|
||||
return luPivot(lu, pivots)
|
||||
}
|
||||
|
||||
override fun TensorStructure<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
|
||||
|
||||
|
||||
}
|
||||
|
||||
public inline fun <R> DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R =
|
||||
|
@ -124,9 +124,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
)
|
||||
val tensor = fromArray(shape, buffer)
|
||||
|
||||
val (lu, pivots) = tensor.lu()
|
||||
|
||||
val (p, l, u) = luPivot(lu, pivots)
|
||||
val (p, l, u) = tensor.lu()
|
||||
|
||||
assertTrue { p.shape contentEquals shape }
|
||||
assertTrue { l.shape contentEquals shape }
|
||||
|
Loading…
Reference in New Issue
Block a user