Complete lu + buffered tensor #240

Merged
AndreiKingsley merged 5 commits from andrew into feature/tensor-algebra 2021-03-13 22:16:23 +03:00
3 changed files with 118 additions and 17 deletions

View File

@ -0,0 +1,33 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.linear.BufferMatrix
import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.structures.*
import space.kscience.kmath.structures.BufferAccessor2D
public open class BufferedTensor<T>(
override val shape: IntArray,
buffer: MutableBuffer<T>
) :
TensorStructure<T>,
MutableNDBuffer<T>(
TensorStrides(shape),
buffer
) {
public operator fun get(i: Int, j: Int): T{
check(this.dimension == 2) {"Not matrix"}
return this[intArrayOf(i, j)]
}
public operator fun set(i: Int, j: Int, value: T): Unit{
check(this.dimension == 2) {"Not matrix"}
this[intArrayOf(i, j)] = value
}
}
public class IntTensor(
shape: IntArray,
buffer: IntArray
) : BufferedTensor<Int>(shape, IntBuffer(buffer))

View File

@ -1,20 +1,15 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.RealBuffer
import space.kscience.kmath.structures.array import space.kscience.kmath.structures.array
import kotlin.math.abs
import kotlin.math.max import kotlin.math.max
public class RealTensor( public class RealTensor(
override val shape: IntArray, shape: IntArray,
buffer: DoubleArray buffer: DoubleArray
) : ) : BufferedTensor<Double>(shape, RealBuffer(buffer))
TensorStructure<Double>,
MutableNDBuffer<Double>(
TensorStrides(shape),
RealBuffer(buffer)
)
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> { public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
@ -40,7 +35,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
} }
override fun zeroesLike(other: RealTensor): RealTensor { override fun zeroesLike(other: RealTensor): RealTensor {
TODO("Not yet implemented") // TODO refactor very bad!!!
val shape = other.shape
val buffer = DoubleArray(other.buffer.size) { 0.0 }
return RealTensor(shape, buffer)
} }
override fun ones(shape: IntArray): RealTensor { override fun ones(shape: IntArray): RealTensor {
@ -52,7 +50,8 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
} }
override fun RealTensor.copy(): RealTensor { override fun RealTensor.copy(): RealTensor {
TODO("Not yet implemented") // should be rework as soon as copy() method for NDBuffer will be available
return RealTensor(this.shape, this.buffer.array.copyOf())
} }
override fun broadcastShapes(vararg shapes: IntArray): IntArray { override fun broadcastShapes(vararg shapes: IntArray): IntArray {
@ -294,12 +293,80 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
} }
override fun RealTensor.lu(): Pair<RealTensor, RealTensor> { override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
TODO() // todo checks
val lu = this.copy()
val m = this.shape[0]
val pivot = IntArray(m)
// Initialize permutation array and parity
for (row in 0 until m) pivot[row] = row
var even = true
for (i in 0 until m) {
var maxA = -1.0
var iMax = i
for (k in i until m) {
val absA = abs(lu[k, i])
if (absA > maxA) {
maxA = absA
iMax = k
}
} }
override fun luUnpack(A_LU: RealTensor, pivots: RealTensor): Triple<RealTensor, RealTensor, RealTensor> { //todo check singularity
TODO("Not yet implemented")
if (iMax != i) {
val j = pivot[i]
pivot[i] = pivot[iMax]
pivot[iMax] = j
even != even
for (k in 0 until m) {
val tmp = lu[i, k]
lu[i, k] = lu[iMax, k]
lu[iMax, k] = tmp
}
}
for (j in i + 1 until m) {
lu[j, i] /= lu[i, i]
for (k in i + 1 until m) {
lu[j, k] -= lu[j, i] * lu[i, k]
}
}
}
return Pair(lu, IntTensor(intArrayOf(m), pivot))
}
override fun luUnpack(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
// todo checks
val n = lu.shape[0]
val p = zeroesLike(lu)
pivots.buffer.array.forEachIndexed { i, pivot ->
p[i, pivot] = 1.0
}
val l = zeroesLike(lu)
val u = zeroesLike(lu)
for (i in 0 until n){
for (j in 0 until n){
if (i == j) {
l[i, j] = 1.0
}
if (j < i) {
l[i, j] = lu[i, j]
}
if (j >= i) {
u[i, j] = lu[i, j]
}
}
}
return Triple(p, l, u)
} }
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> { override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {

View File

@ -74,11 +74,12 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
//https://pytorch.org/docs/stable/generated/torch.log.html //https://pytorch.org/docs/stable/generated/torch.log.html
public fun TensorType.log(): TensorType public fun TensorType.log(): TensorType
// todo change type of pivots
//https://pytorch.org/docs/stable/generated/torch.lu.html //https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, TensorType> public fun TensorType.lu(): Pair<TensorType, IntTensor>
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html //https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
public fun luUnpack(A_LU: TensorType, pivots: TensorType): Triple<TensorType, TensorType, TensorType> public fun luUnpack(A_LU: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
//https://pytorch.org/docs/stable/generated/torch.svd.html //https://pytorch.org/docs/stable/generated/torch.svd.html
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType> public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>