Complete lu + buffered tensor #240
@ -35,7 +35,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun zeroesLike(other: RealTensor): RealTensor {
|
override fun zeroesLike(other: RealTensor): RealTensor {
|
||||||
TODO("Not yet implemented")
|
// TODO refactor very bad!!!
|
||||||
|
val shape = other.shape
|
||||||
|
val buffer = DoubleArray(other.buffer.size) { 0.0 }
|
||||||
|
return RealTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun ones(shape: IntArray): RealTensor {
|
override fun ones(shape: IntArray): RealTensor {
|
||||||
@ -340,9 +343,30 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
return Pair(lu, IntTensor(intArrayOf(m), pivot))
|
return Pair(lu, IntTensor(intArrayOf(m), pivot))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun luUnpack(A_LU: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
override fun luUnpack(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||||
// todo checks
|
// todo checks
|
||||||
|
val n = lu.shape[0]
|
||||||
|
val p = zeroesLike(lu)
|
||||||
|
pivots.buffer.array.forEachIndexed { i, pivot ->
|
||||||
|
p[i, pivot] = 1.0
|
||||||
|
}
|
||||||
|
val l = zeroesLike(lu)
|
||||||
|
val u = zeroesLike(lu)
|
||||||
|
|
||||||
|
for (i in 0 until n){
|
||||||
|
for (j in 0 until n){
|
||||||
|
if (i == j) {
|
||||||
|
l[i, j] = 1.0
|
||||||
|
}
|
||||||
|
if (j < i) {
|
||||||
|
l[i, j] = lu[i, j]
|
||||||
|
}
|
||||||
|
if (j >= i) {
|
||||||
|
u[i, j] = lu[i, j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Triple(p, l, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {
|
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {
|
||||||
|
Loading…
Reference in New Issue
Block a user