BufferedTensor revisited

This commit is contained in:
Roland Grinis 2021-03-15 22:11:15 +00:00
parent f8e0d4be17
commit 7cb5cd8f71
6 changed files with 78 additions and 62 deletions

View File

@ -1,32 +1,34 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
public open class BufferedTensor<T>( public open class BufferedTensor<T>(
override val shape: IntArray, override val shape: IntArray,
buffer: MutableBuffer<T> public val buffer: MutableBuffer<T>,
) : internal val bufferStart: Int
TensorStructure<T>, ) : TensorStructure<T>
MutableNDBuffer<T>( {
TensorStrides(shape), public val strides: TensorStrides
buffer get() = TensorStrides(shape)
) {
override fun get(index: IntArray): T = buffer[bufferStart + strides.offset(index)]
public operator fun get(i: Int, j: Int): T { override fun set(index: IntArray, value: T) {
check(this.dimension == 2) { "Not matrix" } buffer[bufferStart + strides.offset(index)] = value
return this[intArrayOf(i, j)]
} }
public operator fun set(i: Int, j: Int, value: T): Unit { override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map {
check(this.dimension == 2) { "Not matrix" } it to this[it]
this[intArrayOf(i, j)] = value
} }
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
} }
/*
//todo make generator mb nextMatrixIndex? //todo make generator mb nextMatrixIndex?
public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){ public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){
private var offset: Int = 0 private var offset: Int = 0
@ -75,25 +77,29 @@ public class InnerVector<T>(private val tensor: BufferedTensor<T>){
offset += step offset += step
} }
} }
//todo default buffer = arrayOf(0)??? //todo default buffer = arrayOf(0)???
*/
public class IntTensor( public class IntTensor(
shape: IntArray, shape: IntArray,
buffer: IntArray buffer: IntArray,
) : BufferedTensor<Int>(shape, IntBuffer(buffer)) offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
public class LongTensor( public class LongTensor(
shape: IntArray, shape: IntArray,
buffer: LongArray buffer: LongArray,
) : BufferedTensor<Long>(shape, LongBuffer(buffer)) offset: Int = 0
) : BufferedTensor<Long>(shape, LongBuffer(buffer), offset)
public class FloatTensor( public class FloatTensor(
shape: IntArray, shape: IntArray,
buffer: FloatArray buffer: FloatArray,
) : BufferedTensor<Float>(shape, FloatBuffer(buffer)) offset: Int = 0
) : BufferedTensor<Float>(shape, FloatBuffer(buffer), offset)
public class DoubleTensor( public class DoubleTensor(
shape: IntArray, shape: IntArray,
buffer: DoubleArray buffer: DoubleArray,
) : BufferedTensor<Double>(shape, RealBuffer(buffer)) offset: Int = 0
) : BufferedTensor<Double>(shape, RealBuffer(buffer), offset)

View File

@ -9,6 +9,7 @@ public class DoubleLinearOpsTensorAlgebra :
} }
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> { override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
/*
// todo checks // todo checks
val luTensor = this.copy() val luTensor = this.copy()
val lu = InnerMatrix(luTensor) val lu = InnerMatrix(luTensor)
@ -69,11 +70,13 @@ public class DoubleLinearOpsTensorAlgebra :
pivot.makeStep() pivot.makeStep()
} }
return Pair(luTensor, pivotsTensor) return Pair(luTensor, pivotsTensor)*/
TODO("Andrei, first we need to view and get(Int)")
} }
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
/*
// todo checks // todo checks
val n = lu.shape[0] val n = lu.shape[0]
val p = lu.zeroesLike() val p = lu.zeroesLike()
@ -97,7 +100,8 @@ public class DoubleLinearOpsTensorAlgebra :
} }
} }
return Triple(p, l, u) return Triple(p, l, u)*/
TODO("Andrei, first we need implement get(Int)")
} }
override fun DoubleTensor.cholesky(): DoubleTensor { override fun DoubleTensor.cholesky(): DoubleTensor {

View File

@ -7,11 +7,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
check(this.shape contentEquals intArrayOf(1)) { check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}" "Inconsistent value for tensor of shape ${shape.toList()}"
} }
return this.buffer.unsafeToDoubleArray()[0] return this.buffer.unsafeToDoubleArray()[this.bufferStart]
} }
override fun DoubleTensor.get(i: Int): DoubleTensor { override fun DoubleTensor.get(i: Int): DoubleTensor {
TODO("Not yet implemented") TODO("TOP PRIORITY")
} }
override fun zeros(shape: IntArray): DoubleTensor { override fun zeros(shape: IntArray): DoubleTensor {
@ -20,7 +20,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.zeroesLike(): DoubleTensor { override fun DoubleTensor.zeroesLike(): DoubleTensor {
val shape = this.shape val shape = this.shape
val buffer = DoubleArray(this.buffer.size) { 0.0 } val buffer = DoubleArray(this.strides.linearSize) { 0.0 }
return DoubleTensor(shape, buffer) return DoubleTensor(shape, buffer)
} }
@ -31,6 +31,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.onesLike(): DoubleTensor { override fun DoubleTensor.onesLike(): DoubleTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun eye(n: Int): DoubleTensor { override fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n) val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 } val buffer = DoubleArray(n * n) { 0.0 }
@ -42,14 +43,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
override fun DoubleTensor.copy(): DoubleTensor { override fun DoubleTensor.copy(): DoubleTensor {
// should be rework as soon as copy() method for NDBuffer will be available return DoubleTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf(), this.bufferStart)
return DoubleTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf())
} }
override fun Double.plus(other: DoubleTensor): DoubleTensor { override fun Double.plus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.buffer.size) { i -> val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.unsafeToDoubleArray()[i] + this other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
@ -60,35 +60,36 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
val broadcast = broadcastTensors(this, other) val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i -> val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i] newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun DoubleTensor.plusAssign(value: Double) { override fun DoubleTensor.plusAssign(value: Double) {
for (i in this.buffer.unsafeToDoubleArray().indices) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[i] += value this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += value
} }
} }
override fun DoubleTensor.plusAssign(other: DoubleTensor) { override fun DoubleTensor.plusAssign(other: DoubleTensor) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in this.buffer.unsafeToDoubleArray().indices) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[i] += other.buffer.unsafeToDoubleArray()[i] this.buffer.unsafeToDoubleArray()[this.bufferStart + i] +=
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
} }
} }
override fun Double.minus(other: DoubleTensor): DoubleTensor { override fun Double.minus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.buffer.size) { i -> val resBuffer = DoubleArray(other.strides.linearSize) { i ->
this - other.buffer.unsafeToDoubleArray()[i] this - other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun DoubleTensor.minus(value: Double): DoubleTensor { override fun DoubleTensor.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(this.buffer.size) { i -> val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[i] - value this.buffer.unsafeToDoubleArray()[this.bufferStart + i] - value
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
@ -97,15 +98,15 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
val broadcast = broadcastTensors(this, other) val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i -> val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i] newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun DoubleTensor.minusAssign(value: Double) { override fun DoubleTensor.minusAssign(value: Double) {
for (i in this.buffer.unsafeToDoubleArray().indices) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[i] -= value this.buffer.unsafeToDoubleArray()[this.bufferStart + i] -= value
} }
} }
@ -115,8 +116,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun Double.times(other: DoubleTensor): DoubleTensor { override fun Double.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting //todo should be change with broadcasting
val resBuffer = DoubleArray(other.buffer.size) { i -> val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.unsafeToDoubleArray()[i] * this other.buffer.unsafeToDoubleArray()[other.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
@ -126,36 +127,38 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting //todo should be change with broadcasting
val resBuffer = DoubleArray(this.buffer.size) { i -> val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[i] * other.buffer.unsafeToDoubleArray()[i] this.buffer.unsafeToDoubleArray()[other.bufferStart + i] *
other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
override fun DoubleTensor.timesAssign(value: Double) { override fun DoubleTensor.timesAssign(value: Double) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in this.buffer.unsafeToDoubleArray().indices) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[i] *= value this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *= value
} }
} }
override fun DoubleTensor.timesAssign(other: DoubleTensor) { override fun DoubleTensor.timesAssign(other: DoubleTensor) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in this.buffer.unsafeToDoubleArray().indices) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[i] *= other.buffer.unsafeToDoubleArray()[i] this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *=
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
} }
} }
override fun DoubleTensor.unaryMinus(): DoubleTensor { override fun DoubleTensor.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(this.buffer.size) { i -> val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[i].unaryMinus() this.buffer.unsafeToDoubleArray()[this.bufferStart + i].unaryMinus()
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor { override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor {
checkTranspose(this.dimension, i, j) checkTranspose(this.dimension, i, j)
val n = this.buffer.size val n = this.strides.linearSize
val resBuffer = DoubleArray(n) val resBuffer = DoubleArray(n)
val resShape = this.shape.copyOf() val resShape = this.shape.copyOf()
@ -169,14 +172,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] } newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
val linearIndex = resTensor.strides.offset(newMultiIndex) val linearIndex = resTensor.strides.offset(newMultiIndex)
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = this.buffer.unsafeToDoubleArray()[offset] resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
this.buffer.unsafeToDoubleArray()[this.bufferStart + offset]
} }
return resTensor return resTensor
} }
override fun DoubleTensor.view(shape: IntArray): DoubleTensor { override fun DoubleTensor.view(shape: IntArray): DoubleTensor {
return DoubleTensor(shape, this.buffer.unsafeToDoubleArray()) checkView(this, shape)
return DoubleTensor(shape, this.buffer.unsafeToDoubleArray(), this.bufferStart)
} }
override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor {

View File

@ -55,7 +55,8 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
} }
val curLinearIndex = tensor.strides.offset(curMultiIndex) val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = tensor.buffer.unsafeToDoubleArray()[curLinearIndex] resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
tensor.buffer.unsafeToDoubleArray()[tensor.bufferStart + curLinearIndex]
} }
res.add(resTensor) res.add(resTensor)
} }

View File

@ -6,7 +6,7 @@ import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class TestRealTensor { class TestDoubleTensor {
@Test @Test
fun valueTest() = DoubleTensorAlgebra { fun valueTest() = DoubleTensorAlgebra {

View File

@ -4,7 +4,7 @@ package space.kscience.kmath.tensors
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertTrue import kotlin.test.assertTrue
class TestRealTensorAlgebra { class TestDoubleTensorAlgebra {
@Test @Test
fun doublePlus() = DoubleTensorAlgebra { fun doublePlus() = DoubleTensorAlgebra {