get dim 0 operator for tensors

This commit is contained in:
Roland Grinis 2021-03-15 22:39:29 +00:00
parent 7cb5cd8f71
commit 791f55ee8a
6 changed files with 52 additions and 43 deletions

View File

@ -72,7 +72,7 @@ public class DoubleLinearOpsTensorAlgebra :
return Pair(luTensor, pivotsTensor)*/ return Pair(luTensor, pivotsTensor)*/
TODO("Andrei, first we need to view and get(Int)") TODO("Andrei, use view, get, as2D, as1D")
} }
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {

View File

@ -7,11 +7,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
check(this.shape contentEquals intArrayOf(1)) { check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}" "Inconsistent value for tensor of shape ${shape.toList()}"
} }
return this.buffer.unsafeToDoubleArray()[this.bufferStart] return this.buffer.array()[this.bufferStart]
} }
override fun DoubleTensor.get(i: Int): DoubleTensor { override operator fun DoubleTensor.get(i: Int): DoubleTensor {
TODO("TOP PRIORITY") val lastShape = this.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + this.bufferStart
return DoubleTensor(newShape, this.buffer.array(), newStart)
} }
override fun zeros(shape: IntArray): DoubleTensor { override fun zeros(shape: IntArray): DoubleTensor {
@ -43,13 +46,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
override fun DoubleTensor.copy(): DoubleTensor { override fun DoubleTensor.copy(): DoubleTensor {
return DoubleTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf(), this.bufferStart) return DoubleTensor(this.shape, this.buffer.array().copyOf(), this.bufferStart)
} }
override fun Double.plus(other: DoubleTensor): DoubleTensor { override fun Double.plus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.strides.linearSize) { i -> val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this other.buffer.array()[other.bufferStart + i] + this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
@ -61,35 +64,35 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i -> val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i] newThis.buffer.array()[i] + newOther.buffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun DoubleTensor.plusAssign(value: Double) { override fun DoubleTensor.plusAssign(value: Double) {
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += value this.buffer.array()[this.bufferStart + i] += value
} }
} }
override fun DoubleTensor.plusAssign(other: DoubleTensor) { override fun DoubleTensor.plusAssign(other: DoubleTensor) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += this.buffer.array()[this.bufferStart + i] +=
other.buffer.unsafeToDoubleArray()[this.bufferStart + i] other.buffer.array()[this.bufferStart + i]
} }
} }
override fun Double.minus(other: DoubleTensor): DoubleTensor { override fun Double.minus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.strides.linearSize) { i -> val resBuffer = DoubleArray(other.strides.linearSize) { i ->
this - other.buffer.unsafeToDoubleArray()[other.bufferStart + i] this - other.buffer.array()[other.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun DoubleTensor.minus(value: Double): DoubleTensor { override fun DoubleTensor.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] - value this.buffer.array()[this.bufferStart + i] - value
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
@ -99,14 +102,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i -> val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i] newThis.buffer.array()[i] - newOther.buffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun DoubleTensor.minusAssign(value: Double) { override fun DoubleTensor.minusAssign(value: Double) {
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] -= value this.buffer.array()[this.bufferStart + i] -= value
} }
} }
@ -117,7 +120,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun Double.times(other: DoubleTensor): DoubleTensor { override fun Double.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting //todo should be change with broadcasting
val resBuffer = DoubleArray(other.strides.linearSize) { i -> val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] * this other.buffer.array()[other.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
@ -128,8 +131,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting //todo should be change with broadcasting
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[other.bufferStart + i] * this.buffer.array()[other.bufferStart + i] *
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] other.buffer.array()[other.bufferStart + i]
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
@ -137,21 +140,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.timesAssign(value: Double) { override fun DoubleTensor.timesAssign(value: Double) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *= value this.buffer.array()[this.bufferStart + i] *= value
} }
} }
override fun DoubleTensor.timesAssign(other: DoubleTensor) { override fun DoubleTensor.timesAssign(other: DoubleTensor) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *= this.buffer.array()[this.bufferStart + i] *=
other.buffer.unsafeToDoubleArray()[this.bufferStart + i] other.buffer.array()[this.bufferStart + i]
} }
} }
override fun DoubleTensor.unaryMinus(): DoubleTensor { override fun DoubleTensor.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[this.bufferStart + i].unaryMinus() this.buffer.array()[this.bufferStart + i].unaryMinus()
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
@ -172,8 +175,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] } newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
val linearIndex = resTensor.strides.offset(newMultiIndex) val linearIndex = resTensor.strides.offset(newMultiIndex)
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = resTensor.buffer.array()[linearIndex] =
this.buffer.unsafeToDoubleArray()[this.bufferStart + offset] this.buffer.array()[this.bufferStart + offset]
} }
return resTensor return resTensor
} }
@ -181,7 +184,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.view(shape: IntArray): DoubleTensor { override fun DoubleTensor.view(shape: IntArray): DoubleTensor {
checkView(this, shape) checkView(this, shape)
return DoubleTensor(shape, this.buffer.unsafeToDoubleArray(), this.bufferStart) return DoubleTensor(shape, this.buffer.array(), this.bufferStart)
} }
override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor {

View File

@ -42,7 +42,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public operator fun TensorType.unaryMinus(): TensorType public operator fun TensorType.unaryMinus(): TensorType
//https://pytorch.org/cppdocs/notes/tensor_indexing.html //https://pytorch.org/cppdocs/notes/tensor_indexing.html
public fun TensorType.get(i: Int): TensorType public operator fun TensorType.get(i: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.transpose.html //https://pytorch.org/docs/stable/generated/torch.transpose.html
public fun TensorType.transpose(i: Int, j: Int): TensorType public fun TensorType.transpose(i: Int, j: Int): TensorType

View File

@ -55,8 +55,8 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
} }
val curLinearIndex = tensor.strides.offset(curMultiIndex) val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = resTensor.buffer.array()[linearIndex] =
tensor.buffer.unsafeToDoubleArray()[tensor.bufferStart + curLinearIndex] tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
} }
res.add(resTensor) res.add(resTensor)
} }
@ -99,7 +99,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
/** /**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer]. * Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/ */
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) { internal fun Buffer<Int>.array(): IntArray = when(this) {
is IntBuffer -> array is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray") else -> throw RuntimeException("Failed to cast Buffer to IntArray")
} }
@ -107,7 +107,7 @@ internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
/** /**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer]. * Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/ */
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) { internal fun Buffer<Long>.array(): LongArray = when(this) {
is LongBuffer -> array is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray") else -> throw RuntimeException("Failed to cast Buffer to LongArray")
} }
@ -115,7 +115,7 @@ internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
/** /**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer]. * Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/ */
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) { internal fun Buffer<Float>.array(): FloatArray = when(this) {
is FloatBuffer -> array is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray") else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
} }
@ -123,7 +123,7 @@ internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
/** /**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer]. * Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/ */
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) { internal fun Buffer<Double>.array(): DoubleArray = when(this) {
is RealBuffer -> array is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray") else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
} }

