slicing API changed
This commit is contained in:
parent
bc43afe93b
commit
371674c9d3
@ -13,6 +13,7 @@ import space.kscience.kmath.tensors.api.Tensor
|
|||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||||
|
|
||||||
|
typealias Slice = Pair<Int,Int>
|
||||||
|
|
||||||
public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||||
protected constructor(protected val scope: NoaScope) :
|
protected constructor(protected val scope: NoaScope) :
|
||||||
@ -92,13 +93,13 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
||||||
|
|
||||||
public operator fun Tensor<T>.get(dim: Int, start: Int, end: Int): TensorType =
|
public operator fun Tensor<T>.get(dim: Int, slice: Slice): TensorType =
|
||||||
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end))
|
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, slice.first, slice.second))
|
||||||
|
|
||||||
public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor<T>): Unit =
|
public operator fun TensorType.set(dim: Int, slice: Slice, value: Tensor<T>): Unit =
|
||||||
JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle)
|
JNoa.setSliceTensor(tensorHandle, dim, slice.first, slice.second, value.tensor.tensorHandle)
|
||||||
|
|
||||||
public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit
|
public abstract operator fun TensorType.set(dim: Int, slice: Slice, array: PrimitiveArray): Unit
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
||||||
@ -429,8 +430,8 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
|
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
|
||||||
JNoa.setBlobDouble(tensorHandle, i, array)
|
JNoa.setBlobDouble(tensorHandle, i, array)
|
||||||
|
|
||||||
override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit =
|
override fun NoaDoubleTensor.set(dim: Int, slice: Slice, array: DoubleArray): Unit =
|
||||||
JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array)
|
JNoa.setSliceBlobDouble(tensorHandle, dim, slice.first, slice.second, array)
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaFloatAlgebra
|
public sealed class NoaFloatAlgebra
|
||||||
@ -521,8 +522,8 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
|
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
|
||||||
JNoa.setBlobFloat(tensorHandle, i, array)
|
JNoa.setBlobFloat(tensorHandle, i, array)
|
||||||
|
|
||||||
override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit =
|
override fun NoaFloatTensor.set(dim: Int, slice: Slice, array: FloatArray): Unit =
|
||||||
JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array)
|
JNoa.setSliceBlobFloat(tensorHandle, dim, slice.first, slice.second, array)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -599,8 +600,8 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
|
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
|
||||||
JNoa.setBlobLong(tensorHandle, i, array)
|
JNoa.setBlobLong(tensorHandle, i, array)
|
||||||
|
|
||||||
override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit =
|
override fun NoaLongTensor.set(dim: Int, slice: Slice, array: LongArray): Unit =
|
||||||
JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array)
|
JNoa.setSliceBlobLong(tensorHandle, dim, slice.first, slice.second, array)
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaIntAlgebra
|
public sealed class NoaIntAlgebra
|
||||||
@ -676,6 +677,6 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
|
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
|
||||||
JNoa.setBlobInt(tensorHandle, i, array)
|
JNoa.setBlobInt(tensorHandle, i, array)
|
||||||
|
|
||||||
override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit =
|
override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit =
|
||||||
JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array)
|
JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array)
|
||||||
}
|
}
|
||||||
|
@ -53,12 +53,12 @@ internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) {
|
|||||||
|
|
||||||
val updateArray = floatArrayOf(15f, 20f)
|
val updateArray = floatArrayOf(15f, 20f)
|
||||||
val updateTensor = full(5.0f, intArrayOf(4), device)
|
val updateTensor = full(5.0f, intArrayOf(4), device)
|
||||||
updateTensor[0, 1, 3] = updateArray
|
updateTensor[0, Slice(1, 3)] = updateArray
|
||||||
|
|
||||||
NoaFloat {
|
NoaFloat {
|
||||||
tensor[0][1] = updateArray
|
tensor[0][1] = updateArray
|
||||||
tensor[1] = updateTensor.view(intArrayOf(2, 2))
|
tensor[1] = updateTensor.view(intArrayOf(2, 2))
|
||||||
updateTensor[0, 2, 4] = updateTensor[0, 0, 2]
|
updateTensor[0, Slice(2, 4)] = updateTensor[0, Slice(0, 2)]
|
||||||
}!!
|
}!!
|
||||||
|
|
||||||
assertTrue(
|
assertTrue(
|
||||||
|
Loading…
Reference in New Issue
Block a user