v0.3.0-dev-9 #324
@ -35,7 +35,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
||||
}
|
||||
|
||||
override fun zeroesLike(other: RealTensor): RealTensor {
|
||||
TODO("Not yet implemented")
|
||||
// TODO refactor very bad!!!
|
||||
val shape = other.shape
|
||||
val buffer = DoubleArray(other.buffer.size) { 0.0 }
|
||||
return RealTensor(shape, buffer)
|
||||
}
|
||||
|
||||
override fun ones(shape: IntArray): RealTensor {
|
||||
@ -340,9 +343,30 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
||||
return Pair(lu, IntTensor(intArrayOf(m), pivot))
|
||||
}
|
||||
|
||||
override fun luUnpack(A_LU: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||
override fun luUnpack(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||
// todo checks
|
||||
val n = lu.shape[0]
|
||||
val p = zeroesLike(lu)
|
||||
pivots.buffer.array.forEachIndexed { i, pivot ->
|
||||
p[i, pivot] = 1.0
|
||||
}
|
||||
val l = zeroesLike(lu)
|
||||
val u = zeroesLike(lu)
|
||||
|
||||
for (i in 0 until n){
|
||||
for (j in 0 until n){
|
||||
if (i == j) {
|
||||
l[i, j] = 1.0
|
||||
}
|
||||
if (j < i) {
|
||||
l[i, j] = lu[i, j]
|
||||
}
|
||||
if (j >= i) {
|
||||
u[i, j] = lu[i, j]
|
||||
}
|
||||
}
|
||||
}
|
||||
return Triple(p, l, u)
|
||||
}
|
||||
|
||||
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {
|
||||
|
Loading…
Reference in New Issue
Block a user