slicing API changed
This commit is contained in:
parent
bc43afe93b
commit
371674c9d3
@ -13,6 +13,7 @@ import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||
|
||||
typealias Slice = Pair<Int,Int>
|
||||
|
||||
public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||
protected constructor(protected val scope: NoaScope) :
|
||||
@ -92,13 +93,13 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
||||
|
||||
public operator fun Tensor<T>.get(dim: Int, start: Int, end: Int): TensorType =
|
||||
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end))
|
||||
public operator fun Tensor<T>.get(dim: Int, slice: Slice): TensorType =
|
||||
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, slice.first, slice.second))
|
||||
|
||||
public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor<T>): Unit =
|
||||
JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle)
|
||||
public operator fun TensorType.set(dim: Int, slice: Slice, value: Tensor<T>): Unit =
|
||||
JNoa.setSliceTensor(tensorHandle, dim, slice.first, slice.second, value.tensor.tensorHandle)
|
||||
|
||||
public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit
|
||||
public abstract operator fun TensorType.set(dim: Int, slice: Slice, array: PrimitiveArray): Unit
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
||||
@ -429,8 +430,8 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
|
||||
JNoa.setBlobDouble(tensorHandle, i, array)
|
||||
|
||||
override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit =
|
||||
JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array)
|
||||
override fun NoaDoubleTensor.set(dim: Int, slice: Slice, array: DoubleArray): Unit =
|
||||
JNoa.setSliceBlobDouble(tensorHandle, dim, slice.first, slice.second, array)
|
||||
}
|
||||
|
||||
public sealed class NoaFloatAlgebra
|
||||
@ -521,8 +522,8 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
|
||||
JNoa.setBlobFloat(tensorHandle, i, array)
|
||||
|
||||
override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit =
|
||||
JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array)
|
||||
override fun NoaFloatTensor.set(dim: Int, slice: Slice, array: FloatArray): Unit =
|
||||
JNoa.setSliceBlobFloat(tensorHandle, dim, slice.first, slice.second, array)
|
||||
|
||||
}
|
||||
|
||||
@ -599,8 +600,8 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
|
||||
JNoa.setBlobLong(tensorHandle, i, array)
|
||||
|
||||
override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit =
|
||||
JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array)
|
||||
override fun NoaLongTensor.set(dim: Int, slice: Slice, array: LongArray): Unit =
|
||||
JNoa.setSliceBlobLong(tensorHandle, dim, slice.first, slice.second, array)
|
||||
}
|
||||
|
||||
public sealed class NoaIntAlgebra
|
||||
@ -676,6 +677,6 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
|
||||
JNoa.setBlobInt(tensorHandle, i, array)
|
||||
|
||||
override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit =
|
||||
JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array)
|
||||
override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit =
|
||||
JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array)
|
||||
}
|
||||
|
@ -53,12 +53,12 @@ internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) {
|
||||
|
||||
val updateArray = floatArrayOf(15f, 20f)
|
||||
val updateTensor = full(5.0f, intArrayOf(4), device)
|
||||
updateTensor[0, 1, 3] = updateArray
|
||||
updateTensor[0, Slice(1, 3)] = updateArray
|
||||
|
||||
NoaFloat {
|
||||
tensor[0][1] = updateArray
|
||||
tensor[1] = updateTensor.view(intArrayOf(2, 2))
|
||||
updateTensor[0, 2, 4] = updateTensor[0, 0, 2]
|
||||
updateTensor[0, Slice(2, 4)] = updateTensor[0, Slice(0, 2)]
|
||||
}!!
|
||||
|
||||
assertTrue(
|
||||
|
Loading…
Reference in New Issue
Block a user