Complete lu + buffered tensor #240

Merged
AndreiKingsley merged 5 commits from andrew into feature/tensor-algebra 2021-03-13 22:16:23 +03:00
Showing only changes of commit 8f88a101d2 - Show all commits

View File

@ -35,7 +35,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
} }
override fun zeroesLike(other: RealTensor): RealTensor { override fun zeroesLike(other: RealTensor): RealTensor {
TODO("Not yet implemented") // TODO refactor very bad!!!
val shape = other.shape
val buffer = DoubleArray(other.buffer.size) { 0.0 }
return RealTensor(shape, buffer)
} }
override fun ones(shape: IntArray): RealTensor { override fun ones(shape: IntArray): RealTensor {
@ -340,9 +343,30 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
return Pair(lu, IntTensor(intArrayOf(m), pivot)) return Pair(lu, IntTensor(intArrayOf(m), pivot))
} }
override fun luUnpack(A_LU: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> { override fun luUnpack(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
// todo checks // todo checks
val n = lu.shape[0]
val p = zeroesLike(lu)
pivots.buffer.array.forEachIndexed { i, pivot ->
p[i, pivot] = 1.0
}
val l = zeroesLike(lu)
val u = zeroesLike(lu)
for (i in 0 until n){
for (j in 0 until n){
if (i == j) {
l[i, j] = 1.0
}
if (j < i) {
l[i, j] = lu[i, j]
}
if (j >= i) {
u[i, j] = lu[i, j]
}
}
}
return Triple(p, l, u)
} }
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> { override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {