advanced slicing

This commit is contained in:
Roland Grinis 2021-07-31 21:23:15 +01:00
parent c6acedf9e0
commit 3d1a3e3b69
5 changed files with 118 additions and 15 deletions

View File

@ -77,7 +77,7 @@ import torch
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0] tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
``` ```
The most efficient way pass data between the `JVM` and the native backend The most efficient way passing data between the `JVM` and the native backend
is to rely on primitive arrays: is to rely on primitive arrays:
```kotlin ```kotlin
val array = (1..8).map { 100f * it }.toFloatArray() val array = (1..8).map { 100f * it }.toFloatArray()

View File

@ -354,4 +354,18 @@ class JNoa {
public static native void getBlobInt(long tensorHandle, int[] data); public static native void getBlobInt(long tensorHandle, int[] data);
public static native void setTensor(long tensorHandle, int i, long tensorValue);
public static native long getSliceTensor(long tensorHandle, int dim, int start, int end);
public static native void setSliceTensor(long tensorHandle, int dim, int start, int end, long tensorValue);
public static native void setSliceBlobDouble(long tensorHandle, int dim, int start, int end, double[] data);
public static native void setSliceBlobFloat(long tensorHandle, int dim, int start, int end, float[] data);
public static native void setSliceBlobLong(long tensorHandle, int dim, int start, int end, long[] data);
public static native void setSliceBlobInt(long tensorHandle, int dim, int start, int end, int[] data);
} }

View File

@ -87,6 +87,19 @@ protected constructor(protected val scope: NoaScope) :
override operator fun Tensor<T>.get(i: Int): TensorType = override operator fun Tensor<T>.get(i: Int): TensorType =
wrap(JNoa.getIndex(tensor.tensorHandle, i)) wrap(JNoa.getIndex(tensor.tensorHandle, i))
public operator fun TensorType.set(i: Int, value: Tensor<T>): Unit =
JNoa.setTensor(tensorHandle, i, value.tensor.tensorHandle)
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
public operator fun Tensor<T>.get(dim: Int, start: Int, end: Int): TensorType =
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end))
public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor<T>): Unit =
JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle)
public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit
override fun diagonalEmbedding( override fun diagonalEmbedding(
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
): TensorType = ): TensorType =
@ -166,8 +179,6 @@ protected constructor(protected val scope: NoaScope) :
public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
} }
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>> public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
@ -418,6 +429,8 @@ protected constructor(scope: NoaScope) :
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit = override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
JNoa.setBlobDouble(tensorHandle, i, array) JNoa.setBlobDouble(tensorHandle, i, array)
override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit =
JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array)
} }
public sealed class NoaFloatAlgebra public sealed class NoaFloatAlgebra
@ -508,6 +521,9 @@ protected constructor(scope: NoaScope) :
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit = override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
JNoa.setBlobFloat(tensorHandle, i, array) JNoa.setBlobFloat(tensorHandle, i, array)
override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit =
JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array)
} }
public sealed class NoaLongAlgebra public sealed class NoaLongAlgebra
@ -583,6 +599,8 @@ protected constructor(scope: NoaScope) :
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit = override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
JNoa.setBlobLong(tensorHandle, i, array) JNoa.setBlobLong(tensorHandle, i, array)
override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit =
JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array)
} }
public sealed class NoaIntAlgebra public sealed class NoaIntAlgebra
@ -658,4 +676,6 @@ protected constructor(scope: NoaScope) :
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit = override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
JNoa.setBlobInt(tensorHandle, i, array) JNoa.setBlobInt(tensorHandle, i, array)
override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit =
JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array)
} }

View File

@ -1343,6 +1343,62 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobLong
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt
(JNIEnv *, jclass, jlong, jintArray); (JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setTensor
* Signature: (JIJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setTensor
(JNIEnv *, jclass, jlong, jint, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getSliceTensor
* Signature: (JIII)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getSliceTensor
(JNIEnv *, jclass, jlong, jint, jint, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setSliceTensor
* Signature: (JIIIJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceTensor
(JNIEnv *, jclass, jlong, jint, jint, jint, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setSliceBlobDouble
* Signature: (JIII[D)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobDouble
(JNIEnv *, jclass, jlong, jint, jint, jint, jdoubleArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setSliceBlobFloat
* Signature: (JIII[F)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobFloat
(JNIEnv *, jclass, jlong, jint, jint, jint, jfloatArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setSliceBlobLong
* Signature: (JIII[J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobLong
(JNIEnv *, jclass, jlong, jint, jint, jint, jlongArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setSliceBlobInt
* Signature: (JIII[I)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobInt
(JNIEnv *, jclass, jlong, jint, jint, jint, jintArray);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -45,18 +45,31 @@ internal fun NoaFloat.testingSerialisation(tensorPath: String, device: Device =
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray()) assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
} }
internal fun NoaFloat.testingArrayTransfer(device: Device = Device.CPU) { internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) {
val array = (1..8).map { 100f * it }.toFloatArray() val array = (1..8).map { 100f * it }.toFloatArray()
val tensor = full(0.0f, intArrayOf(2, 2, 2), device)
tensor.assignFromArray(array)
assertTrue(tensor.copyToArray() contentEquals array)
val updateArray = floatArrayOf(15f, 20f) val updateArray = floatArrayOf(15f, 20f)
val resArray = NoaFloat { val updateTensor = full(5.0f, intArrayOf(4), device)
val tensor = copyFromArray(array, intArrayOf(2, 2, 2), device) updateTensor[0, 1, 3] = updateArray
NoaFloat { NoaFloat {
tensor[0][1] = updateArray tensor[0][1] = updateArray
tensor[1] = updateTensor.view(intArrayOf(2, 2))
updateTensor[0, 2, 4] = updateTensor[0, 0, 2]
}!! }!!
assertEquals(tensor[intArrayOf(0, 1, 0)], 15f)
tensor.copyToArray() assertTrue(
}!! tensor.copyToArray() contentEquals
assertEquals(resArray[3], 20f) floatArrayOf(100f, 200f, 15f, 20f, 5f, 15f, 20f, 5f)
)
assertTrue(
updateTensor.copyToArray() contentEquals
floatArrayOf(5f, 15f, 5f, 15f)
)
} }
class TestTensor { class TestTensor {
@ -110,9 +123,9 @@ class TestTensor {
}!! }!!
@Test @Test
fun testArrayTransfer() = NoaFloat { fun testBatchedGetterSetter() = NoaFloat {
withCuda { device -> withCuda { device ->
testingArrayTransfer(device) testingBatchedGetterSetter(device)
} }
}!! }!!