forked from kscience/kmath
advanced slicing
This commit is contained in:
parent
c6acedf9e0
commit
3d1a3e3b69
@ -77,7 +77,7 @@ import torch
|
|||||||
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
|
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
|
||||||
```
|
```
|
||||||
|
|
||||||
The most efficient way pass data between the `JVM` and the native backend
|
The most efficient way passing data between the `JVM` and the native backend
|
||||||
is to rely on primitive arrays:
|
is to rely on primitive arrays:
|
||||||
```kotlin
|
```kotlin
|
||||||
val array = (1..8).map { 100f * it }.toFloatArray()
|
val array = (1..8).map { 100f * it }.toFloatArray()
|
||||||
|
@ -350,8 +350,22 @@ class JNoa {
|
|||||||
|
|
||||||
public static native void getBlobFloat(long tensorHandle, float[] data);
|
public static native void getBlobFloat(long tensorHandle, float[] data);
|
||||||
|
|
||||||
public static native void getBlobLong(long tensorHandle, long[] data);
|
public static native void getBlobLong(long tensorHandle, long[] data);
|
||||||
|
|
||||||
public static native void getBlobInt(long tensorHandle, int[] data);
|
public static native void getBlobInt(long tensorHandle, int[] data);
|
||||||
|
|
||||||
|
public static native void setTensor(long tensorHandle, int i, long tensorValue);
|
||||||
|
|
||||||
|
public static native long getSliceTensor(long tensorHandle, int dim, int start, int end);
|
||||||
|
|
||||||
|
public static native void setSliceTensor(long tensorHandle, int dim, int start, int end, long tensorValue);
|
||||||
|
|
||||||
|
public static native void setSliceBlobDouble(long tensorHandle, int dim, int start, int end, double[] data);
|
||||||
|
|
||||||
|
public static native void setSliceBlobFloat(long tensorHandle, int dim, int start, int end, float[] data);
|
||||||
|
|
||||||
|
public static native void setSliceBlobLong(long tensorHandle, int dim, int start, int end, long[] data);
|
||||||
|
|
||||||
|
public static native void setSliceBlobInt(long tensorHandle, int dim, int start, int end, int[] data);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -87,6 +87,19 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
override operator fun Tensor<T>.get(i: Int): TensorType =
|
override operator fun Tensor<T>.get(i: Int): TensorType =
|
||||||
wrap(JNoa.getIndex(tensor.tensorHandle, i))
|
wrap(JNoa.getIndex(tensor.tensorHandle, i))
|
||||||
|
|
||||||
|
public operator fun TensorType.set(i: Int, value: Tensor<T>): Unit =
|
||||||
|
JNoa.setTensor(tensorHandle, i, value.tensor.tensorHandle)
|
||||||
|
|
||||||
|
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
||||||
|
|
||||||
|
public operator fun Tensor<T>.get(dim: Int, start: Int, end: Int): TensorType =
|
||||||
|
wrap(JNoa.getSliceTensor(tensor.tensorHandle, dim, start, end))
|
||||||
|
|
||||||
|
public operator fun TensorType.set(dim: Int, start: Int, end: Int, value: Tensor<T>): Unit =
|
||||||
|
JNoa.setSliceTensor(tensorHandle, dim, start, end, value.tensor.tensorHandle)
|
||||||
|
|
||||||
|
public abstract operator fun TensorType.set(dim: Int, start: Int, end: Int, array: PrimitiveArray): Unit
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
||||||
): TensorType =
|
): TensorType =
|
||||||
@ -166,8 +179,6 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit
|
public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit
|
||||||
|
|
||||||
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||||
@ -418,6 +429,8 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
|
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
|
||||||
JNoa.setBlobDouble(tensorHandle, i, array)
|
JNoa.setBlobDouble(tensorHandle, i, array)
|
||||||
|
|
||||||
|
override fun NoaDoubleTensor.set(dim: Int, start: Int, end: Int, array: DoubleArray): Unit =
|
||||||
|
JNoa.setSliceBlobDouble(tensorHandle, dim, start, end, array)
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaFloatAlgebra
|
public sealed class NoaFloatAlgebra
|
||||||
@ -508,6 +521,9 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
|
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
|
||||||
JNoa.setBlobFloat(tensorHandle, i, array)
|
JNoa.setBlobFloat(tensorHandle, i, array)
|
||||||
|
|
||||||
|
override fun NoaFloatTensor.set(dim: Int, start: Int, end: Int, array: FloatArray): Unit =
|
||||||
|
JNoa.setSliceBlobFloat(tensorHandle, dim, start, end, array)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaLongAlgebra
|
public sealed class NoaLongAlgebra
|
||||||
@ -583,6 +599,8 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
|
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
|
||||||
JNoa.setBlobLong(tensorHandle, i, array)
|
JNoa.setBlobLong(tensorHandle, i, array)
|
||||||
|
|
||||||
|
override fun NoaLongTensor.set(dim: Int, start: Int, end: Int, array: LongArray): Unit =
|
||||||
|
JNoa.setSliceBlobLong(tensorHandle, dim, start, end, array)
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaIntAlgebra
|
public sealed class NoaIntAlgebra
|
||||||
@ -658,4 +676,6 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
|
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
|
||||||
JNoa.setBlobInt(tensorHandle, i, array)
|
JNoa.setBlobInt(tensorHandle, i, array)
|
||||||
|
|
||||||
|
override fun NoaIntTensor.set(dim: Int, start: Int, end: Int, array: IntArray): Unit =
|
||||||
|
JNoa.setSliceBlobInt(tensorHandle, dim, start, end, array)
|
||||||
}
|
}
|
||||||
|
@ -1343,6 +1343,62 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobLong
|
|||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt
|
||||||
(JNIEnv *, jclass, jlong, jintArray);
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setTensor
|
||||||
|
* Signature: (JIJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: getSliceTensor
|
||||||
|
* Signature: (JIII)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getSliceTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setSliceTensor
|
||||||
|
* Signature: (JIIIJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint, jint, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setSliceBlobDouble
|
||||||
|
* Signature: (JIII[D)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobDouble
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint, jint, jdoubleArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setSliceBlobFloat
|
||||||
|
* Signature: (JIII[F)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobFloat
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint, jint, jfloatArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setSliceBlobLong
|
||||||
|
* Signature: (JIII[J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobLong
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint, jint, jlongArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setSliceBlobInt
|
||||||
|
* Signature: (JIII[I)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSliceBlobInt
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint, jint, jintArray);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -45,18 +45,31 @@ internal fun NoaFloat.testingSerialisation(tensorPath: String, device: Device =
|
|||||||
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
|
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun NoaFloat.testingArrayTransfer(device: Device = Device.CPU) {
|
internal fun NoaFloat.testingBatchedGetterSetter(device: Device = Device.CPU) {
|
||||||
val array = (1..8).map { 100f * it }.toFloatArray()
|
val array = (1..8).map { 100f * it }.toFloatArray()
|
||||||
|
val tensor = full(0.0f, intArrayOf(2, 2, 2), device)
|
||||||
|
tensor.assignFromArray(array)
|
||||||
|
assertTrue(tensor.copyToArray() contentEquals array)
|
||||||
|
|
||||||
val updateArray = floatArrayOf(15f, 20f)
|
val updateArray = floatArrayOf(15f, 20f)
|
||||||
val resArray = NoaFloat {
|
val updateTensor = full(5.0f, intArrayOf(4), device)
|
||||||
val tensor = copyFromArray(array, intArrayOf(2, 2, 2), device)
|
updateTensor[0, 1, 3] = updateArray
|
||||||
NoaFloat {
|
|
||||||
tensor[0][1] = updateArray
|
NoaFloat {
|
||||||
}!!
|
tensor[0][1] = updateArray
|
||||||
assertEquals(tensor[intArrayOf(0, 1, 0)], 15f)
|
tensor[1] = updateTensor.view(intArrayOf(2, 2))
|
||||||
tensor.copyToArray()
|
updateTensor[0, 2, 4] = updateTensor[0, 0, 2]
|
||||||
}!!
|
}!!
|
||||||
assertEquals(resArray[3], 20f)
|
|
||||||
|
assertTrue(
|
||||||
|
tensor.copyToArray() contentEquals
|
||||||
|
floatArrayOf(100f, 200f, 15f, 20f, 5f, 15f, 20f, 5f)
|
||||||
|
)
|
||||||
|
assertTrue(
|
||||||
|
updateTensor.copyToArray() contentEquals
|
||||||
|
floatArrayOf(5f, 15f, 5f, 15f)
|
||||||
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestTensor {
|
class TestTensor {
|
||||||
@ -110,9 +123,9 @@ class TestTensor {
|
|||||||
}!!
|
}!!
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testArrayTransfer() = NoaFloat {
|
fun testBatchedGetterSetter() = NoaFloat {
|
||||||
withCuda { device ->
|
withCuda { device ->
|
||||||
testingArrayTransfer(device)
|
testingBatchedGetterSetter(device)
|
||||||
}
|
}
|
||||||
}!!
|
}!!
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user