forked from kscience/kmath
BufferedTensor revisited
This commit is contained in:
parent
f8e0d4be17
commit
7cb5cd8f71
@ -1,32 +1,34 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.nd.MutableNDBuffer
|
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
|
|
||||||
|
|
||||||
public open class BufferedTensor<T>(
|
public open class BufferedTensor<T>(
|
||||||
override val shape: IntArray,
|
override val shape: IntArray,
|
||||||
buffer: MutableBuffer<T>
|
public val buffer: MutableBuffer<T>,
|
||||||
) :
|
internal val bufferStart: Int
|
||||||
TensorStructure<T>,
|
) : TensorStructure<T>
|
||||||
MutableNDBuffer<T>(
|
{
|
||||||
TensorStrides(shape),
|
public val strides: TensorStrides
|
||||||
buffer
|
get() = TensorStrides(shape)
|
||||||
) {
|
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T = buffer[bufferStart + strides.offset(index)]
|
||||||
|
|
||||||
public operator fun get(i: Int, j: Int): T {
|
override fun set(index: IntArray, value: T) {
|
||||||
check(this.dimension == 2) { "Not matrix" }
|
buffer[bufferStart + strides.offset(index)] = value
|
||||||
return this[intArrayOf(i, j)]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun set(i: Int, j: Int, value: T): Unit {
|
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map {
|
||||||
check(this.dimension == 2) { "Not matrix" }
|
it to this[it]
|
||||||
this[intArrayOf(i, j)] = value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean = false
|
||||||
|
|
||||||
|
override fun hashCode(): Int = 0
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
//todo make generator mb nextMatrixIndex?
|
//todo make generator mb nextMatrixIndex?
|
||||||
public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){
|
public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){
|
||||||
private var offset: Int = 0
|
private var offset: Int = 0
|
||||||
@ -75,25 +77,29 @@ public class InnerVector<T>(private val tensor: BufferedTensor<T>){
|
|||||||
offset += step
|
offset += step
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//todo default buffer = arrayOf(0)???
|
//todo default buffer = arrayOf(0)???
|
||||||
|
*/
|
||||||
|
|
||||||
public class IntTensor(
|
public class IntTensor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray
|
buffer: IntArray,
|
||||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer))
|
offset: Int = 0
|
||||||
|
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
||||||
|
|
||||||
public class LongTensor(
|
public class LongTensor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: LongArray
|
buffer: LongArray,
|
||||||
) : BufferedTensor<Long>(shape, LongBuffer(buffer))
|
offset: Int = 0
|
||||||
|
) : BufferedTensor<Long>(shape, LongBuffer(buffer), offset)
|
||||||
|
|
||||||
public class FloatTensor(
|
public class FloatTensor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: FloatArray
|
buffer: FloatArray,
|
||||||
) : BufferedTensor<Float>(shape, FloatBuffer(buffer))
|
offset: Int = 0
|
||||||
|
) : BufferedTensor<Float>(shape, FloatBuffer(buffer), offset)
|
||||||
|
|
||||||
public class DoubleTensor(
|
public class DoubleTensor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: DoubleArray
|
buffer: DoubleArray,
|
||||||
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
|
offset: Int = 0
|
||||||
|
) : BufferedTensor<Double>(shape, RealBuffer(buffer), offset)
|
@ -9,6 +9,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||||
|
/*
|
||||||
// todo checks
|
// todo checks
|
||||||
val luTensor = this.copy()
|
val luTensor = this.copy()
|
||||||
val lu = InnerMatrix(luTensor)
|
val lu = InnerMatrix(luTensor)
|
||||||
@ -69,11 +70,13 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
pivot.makeStep()
|
pivot.makeStep()
|
||||||
}
|
}
|
||||||
|
|
||||||
return Pair(luTensor, pivotsTensor)
|
return Pair(luTensor, pivotsTensor)*/
|
||||||
|
|
||||||
|
TODO("Andrei, first we need to view and get(Int)")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
|
/*
|
||||||
// todo checks
|
// todo checks
|
||||||
val n = lu.shape[0]
|
val n = lu.shape[0]
|
||||||
val p = lu.zeroesLike()
|
val p = lu.zeroesLike()
|
||||||
@ -97,7 +100,8 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Triple(p, l, u)
|
return Triple(p, l, u)*/
|
||||||
|
TODO("Andrei, first we need implement get(Int)")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||||
|
@ -7,11 +7,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
check(this.shape contentEquals intArrayOf(1)) {
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||||
}
|
}
|
||||||
return this.buffer.unsafeToDoubleArray()[0]
|
return this.buffer.unsafeToDoubleArray()[this.bufferStart]
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.get(i: Int): DoubleTensor {
|
override fun DoubleTensor.get(i: Int): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("TOP PRIORITY")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun zeros(shape: IntArray): DoubleTensor {
|
override fun zeros(shape: IntArray): DoubleTensor {
|
||||||
@ -20,7 +20,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
override fun DoubleTensor.zeroesLike(): DoubleTensor {
|
override fun DoubleTensor.zeroesLike(): DoubleTensor {
|
||||||
val shape = this.shape
|
val shape = this.shape
|
||||||
val buffer = DoubleArray(this.buffer.size) { 0.0 }
|
val buffer = DoubleArray(this.strides.linearSize) { 0.0 }
|
||||||
return DoubleTensor(shape, buffer)
|
return DoubleTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,6 +31,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
override fun DoubleTensor.onesLike(): DoubleTensor {
|
override fun DoubleTensor.onesLike(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun eye(n: Int): DoubleTensor {
|
override fun eye(n: Int): DoubleTensor {
|
||||||
val shape = intArrayOf(n, n)
|
val shape = intArrayOf(n, n)
|
||||||
val buffer = DoubleArray(n * n) { 0.0 }
|
val buffer = DoubleArray(n * n) { 0.0 }
|
||||||
@ -42,14 +43,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.copy(): DoubleTensor {
|
override fun DoubleTensor.copy(): DoubleTensor {
|
||||||
// should be rework as soon as copy() method for NDBuffer will be available
|
return DoubleTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf(), this.bufferStart)
|
||||||
return DoubleTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun Double.plus(other: DoubleTensor): DoubleTensor {
|
override fun Double.plus(other: DoubleTensor): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||||
other.buffer.unsafeToDoubleArray()[i] + this
|
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -60,35 +60,36 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
val broadcast = broadcastTensors(this, other)
|
val broadcast = broadcastTensors(this, other)
|
||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
|
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.plusAssign(value: Double) {
|
override fun DoubleTensor.plusAssign(value: Double) {
|
||||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[i] += value
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[i] += other.buffer.unsafeToDoubleArray()[i]
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] +=
|
||||||
|
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Double.minus(other: DoubleTensor): DoubleTensor {
|
override fun Double.minus(other: DoubleTensor): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||||
this - other.buffer.unsafeToDoubleArray()[i]
|
this - other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minus(value: Double): DoubleTensor {
|
override fun DoubleTensor.minus(value: Double): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
this.buffer.unsafeToDoubleArray()[i] - value
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] - value
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -97,15 +98,15 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
val broadcast = broadcastTensors(this, other)
|
val broadcast = broadcastTensors(this, other)
|
||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
|
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minusAssign(value: Double) {
|
override fun DoubleTensor.minusAssign(value: Double) {
|
||||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[i] -= value
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] -= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,8 +116,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
override fun Double.times(other: DoubleTensor): DoubleTensor {
|
override fun Double.times(other: DoubleTensor): DoubleTensor {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||||
other.buffer.unsafeToDoubleArray()[i] * this
|
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] * this
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -126,36 +127,38 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
this.buffer.unsafeToDoubleArray()[i] * other.buffer.unsafeToDoubleArray()[i]
|
this.buffer.unsafeToDoubleArray()[other.bufferStart + i] *
|
||||||
|
other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.timesAssign(value: Double) {
|
override fun DoubleTensor.timesAssign(value: Double) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[i] *= value
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[i] *= other.buffer.unsafeToDoubleArray()[i]
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *=
|
||||||
|
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.unaryMinus(): DoubleTensor {
|
override fun DoubleTensor.unaryMinus(): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
this.buffer.unsafeToDoubleArray()[i].unaryMinus()
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + i].unaryMinus()
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor {
|
override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor {
|
||||||
checkTranspose(this.dimension, i, j)
|
checkTranspose(this.dimension, i, j)
|
||||||
val n = this.buffer.size
|
val n = this.strides.linearSize
|
||||||
val resBuffer = DoubleArray(n)
|
val resBuffer = DoubleArray(n)
|
||||||
|
|
||||||
val resShape = this.shape.copyOf()
|
val resShape = this.shape.copyOf()
|
||||||
@ -169,14 +172,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
||||||
|
|
||||||
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
||||||
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = this.buffer.unsafeToDoubleArray()[offset]
|
resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
|
||||||
|
this.buffer.unsafeToDoubleArray()[this.bufferStart + offset]
|
||||||
}
|
}
|
||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun DoubleTensor.view(shape: IntArray): DoubleTensor {
|
override fun DoubleTensor.view(shape: IntArray): DoubleTensor {
|
||||||
return DoubleTensor(shape, this.buffer.unsafeToDoubleArray())
|
checkView(this, shape)
|
||||||
|
return DoubleTensor(shape, this.buffer.unsafeToDoubleArray(), this.bufferStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor {
|
||||||
|
@ -55,7 +55,8 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
|
|||||||
}
|
}
|
||||||
|
|
||||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||||
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = tensor.buffer.unsafeToDoubleArray()[curLinearIndex]
|
resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
|
||||||
|
tensor.buffer.unsafeToDoubleArray()[tensor.bufferStart + curLinearIndex]
|
||||||
}
|
}
|
||||||
res.add(resTensor)
|
res.add(resTensor)
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import kotlin.test.Test
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class TestRealTensor {
|
class TestDoubleTensor {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueTest() = DoubleTensorAlgebra {
|
fun valueTest() = DoubleTensorAlgebra {
|
@ -4,7 +4,7 @@ package space.kscience.kmath.tensors
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class TestRealTensorAlgebra {
|
class TestDoubleTensorAlgebra {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun doublePlus() = DoubleTensorAlgebra {
|
fun doublePlus() = DoubleTensorAlgebra {
|
Loading…
Reference in New Issue
Block a user