add buffered tensor + lu

This commit is contained in:
Andrei Kislitsyn 2021-03-13 19:30:58 +03:00
parent 95b814e163
commit 0911efd4aa
4 changed files with 92 additions and 48 deletions

View File

@ -1,33 +0,0 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.linear.BufferMatrix
import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix
import space.kscience.kmath.nd.Matrix
import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferAccessor2D
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.toList
public open class BufferTensor<T>(
override val shape: IntArray,
buffer: MutableBuffer<T>
) :
TensorStructure<T>,
MutableNDBuffer<T>(
TensorStrides(shape),
buffer
)
public fun <T : Any> BufferTensor<T>.toBufferMatrix(): BufferMatrix<T> {
return BufferMatrix(shape[0], shape[1], this.buffer)
}
// T???
public fun BufferMatrix<Double>.BufferTensor(): BufferTensor<Double> {
return BufferTensor(intArrayOf(rowNum, colNum), BufferAccessor2D(rowNum, colNum, Buffer.Companion::real).create(this))
}

View File

@ -0,0 +1,33 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.linear.BufferMatrix
import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.structures.*
import space.kscience.kmath.structures.BufferAccessor2D
public open class BufferedTensor<T>(
override val shape: IntArray,
buffer: MutableBuffer<T>
) :
TensorStructure<T>,
MutableNDBuffer<T>(
TensorStrides(shape),
buffer
) {
public operator fun get(i: Int, j: Int): T{
check(this.dimension == 2) {"Not matrix"}
return this[intArrayOf(i, j)]
}
public operator fun set(i: Int, j: Int, value: T): Unit{
check(this.dimension == 2) {"Not matrix"}
this[intArrayOf(i, j)] = value
}
}
public class IntTensor(
shape: IntArray,
buffer: IntArray
) : BufferedTensor<Int>(shape, IntBuffer(buffer))

View File

@ -1,24 +1,19 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.linear.BufferMatrix
import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix
import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.RealBuffer
import space.kscience.kmath.structures.array import space.kscience.kmath.structures.array
import space.kscience.kmath.structures.toList import kotlin.math.abs
public class RealTensor( public class RealTensor(
shape: IntArray, shape: IntArray,
buffer: DoubleArray buffer: DoubleArray
) : BufferTensor<Double>(shape, RealBuffer(buffer)) ) : BufferedTensor<Double>(shape, RealBuffer(buffer))
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> { public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
override fun RealTensor.value(): Double { override fun RealTensor.value(): Double {
check(this.shape contentEquals intArrayOf(1)) { check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}" "Inconsistent value for tensor of shape ${shape.toList()}"
} }
return this.buffer.array[0] return this.buffer.array[0]
@ -51,7 +46,8 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
} }
override fun RealTensor.copy(): RealTensor { override fun RealTensor.copy(): RealTensor {
TODO("Not yet implemented") // should be rework as soon as copy() method for NDBuffer will be available
return RealTensor(this.shape, this.buffer.array.copyOf())
} }
override fun Double.plus(other: RealTensor): RealTensor { override fun Double.plus(other: RealTensor): RealTensor {
@ -249,12 +245,59 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun RealTensor.lu(): Pair<RealTensor, RealTensor> { override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
TODO() // todo checks
val lu = this.copy()
val m = this.shape[0]
val pivot = IntArray(m)
// Initialize permutation array and parity
for (row in 0 until m) pivot[row] = row
var even = true
for (i in 0 until m) {
var maxA = -1.0
var iMax = i
for (k in i until m) {
val absA = abs(lu[k, i])
if (absA > maxA) {
maxA = absA
iMax = k
}
}
//todo check singularity
if (iMax != i) {
val j = pivot[i]
pivot[i] = pivot[iMax]
pivot[iMax] = j
even != even
for (k in 0 until m) {
val tmp = lu[i, k]
lu[i, k] = lu[iMax, k]
lu[iMax, k] = tmp
}
}
for (j in i + 1 until m) {
lu[j, i] /= lu[i, i]
for (k in i + 1 until m) {
lu[j, k] -= lu[j, i] * lu[i, k]
}
}
}
return Pair(lu, IntTensor(intArrayOf(m), pivot))
} }
override fun luUnpack(A_LU: RealTensor, pivots: RealTensor): Triple<RealTensor, RealTensor, RealTensor> { override fun luUnpack(A_LU: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
TODO("Not yet implemented") // todo checks
} }
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> { override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {

View File

@ -78,11 +78,12 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
public fun TensorType.log(): TensorType public fun TensorType.log(): TensorType
public fun TensorType.logAssign(): Unit public fun TensorType.logAssign(): Unit
// todo change type of pivots
//https://pytorch.org/docs/stable/generated/torch.lu.html //https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, TensorType> public fun TensorType.lu(): Pair<TensorType, IntTensor>
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html //https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
public fun luUnpack(A_LU: TensorType, pivots: TensorType): Triple<TensorType, TensorType, TensorType> public fun luUnpack(A_LU: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
//https://pytorch.org/docs/stable/generated/torch.svd.html //https://pytorch.org/docs/stable/generated/torch.svd.html
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType> public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>