diff --git a/kmath-noa/README.md b/kmath-noa/README.md index c9e20eed4..876e822b2 100644 --- a/kmath-noa/README.md +++ b/kmath-noa/README.md @@ -77,7 +77,7 @@ import torch tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0] ``` -The most efficient way pass data between the `JVM` and the native backend +The most efficient way passing data between the `JVM` and the native backend is to rely on primitive arrays: ```kotlin val array = (1..8).map { 100f * it }.toFloatArray() diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 166eeba6c..5f1d56d34 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -350,8 +350,22 @@ class JNoa { public static native void getBlobFloat(long tensorHandle, float[] data); - public static native void getBlobLong(long tensorHandle, long[] data); + public static native void getBlobLong(long tensorHandle, long[] data); public static native void getBlobInt(long tensorHandle, int[] data); + public static native void setTensor(long tensorHandle, int i, long tensorValue); + + public static native long getSliceTensor(long tensorHandle, int dim, int start, int end); + + public static native void setSliceTensor(long tensorHandle, int dim, int start, int end, long tensorValue); + + public static native void setSliceBlobDouble(long tensorHandle, int dim, int start, int end, double[] data); + + public static native void setSliceBlobFloat(long tensorHandle, int dim, int start, int end, float[] data); + + public static native void setSliceBlobLong(long tensorHandle, int dim, int start, int end, long[] data); + + public static native void setSliceBlobInt(long tensorHandle, int dim, int start, int end, int[] data); + } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index ef553cdea..f1c2a94f9 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -87,6 +87,19 @@ protected constructor(protected val scope: NoaScope) : override operator fun Tensor.get(i: Int): TensorType = wrap(JNoa.getIndex(tensor.tensorHandle, i)) + public operator fun TensorType.set(i: Int, value: Tensor): Unit = + JNoa.setTensor(tensorHandle, i, value.tensor.tensorHandle) + + public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit + + public operator fun Tensor.get(dim: Int, start: Int, end: Int): TensorType = + wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end)) + + public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor): Unit = + JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle) + + public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit + override fun diagonalEmbedding( diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int ): TensorType = @@ -166,8 +179,6 @@ protected constructor(protected val scope: NoaScope) : public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit - public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit - } public sealed class NoaPartialDivisionAlgebra> @@ -418,6 +429,8 @@ protected constructor(scope: NoaScope) : override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit = JNoa.setBlobDouble(tensorHandle, i, array) + override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit = + JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array) } public sealed class NoaFloatAlgebra @@ -508,6 +521,9 @@ protected constructor(scope: NoaScope) : override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit = JNoa.setBlobFloat(tensorHandle, i, array) + override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit = + JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array) + } public sealed class NoaLongAlgebra @@ -583,6 +599,8 @@ protected constructor(scope: NoaScope) : override fun NoaLongTensor.set(i: Int, array: LongArray): Unit = JNoa.setBlobLong(tensorHandle, i, array) + override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit = + JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array) } public sealed class NoaIntAlgebra @@ -658,4 +676,6 @@ protected constructor(scope: NoaScope) : override fun NoaIntTensor.set(i: Int, array: IntArray): Unit = JNoa.setBlobInt(tensorHandle, i, array) + override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit = + JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array) } diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index 34db64490..e301f240e 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -1343,6 +1343,62 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobLong JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt (JNIEnv *, jclass, jlong, jintArray); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setTensor + * Signature: (JIJ)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setTensor + (JNIEnv *, jclass, jlong, jint, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: getSliceTensor + * Signature: (JIII)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getSliceTensor + (JNIEnv *, jclass, jlong, jint, jint, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setSliceTensor + * Signature: (JIIIJ)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceTensor + (JNIEnv *, jclass, jlong, jint, jint, jint, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setSliceBlobDouble + * Signature: (JIII[D)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobDouble + (JNIEnv *, jclass, jlong, jint, jint, jint, jdoubleArray); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setSliceBlobFloat + * Signature: (JIII[F)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobFloat + (JNIEnv *, jclass, jlong, jint, jint, jint, jfloatArray); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setSliceBlobLong + * Signature: (JIII[J)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobLong + (JNIEnv *, jclass, jlong, jint, jint, jint, jlongArray); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setSliceBlobInt + * Signature: (JIII[I)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobInt + (JNIEnv *, jclass, jlong, jint, jint, jint, jintArray); + #ifdef __cplusplus } #endif diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt index 4ecd203a0..ddf2c8ec3 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt @@ -45,18 +45,31 @@ internal fun NoaFloat.testingSerialisation(tensorPath: String, device: Device = assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray()) } -internal fun NoaFloat.testingArrayTransfer(device: Device = Device.CPU) { +internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) { val array = (1..8).map { 100f * it }.toFloatArray() + val tensor = full(0.0f, intArrayOf(2, 2, 2), device) + tensor.assignFromArray(array) + assertTrue(tensor.copyToArray() contentEquals array) + val updateArray = floatArrayOf(15f, 20f) - val resArray = NoaFloat { - val tensor = copyFromArray(array, intArrayOf(2, 2, 2), device) - NoaFloat { - tensor[0][1] = updateArray - }!! - assertEquals(tensor[intArrayOf(0, 1, 0)], 15f) - tensor.copyToArray() + val updateTensor = full(5.0f, intArrayOf(4), device) + updateTensor[0, 1, 3] = updateArray + + NoaFloat { + tensor[0][1] = updateArray + tensor[1] = updateTensor.view(intArrayOf(2, 2)) + updateTensor[0, 2, 4] = updateTensor[0, 0, 2] }!! - assertEquals(resArray[3], 20f) + + assertTrue( + tensor.copyToArray() contentEquals + floatArrayOf(100f, 200f, 15f, 20f, 5f, 15f, 20f, 5f) + ) + assertTrue( + updateTensor.copyToArray() contentEquals + floatArrayOf(5f, 15f, 5f, 15f) + ) + } class TestTensor { @@ -110,9 +123,9 @@ class TestTensor { }!! @Test - fun testArrayTransfer() = NoaFloat { + fun testBatchedGetterSetter() = NoaFloat { withCuda { device -> - testingArrayTransfer(device) + testingBatchedGetterSetter(device) } }!!