slicing API changed

This commit is contained in:
Roland Grinis 2021-08-01 14:57:02 +01:00
parent bc43afe93b
commit 371674c9d3
2 changed files with 16 additions and 15 deletions

View File

@ -13,6 +13,7 @@ import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.core.TensorLinearStructure import space.kscience.kmath.tensors.core.TensorLinearStructure
typealias Slice = Pair<Int,Int>
public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>> public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
protected constructor(protected val scope: NoaScope) : protected constructor(protected val scope: NoaScope) :
@ -92,13 +93,13 @@ protected constructor(protected val scope: NoaScope) :
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
public operator fun Tensor<T>.get(dim: Int, start: Int, end: Int): TensorType = public operator fun Tensor<T>.get(dim: Int, slice: Slice): TensorType =
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end)) wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, slice.first, slice.second))
public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor<T>): Unit = public operator fun TensorType.set(dim: Int, slice: Slice, value: Tensor<T>): Unit =
JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle) JNoa.setSliceTensor(tensorHandle, dim, slice.first, slice.second, value.tensor.tensorHandle)
public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit public abstract operator fun TensorType.set(dim: Int, slice: Slice, array: PrimitiveArray): Unit
override fun diagonalEmbedding( override fun diagonalEmbedding(
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
@ -429,8 +430,8 @@ protected constructor(scope: NoaScope) :
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit = override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
JNoa.setBlobDouble(tensorHandle, i, array) JNoa.setBlobDouble(tensorHandle, i, array)
override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit = override fun NoaDoubleTensor.set(dim: Int, slice: Slice, array: DoubleArray): Unit =
JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array) JNoa.setSliceBlobDouble(tensorHandle, dim, slice.first, slice.second, array)
} }
public sealed class NoaFloatAlgebra public sealed class NoaFloatAlgebra
@ -521,8 +522,8 @@ protected constructor(scope: NoaScope) :
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit = override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
JNoa.setBlobFloat(tensorHandle, i, array) JNoa.setBlobFloat(tensorHandle, i, array)
override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit = override fun NoaFloatTensor.set(dim: Int, slice: Slice, array: FloatArray): Unit =
JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array) JNoa.setSliceBlobFloat(tensorHandle, dim, slice.first, slice.second, array)
} }
@ -599,8 +600,8 @@ protected constructor(scope: NoaScope) :
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit = override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
JNoa.setBlobLong(tensorHandle, i, array) JNoa.setBlobLong(tensorHandle, i, array)
override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit = override fun NoaLongTensor.set(dim: Int, slice: Slice, array: LongArray): Unit =
JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array) JNoa.setSliceBlobLong(tensorHandle, dim, slice.first, slice.second, array)
} }
public sealed class NoaIntAlgebra public sealed class NoaIntAlgebra
@ -676,6 +677,6 @@ protected constructor(scope: NoaScope) :
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit = override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
JNoa.setBlobInt(tensorHandle, i, array) JNoa.setBlobInt(tensorHandle, i, array)
override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit = override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit =
JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array) JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array)
} }

View File

@ -53,12 +53,12 @@ internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) {
val updateArray = floatArrayOf(15f, 20f) val updateArray = floatArrayOf(15f, 20f)
val updateTensor = full(5.0f, intArrayOf(4), device) val updateTensor = full(5.0f, intArrayOf(4), device)
updateTensor[0, 1, 3] = updateArray updateTensor[0, Slice(1, 3)] = updateArray
NoaFloat { NoaFloat {
tensor[0][1] = updateArray tensor[0][1] = updateArray
tensor[1] = updateTensor.view(intArrayOf(2, 2)) tensor[1] = updateTensor.view(intArrayOf(2, 2))
updateTensor[0, 2, 4] = updateTensor[0, 0, 2] updateTensor[0, Slice(2, 4)] = updateTensor[0, Slice(0, 2)]
}!! }!!
assertTrue( assertTrue(