refactor lu + docs
This commit is contained in:
parent
cba62a9468
commit
09f0a2879e
@ -42,8 +42,7 @@ fun main () {
|
|||||||
// solve `Ax = b` system using LUP decomposition
|
// solve `Ax = b` system using LUP decomposition
|
||||||
|
|
||||||
// get P, L, U such that PA = LU
|
// get P, L, U such that PA = LU
|
||||||
val (lu, pivots) = a.lu()
|
val (p, l, u) = a.lu()
|
||||||
val (p, l, u) = luPivot(lu, pivots)
|
|
||||||
|
|
||||||
// check that P is permutation matrix
|
// check that P is permutation matrix
|
||||||
println("P:\n$p")
|
println("P:\n$p")
|
||||||
|
@ -60,23 +60,17 @@ public interface LinearOpsTensorAlgebra<T> :
|
|||||||
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>>
|
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO('Andrew')
|
* LUP decomposition
|
||||||
* For more information: https://pytorch.org/docs/stable/generated/torch.lu.html
|
|
||||||
*
|
*
|
||||||
* @return ...
|
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
||||||
*/
|
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
|
||||||
public fun TensorStructure<T>.lu(): Pair<TensorStructure<T>, TensorStructure<Int>>
|
* with `P` being a permutation matrix or batch of matrices,
|
||||||
|
* `L` being a lower triangular matrix or batch of matrices,
|
||||||
/**
|
* `U` being an upper triangular matrix or batch of matrices.
|
||||||
* TODO('Andrew')
|
|
||||||
* For more information: https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
|
||||||
*
|
*
|
||||||
* @param luTensor ...
|
* * @return triple of P, L and U tensors
|
||||||
* @param pivotsTensor ...
|
|
||||||
* @return ...
|
|
||||||
*/
|
*/
|
||||||
public fun luPivot(luTensor: TensorStructure<T>, pivotsTensor: TensorStructure<Int>):
|
public fun TensorStructure<T>.lu(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
||||||
Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Singular Value Decomposition.
|
* Singular Value Decomposition.
|
||||||
|
@ -29,13 +29,13 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
override fun TensorStructure<Double>.det(): DoubleTensor = detLU(1e-9)
|
override fun TensorStructure<Double>.det(): DoubleTensor = detLU(1e-9)
|
||||||
|
|
||||||
public fun TensorStructure<Double>.lu(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
public fun TensorStructure<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
||||||
computeLU(tensor, epsilon) ?:
|
computeLU(tensor, epsilon) ?:
|
||||||
throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon")
|
throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon")
|
||||||
|
|
||||||
override fun TensorStructure<Double>.lu(): Pair<DoubleTensor, IntTensor> = lu(1e-9)
|
public fun TensorStructure<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
|
||||||
|
|
||||||
override fun luPivot(
|
public fun luPivot(
|
||||||
luTensor: TensorStructure<Double>,
|
luTensor: TensorStructure<Double>,
|
||||||
pivotsTensor: TensorStructure<Int>
|
pivotsTensor: TensorStructure<Int>
|
||||||
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
@ -156,7 +156,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public fun TensorStructure<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
|
public fun TensorStructure<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||||
val (luTensor, pivotsTensor) = lu(epsilon)
|
val (luTensor, pivotsTensor) = luFactor(epsilon)
|
||||||
val invTensor = luTensor.zeroesLike()
|
val invTensor = luTensor.zeroesLike()
|
||||||
|
|
||||||
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||||
@ -167,6 +167,15 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
return invTensor
|
return invTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun TensorStructure<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
|
val (lu, pivots) = this.luFactor(epsilon)
|
||||||
|
return luPivot(lu, pivots)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TensorStructure<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R =
|
public inline fun <R> DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R =
|
||||||
|
@ -124,9 +124,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
)
|
)
|
||||||
val tensor = fromArray(shape, buffer)
|
val tensor = fromArray(shape, buffer)
|
||||||
|
|
||||||
val (lu, pivots) = tensor.lu()
|
val (p, l, u) = tensor.lu()
|
||||||
|
|
||||||
val (p, l, u) = luPivot(lu, pivots)
|
|
||||||
|
|
||||||
assertTrue { p.shape contentEquals shape }
|
assertTrue { p.shape contentEquals shape }
|
||||||
assertTrue { l.shape contentEquals shape }
|
assertTrue { l.shape contentEquals shape }
|
||||||
|
Loading…
Reference in New Issue
Block a user