get dim 0 operator for tensors

This commit is contained in:
Roland Grinis 2021-03-15 22:39:29 +00:00
parent 7cb5cd8f71
commit 791f55ee8a
6 changed files with 52 additions and 43 deletions

View File

@ -72,7 +72,7 @@ public class DoubleLinearOpsTensorAlgebra :
return Pair(luTensor, pivotsTensor)*/
TODO("Andrei, first we need to view and get(Int)")
TODO("Andrei, use view, get, as2D, as1D")
}
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {

View File

@ -7,11 +7,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}"
}
return this.buffer.unsafeToDoubleArray()[this.bufferStart]
return this.buffer.array()[this.bufferStart]
}
override fun DoubleTensor.get(i: Int): DoubleTensor {
TODO("TOP PRIORITY")
override operator fun DoubleTensor.get(i: Int): DoubleTensor {
val lastShape = this.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + this.bufferStart
return DoubleTensor(newShape, this.buffer.array(), newStart)
}
override fun zeros(shape: IntArray): DoubleTensor {
@ -43,13 +46,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
}
override fun DoubleTensor.copy(): DoubleTensor {
return DoubleTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf(), this.bufferStart)
return DoubleTensor(this.shape, this.buffer.array().copyOf(), this.bufferStart)
}
override fun Double.plus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this
other.buffer.array()[other.bufferStart + i] + this
}
return DoubleTensor(other.shape, resBuffer)
}
@ -61,35 +64,35 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
newThis.buffer.array()[i] + newOther.buffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun DoubleTensor.plusAssign(value: Double) {
for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += value
this.buffer.array()[this.bufferStart + i] += value
}
}
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
//todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] +=
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
this.buffer.array()[this.bufferStart + i] +=
other.buffer.array()[this.bufferStart + i]
}
}
override fun Double.minus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
this - other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
this - other.buffer.array()[other.bufferStart + i]
}
return DoubleTensor(other.shape, resBuffer)
}
override fun DoubleTensor.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] - value
this.buffer.array()[this.bufferStart + i] - value
}
return DoubleTensor(this.shape, resBuffer)
}
@ -99,14 +102,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
newThis.buffer.array()[i] - newOther.buffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun DoubleTensor.minusAssign(value: Double) {
for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] -= value
this.buffer.array()[this.bufferStart + i] -= value
}
}
@ -117,7 +120,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun Double.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] * this
other.buffer.array()[other.bufferStart + i] * this
}
return DoubleTensor(other.shape, resBuffer)
}
@ -128,8 +131,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[other.bufferStart + i] *
other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
this.buffer.array()[other.bufferStart + i] *
other.buffer.array()[other.bufferStart + i]
}
return DoubleTensor(this.shape, resBuffer)
}
@ -137,21 +140,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.timesAssign(value: Double) {
//todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *= value
this.buffer.array()[this.bufferStart + i] *= value
}
}
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
//todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) {
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *=
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
this.buffer.array()[this.bufferStart + i] *=
other.buffer.array()[this.bufferStart + i]
}
}
override fun DoubleTensor.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.unsafeToDoubleArray()[this.bufferStart + i].unaryMinus()
this.buffer.array()[this.bufferStart + i].unaryMinus()
}
return DoubleTensor(this.shape, resBuffer)
}
@ -172,8 +175,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
val linearIndex = resTensor.strides.offset(newMultiIndex)
resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
this.buffer.unsafeToDoubleArray()[this.bufferStart + offset]
resTensor.buffer.array()[linearIndex] =
this.buffer.array()[this.bufferStart + offset]
}
return resTensor
}
@ -181,7 +184,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.view(shape: IntArray): DoubleTensor {
checkView(this, shape)
return DoubleTensor(shape, this.buffer.unsafeToDoubleArray(), this.bufferStart)
return DoubleTensor(shape, this.buffer.array(), this.bufferStart)
}
override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor {

View File

@ -42,7 +42,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public operator fun TensorType.unaryMinus(): TensorType
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
public fun TensorType.get(i: Int): TensorType
public operator fun TensorType.get(i: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.transpose.html
public fun TensorType.transpose(i: Int, j: Int): TensorType

View File

@ -55,8 +55,8 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
tensor.buffer.unsafeToDoubleArray()[tensor.bufferStart + curLinearIndex]
resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
}
res.add(resTensor)
}
@ -99,7 +99,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
internal fun Buffer<Int>.array(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
@ -107,7 +107,7 @@ internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
internal fun Buffer<Long>.array(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
@ -115,7 +115,7 @@ internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
internal fun Buffer<Float>.array(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
@ -123,7 +123,7 @@ internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
internal fun Buffer<Double>.array(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}

View File

@ -21,4 +21,10 @@ class TestDoubleTensor {
assertEquals(tensor[intArrayOf(0,1)], 5.8)
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
}
@Test
fun getTest() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
assertTrue(tensor[0].elements().map{ it.second }.toList().toDoubleArray() contentEquals doubleArrayOf(3.5,5.8))
}
}

View File

@ -10,7 +10,7 @@ class TestDoubleTensorAlgebra {
fun doublePlus() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0,12.0))
}
@Test
@ -18,7 +18,7 @@ class TestDoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0)
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0))
assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals intArrayOf(1))
}
@ -27,7 +27,7 @@ class TestDoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0)
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals intArrayOf(2, 3))
}
@ -42,9 +42,9 @@ class TestDoubleTensorAlgebra {
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
}
@Test
@ -70,9 +70,9 @@ class TestDoubleTensorAlgebra {
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
}
@Test
@ -82,14 +82,14 @@ class TestDoubleTensorAlgebra {
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor2 - tensor1).buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray()
assertTrue((tensor3 - tensor1).buffer.array()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
assertTrue((tensor3 - tensor2).buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}
}