forked from kscience/kmath
advanced slicing
This commit is contained in:
parent
c6acedf9e0
commit
3d1a3e3b69
@ -77,7 +77,7 @@ import torch
|
||||
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
|
||||
```
|
||||
|
||||
The most efficient way pass data between the `JVM` and the native backend
|
||||
The most efficient way passing data between the `JVM` and the native backend
|
||||
is to rely on primitive arrays:
|
||||
```kotlin
|
||||
val array = (1..8).map { 100f * it }.toFloatArray()
|
||||
|
@ -350,8 +350,22 @@ class JNoa {
|
||||
|
||||
public static native void getBlobFloat(long tensorHandle, float[] data);
|
||||
|
||||
public static native void getBlobLong(long tensorHandle, long[] data);
|
||||
public static native void getBlobLong(long tensorHandle, long[] data);
|
||||
|
||||
public static native void getBlobInt(long tensorHandle, int[] data);
|
||||
|
||||
public static native void setTensor(long tensorHandle, int i, long tensorValue);
|
||||
|
||||
public static native long getSliceTensor(long tensorHandle, int dim, int start, int end);
|
||||
|
||||
public static native void setSliceTensor(long tensorHandle, int dim, int start, int end, long tensorValue);
|
||||
|
||||
public static native void setSliceBlobDouble(long tensorHandle, int dim, int start, int end, double[] data);
|
||||
|
||||
public static native void setSliceBlobFloat(long tensorHandle, int dim, int start, int end, float[] data);
|
||||
|
||||
public static native void setSliceBlobLong(long tensorHandle, int dim, int start, int end, long[] data);
|
||||
|
||||
public static native void setSliceBlobInt(long tensorHandle, int dim, int start, int end, int[] data);
|
||||
|
||||
}
|
||||
|
@ -87,6 +87,19 @@ protected constructor(protected val scope: NoaScope) :
|
||||
override operator fun Tensor<T>.get(i: Int): TensorType =
|
||||
wrap(JNoa.getIndex(tensor.tensorHandle, i))
|
||||
|
||||
public operator fun TensorType.set(i: Int, value: Tensor<T>): Unit =
|
||||
JNoa.setTensor(tensorHandle, i, value.tensor.tensorHandle)
|
||||
|
||||
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
||||
|
||||
public operator fun Tensor<T>.get(dim: Int, start: Int, end: Int): TensorType =
|
||||
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end))
|
||||
|
||||
public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor<T>): Unit =
|
||||
JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle)
|
||||
|
||||
public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
||||
): TensorType =
|
||||
@ -166,8 +179,6 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit
|
||||
|
||||
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||
@ -418,6 +429,8 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
|
||||
JNoa.setBlobDouble(tensorHandle, i, array)
|
||||
|
||||
override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit =
|
||||
JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array)
|
||||
}
|
||||
|
||||
public sealed class NoaFloatAlgebra
|
||||
@ -508,6 +521,9 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
|
||||
JNoa.setBlobFloat(tensorHandle, i, array)
|
||||
|
||||
override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit =
|
||||
JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaLongAlgebra
|
||||
@ -583,6 +599,8 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
|
||||
JNoa.setBlobLong(tensorHandle, i, array)
|
||||
|
||||
override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit =
|
||||
JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array)
|
||||
}
|
||||
|
||||
public sealed class NoaIntAlgebra
|
||||
@ -658,4 +676,6 @@ protected constructor(scope: NoaScope) :
|
||||
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
|
||||
JNoa.setBlobInt(tensorHandle, i, array)
|
||||
|
||||
override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit =
|
||||
JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array)
|
||||
}
|
||||
|
@ -1343,6 +1343,62 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobLong
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setTensor
|
||||
* Signature: (JIJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setTensor
|
||||
(JNIEnv *, jclass, jlong, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getSliceTensor
|
||||
* Signature: (JIII)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getSliceTensor
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setSliceTensor
|
||||
* Signature: (JIIIJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceTensor
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setSliceBlobDouble
|
||||
* Signature: (JIII[D)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobDouble
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint, jdoubleArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setSliceBlobFloat
|
||||
* Signature: (JIII[F)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobFloat
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setSliceBlobLong
|
||||
* Signature: (JIII[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobLong
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setSliceBlobInt
|
||||
* Signature: (JIII[I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobInt
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint, jintArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -45,18 +45,31 @@ internal fun NoaFloat.testingSerialisation(tensorPath: String, device: Device =
|
||||
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
|
||||
}
|
||||
|
||||
internal fun NoaFloat.testingArrayTransfer(device: Device = Device.CPU) {
|
||||
internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) {
|
||||
val array = (1..8).map { 100f * it }.toFloatArray()
|
||||
val tensor = full(0.0f, intArrayOf(2, 2, 2), device)
|
||||
tensor.assignFromArray(array)
|
||||
assertTrue(tensor.copyToArray() contentEquals array)
|
||||
|
||||
val updateArray = floatArrayOf(15f, 20f)
|
||||
val resArray = NoaFloat {
|
||||
val tensor = copyFromArray(array, intArrayOf(2, 2, 2), device)
|
||||
NoaFloat {
|
||||
tensor[0][1] = updateArray
|
||||
}!!
|
||||
assertEquals(tensor[intArrayOf(0, 1, 0)], 15f)
|
||||
tensor.copyToArray()
|
||||
val updateTensor = full(5.0f, intArrayOf(4), device)
|
||||
updateTensor[0, 1, 3] = updateArray
|
||||
|
||||
NoaFloat {
|
||||
tensor[0][1] = updateArray
|
||||
tensor[1] = updateTensor.view(intArrayOf(2, 2))
|
||||
updateTensor[0, 2, 4] = updateTensor[0, 0, 2]
|
||||
}!!
|
||||
assertEquals(resArray[3], 20f)
|
||||
|
||||
assertTrue(
|
||||
tensor.copyToArray() contentEquals
|
||||
floatArrayOf(100f, 200f, 15f, 20f, 5f, 15f, 20f, 5f)
|
||||
)
|
||||
assertTrue(
|
||||
updateTensor.copyToArray() contentEquals
|
||||
floatArrayOf(5f, 15f, 5f, 15f)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
class TestTensor {
|
||||
@ -110,9 +123,9 @@ class TestTensor {
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testArrayTransfer() = NoaFloat {
|
||||
fun testBatchedGetterSetter() = NoaFloat {
|
||||
withCuda { device ->
|
||||
testingArrayTransfer(device)
|
||||
testingBatchedGetterSetter(device)
|
||||
}
|
||||
}!!
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user