From 791f55ee8a87f2822769099b22e70e04f4daa2f5 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Mon, 15 Mar 2021 22:39:29 +0000 Subject: [PATCH] get dim 0 operator for tensors --- .../tensors/DoubleLinearOpsTensorAlgebra.kt | 2 +- .../kmath/tensors/DoubleTensorAlgebra.kt | 49 ++++++++++--------- .../kscience/kmath/tensors/TensorAlgebra.kt | 2 +- .../space/kscience/kmath/tensors/utils.kt | 12 ++--- .../kmath/tensors/TestDoubleTensor.kt | 6 +++ .../kmath/tensors/TestDoubleTensorAlgebra.kt | 24 ++++----- 6 files changed, 52 insertions(+), 43 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt index eceb28459..d6b202556 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt @@ -72,7 +72,7 @@ public class DoubleLinearOpsTensorAlgebra : return Pair(luTensor, pivotsTensor)*/ - TODO("Andrei, first we need to view and get(Int)") + TODO("Andrei, use view, get, as2D, as1D") } override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt index 391c2895f..3b65e89da 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt @@ -7,11 +7,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra - other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this + other.buffer.array()[other.bufferStart + i] + this } return DoubleTensor(other.shape, resBuffer) } @@ -61,35 +64,35 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra - newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i] + newThis.buffer.array()[i] + newOther.buffer.array()[i] } return DoubleTensor(newThis.shape, resBuffer) } override fun DoubleTensor.plusAssign(value: Double) { for (i in 0 until this.strides.linearSize) { - this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += value + this.buffer.array()[this.bufferStart + i] += value } } override fun DoubleTensor.plusAssign(other: DoubleTensor) { //todo should be change with broadcasting for (i in 0 until this.strides.linearSize) { - this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += - other.buffer.unsafeToDoubleArray()[this.bufferStart + i] + this.buffer.array()[this.bufferStart + i] += + other.buffer.array()[this.bufferStart + i] } } override fun Double.minus(other: DoubleTensor): DoubleTensor { val resBuffer = DoubleArray(other.strides.linearSize) { i -> - this - other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this - other.buffer.array()[other.bufferStart + i] } return DoubleTensor(other.shape, resBuffer) } override fun DoubleTensor.minus(value: Double): DoubleTensor { val resBuffer = DoubleArray(this.strides.linearSize) { i -> - this.buffer.unsafeToDoubleArray()[this.bufferStart + i] - value + this.buffer.array()[this.bufferStart + i] - value } return DoubleTensor(this.shape, resBuffer) } @@ -99,14 +102,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra - newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i] + newThis.buffer.array()[i] - newOther.buffer.array()[i] } return DoubleTensor(newThis.shape, resBuffer) } override fun DoubleTensor.minusAssign(value: Double) { for (i in 0 until this.strides.linearSize) { - this.buffer.unsafeToDoubleArray()[this.bufferStart + i] -= value + this.buffer.array()[this.bufferStart + i] -= value } } @@ -117,7 +120,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra - other.buffer.unsafeToDoubleArray()[other.bufferStart + i] * this + other.buffer.array()[other.bufferStart + i] * this } return DoubleTensor(other.shape, resBuffer) } @@ -128,8 +131,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra - this.buffer.unsafeToDoubleArray()[other.bufferStart + i] * - other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this.buffer.array()[other.bufferStart + i] * + other.buffer.array()[other.bufferStart + i] } return DoubleTensor(this.shape, resBuffer) } @@ -137,21 +140,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra - this.buffer.unsafeToDoubleArray()[this.bufferStart + i].unaryMinus() + this.buffer.array()[this.bufferStart + i].unaryMinus() } return DoubleTensor(this.shape, resBuffer) } @@ -172,8 +175,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra> { public operator fun TensorType.unaryMinus(): TensorType //https://pytorch.org/cppdocs/notes/tensor_indexing.html - public fun TensorType.get(i: Int): TensorType + public operator fun TensorType.get(i: Int): TensorType //https://pytorch.org/docs/stable/generated/torch.transpose.html public fun TensorType.transpose(i: Int, j: Int): TensorType diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/utils.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/utils.kt index 74a774f45..3e32e2b72 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/utils.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/utils.kt @@ -55,8 +55,8 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List, /** * Returns a reference to [IntArray] containing all of the elements of this [Buffer]. */ -internal fun Buffer.unsafeToIntArray(): IntArray = when(this) { +internal fun Buffer.array(): IntArray = when(this) { is IntBuffer -> array else -> throw RuntimeException("Failed to cast Buffer to IntArray") } @@ -107,7 +107,7 @@ internal fun Buffer.unsafeToIntArray(): IntArray = when(this) { /** * Returns a reference to [LongArray] containing all of the elements of this [Buffer]. */ -internal fun Buffer.unsafeToLongArray(): LongArray = when(this) { +internal fun Buffer.array(): LongArray = when(this) { is LongBuffer -> array else -> throw RuntimeException("Failed to cast Buffer to LongArray") } @@ -115,7 +115,7 @@ internal fun Buffer.unsafeToLongArray(): LongArray = when(this) { /** * Returns a reference to [FloatArray] containing all of the elements of this [Buffer]. */ -internal fun Buffer.unsafeToFloatArray(): FloatArray = when(this) { +internal fun Buffer.array(): FloatArray = when(this) { is FloatBuffer -> array else -> throw RuntimeException("Failed to cast Buffer to FloatArray") } @@ -123,7 +123,7 @@ internal fun Buffer.unsafeToFloatArray(): FloatArray = when(this) { /** * Returns a reference to [DoubleArray] containing all of the elements of this [Buffer]. */ -internal fun Buffer.unsafeToDoubleArray(): DoubleArray = when(this) { +internal fun Buffer.array(): DoubleArray = when(this) { is RealBuffer -> array else -> throw RuntimeException("Failed to cast Buffer to DoubleArray") } diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensor.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensor.kt index 31c6ccbbf..b1c8cd6dd 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensor.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensor.kt @@ -21,4 +21,10 @@ class TestDoubleTensor { assertEquals(tensor[intArrayOf(0,1)], 5.8) assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray()) } + + @Test + fun getTest() = DoubleTensorAlgebra { + val tensor = DoubleTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4)) + assertTrue(tensor[0].elements().map{ it.second }.toList().toDoubleArray() contentEquals doubleArrayOf(3.5,5.8)) + } } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensorAlgebra.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensorAlgebra.kt index 91181484c..226454bf4 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensorAlgebra.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestDoubleTensorAlgebra.kt @@ -10,7 +10,7 @@ class TestDoubleTensorAlgebra { fun doublePlus() = DoubleTensorAlgebra { val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0)) val res = 10.0 + tensor - assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0)) + assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0,12.0)) } @Test @@ -18,7 +18,7 @@ class TestDoubleTensorAlgebra { val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0)) val res = tensor.transpose(0, 0) - assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0)) + assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0)) assertTrue(res.shape contentEquals intArrayOf(1)) } @@ -27,7 +27,7 @@ class TestDoubleTensorAlgebra { val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val res = tensor.transpose(1, 0) - assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) assertTrue(res.shape contentEquals intArrayOf(2, 3)) } @@ -42,9 +42,9 @@ class TestDoubleTensorAlgebra { assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1)) assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2)) - assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) - assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) + assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) + assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) } @Test @@ -70,9 +70,9 @@ class TestDoubleTensorAlgebra { assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3)) - assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) - assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0)) + assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) + assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0)) } @Test @@ -82,14 +82,14 @@ class TestDoubleTensorAlgebra { val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3)) - assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) + assertTrue((tensor2 - tensor1).buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3)) - assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray() + assertTrue((tensor3 - tensor1).buffer.array() contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)) assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3)) - assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) + assertTrue((tensor3 - tensor2).buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) } } \ No newline at end of file