View File

@ -21,4 +21,10 @@ class TestDoubleTensor {
assertEquals(tensor[intArrayOf(0,1)], 5.8) assertEquals(tensor[intArrayOf(0,1)], 5.8)
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray()) assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
} }
@Test
fun getTest() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
assertTrue(tensor[0].elements().map{ it.second }.toList().toDoubleArray() contentEquals doubleArrayOf(3.5,5.8))
}
} }

View File

@ -10,7 +10,7 @@ class TestDoubleTensorAlgebra {
fun doublePlus() = DoubleTensorAlgebra { fun doublePlus() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0)) val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor val res = 10.0 + tensor
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0)) assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0,12.0))
} }
@Test @Test
@ -18,7 +18,7 @@ class TestDoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0)) val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0) val res = tensor.transpose(0, 0)
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0)) assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals intArrayOf(1)) assertTrue(res.shape contentEquals intArrayOf(1))
} }
@ -27,7 +27,7 @@ class TestDoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0) val res = tensor.transpose(1, 0)
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals intArrayOf(2, 3)) assertTrue(res.shape contentEquals intArrayOf(2, 3))
} }
@ -42,9 +42,9 @@ class TestDoubleTensorAlgebra {
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1)) assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2)) assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
} }
@Test @Test
@ -70,9 +70,9 @@ class TestDoubleTensorAlgebra {
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0)) assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
} }
@Test @Test
@ -82,14 +82,14 @@ class TestDoubleTensorAlgebra {
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3)) assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue((tensor2 - tensor1).buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3)) assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray() assertTrue((tensor3 - tensor1).buffer.array()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)) contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3)) assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) assertTrue((tensor3 - tensor2).buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
} }
} }