From 371674c9d3d57c25bed005f6ae545674d40a39e7 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Sun, 1 Aug 2021 14:57:02 +0100 Subject: [PATCH] slicing API changed --- .../space/kscience/kmath/noa/algebras.kt | 27 ++++++++++--------- .../space/kscience/kmath/noa/TestTensor.kt | 4 +-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index f1c2a94f9..93dbe8f62 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -13,6 +13,7 @@ import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.core.TensorLinearStructure +typealias Slice = Pair public sealed class NoaAlgebra> protected constructor(protected val scope: NoaScope) : @@ -92,13 +93,13 @@ protected constructor(protected val scope: NoaScope) : public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit - public operator fun Tensor.get(dim: Int, start: Int, end: Int): TensorType = - wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end)) + public operator fun Tensor.get(dim: Int, slice: Slice): TensorType = + wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, slice.first, slice.second)) - public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor): Unit = - JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle) + public operator fun TensorType.set(dim: Int, slice: Slice, value: Tensor): Unit = + JNoa.setSliceTensor(tensorHandle, dim, slice.first, slice.second, value.tensor.tensorHandle) - public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit + public abstract operator fun TensorType.set(dim: Int, slice: Slice, array: PrimitiveArray): Unit override fun diagonalEmbedding( diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int @@ -429,8 +430,8 @@ protected constructor(scope: NoaScope) : override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit = JNoa.setBlobDouble(tensorHandle, i, array) - override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit = - JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array) + override fun NoaDoubleTensor.set(dim: Int, slice: Slice, array: DoubleArray): Unit = + JNoa.setSliceBlobDouble(tensorHandle, dim, slice.first, slice.second, array) } public sealed class NoaFloatAlgebra @@ -521,8 +522,8 @@ protected constructor(scope: NoaScope) : override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit = JNoa.setBlobFloat(tensorHandle, i, array) - override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit = - JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array) + override fun NoaFloatTensor.set(dim: Int, slice: Slice, array: FloatArray): Unit = + JNoa.setSliceBlobFloat(tensorHandle, dim, slice.first, slice.second, array) } @@ -599,8 +600,8 @@ protected constructor(scope: NoaScope) : override fun NoaLongTensor.set(i: Int, array: LongArray): Unit = JNoa.setBlobLong(tensorHandle, i, array) - override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit = - JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array) + override fun NoaLongTensor.set(dim: Int, slice: Slice, array: LongArray): Unit = + JNoa.setSliceBlobLong(tensorHandle, dim, slice.first, slice.second, array) } public sealed class NoaIntAlgebra @@ -676,6 +677,6 @@ protected constructor(scope: NoaScope) : override fun NoaIntTensor.set(i: Int, array: IntArray): Unit = JNoa.setBlobInt(tensorHandle, i, array) - override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit = - JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array) + override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit = + JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array) } diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt index ddf2c8ec3..969c79e29 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt @@ -53,12 +53,12 @@ internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) { val updateArray = floatArrayOf(15f, 20f) val updateTensor = full(5.0f, intArrayOf(4), device) - updateTensor[0, 1, 3] = updateArray + updateTensor[0, Slice(1, 3)] = updateArray NoaFloat { tensor[0][1] = updateArray tensor[1] = updateTensor.view(intArrayOf(2, 2)) - updateTensor[0, 2, 4] = updateTensor[0, 0, 2] + updateTensor[0, Slice(2, 4)] = updateTensor[0, Slice(0, 2)] }!! assertTrue(