new functionality for primitive arrays

This commit is contained in:
Roland Grinis 2021-07-31 18:37:42 +01:00
parent 36bc127260
commit c6acedf9e0
5 changed files with 217 additions and 18 deletions

View File

@ -48,7 +48,7 @@ To load the native library you will need to add to the VM options:
The library is under active development. Many more features
will be available soon.
### Linear Algebra
### Tensors and Linear Algebra
We implement the tensor algebra interfaces
from [kmath-tensors](../kmath-tensors):
@ -70,12 +70,31 @@ NoaFloat {
tensorReg.save("tensorReg.pt")
}
```
The saved tensor can be loaded in `C++` or in `python`:
```python
import torch
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
```
The most efficient way pass data between the `JVM` and the native backend
is to rely on primitive arrays:
```kotlin
val array = (1..8).map { 100f * it }.toFloatArray()
val updateArray = floatArrayOf(15f, 20f)
val resArray = NoaFloat {
val tensor = copyFromArray(array, intArrayOf(2, 2, 2))
NoaFloat {
// The call `tensor[0]` creates a native tensor instance pointing to a slice of `tensor`
// The second call `[1]` is a setter call and does not create any new instances
tensor[0][1] = updateArray
// The instance `tensor[0]` is destroyed as we move out of the scope
}!! // if the computation fails the result fill be null
tensor.copyToArray()
// the instance `tensor` is destroyed here
}!!
```
### Automatic Differentiation
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)

View File

@ -329,4 +329,29 @@ class JNoa {
public static native void saveTensor(long tensorHandle, String path);
public static native void saveJitModule(long jitModuleHandle, String path);
public static native void assignBlobDouble(long tensorHandle, double[] data);
public static native void assignBlobFloat(long tensorHandle, float[] data);
public static native void assignBlobLong(long tensorHandle, long[] data);
public static native void assignBlobInt(long tensorHandle, int[] data);
public static native void setBlobDouble(long tensorHandle, int i, double[] data);
public static native void setBlobFloat(long tensorHandle, int i, float[] data);
public static native void setBlobLong(long tensorHandle, int i, long[] data);
public static native void setBlobInt(long tensorHandle, int i, int[] data);
public static native void getBlobDouble(long tensorHandle, double[] data);
public static native void getBlobFloat(long tensorHandle, float[] data);
public static native void getBlobLong(long tensorHandle, long[] data);
public static native void getBlobInt(long tensorHandle, int[] data);
}

View File

@ -22,6 +22,7 @@ protected constructor(protected val scope: NoaScope) :
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
@PerformancePitfall
public fun Tensor<T>.cast(): TensorType = tensor
/**
@ -38,8 +39,7 @@ protected constructor(protected val scope: NoaScope) :
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device = Device.CPU): TensorType
@PerformancePitfall
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
public abstract fun TensorType.copyToArray(): PrimitiveArray
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device = Device.CPU): TensorType
@ -164,6 +164,10 @@ protected constructor(protected val scope: NoaScope) :
public infix fun TensorType.swap(other: TensorType): Unit =
JNoa.swapTensors(tensorHandle, other.tensorHandle)
public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
}
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
@ -345,9 +349,11 @@ protected constructor(scope: NoaScope) :
override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor =
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
override fun Tensor<Double>.copyToArray(): DoubleArray =
tensor.elements().map { it.second }.toList().toDoubleArray()
override fun NoaDoubleTensor.copyToArray(): DoubleArray {
val array = DoubleArray(numElements)
JNoa.getBlobDouble(tensorHandle, array)
return array
}
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
@ -405,6 +411,13 @@ protected constructor(scope: NoaScope) :
override fun loadTensor(path: String, device: Device): NoaDoubleTensor =
wrap(JNoa.loadTensorDouble(path, device.toInt()))
override fun NoaDoubleTensor.assignFromArray(array: DoubleArray): Unit =
JNoa.assignBlobDouble(tensorHandle, array)
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
JNoa.setBlobDouble(tensorHandle, i, array)
}
public sealed class NoaFloatAlgebra
@ -426,9 +439,11 @@ protected constructor(scope: NoaScope) :
override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor =
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
override fun Tensor<Float>.copyToArray(): FloatArray =
tensor.elements().map { it.second }.toList().toFloatArray()
override fun NoaFloatTensor.copyToArray(): FloatArray {
val res = FloatArray(numElements)
JNoa.getBlobFloat(tensorHandle, res)
return res
}
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
@ -486,6 +501,13 @@ protected constructor(scope: NoaScope) :
override fun loadTensor(path: String, device: Device): NoaFloatTensor =
wrap(JNoa.loadTensorFloat(path, device.toInt()))
override fun NoaFloatTensor.assignFromArray(array: FloatArray): Unit =
JNoa.assignBlobFloat(tensorHandle, array)
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
JNoa.setBlobFloat(tensorHandle, i, array)
}
public sealed class NoaLongAlgebra
@ -507,9 +529,11 @@ protected constructor(scope: NoaScope) :
override fun wrap(tensorHandle: TensorHandle): NoaLongTensor =
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
override fun Tensor<Long>.copyToArray(): LongArray =
tensor.elements().map { it.second }.toList().toLongArray()
override fun NoaLongTensor.copyToArray(): LongArray {
val array = LongArray(numElements)
JNoa.getBlobLong(tensorHandle, array)
return array
}
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
@ -553,6 +577,12 @@ protected constructor(scope: NoaScope) :
override fun loadTensor(path: String, device: Device): NoaLongTensor =
wrap(JNoa.loadTensorLong(path, device.toInt()))
override fun NoaLongTensor.assignFromArray(array: LongArray): Unit =
JNoa.assignBlobLong(tensorHandle, array)
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
JNoa.setBlobLong(tensorHandle, i, array)
}
public sealed class NoaIntAlgebra
@ -574,9 +604,11 @@ protected constructor(scope: NoaScope) :
override fun wrap(tensorHandle: TensorHandle): NoaIntTensor =
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
override fun Tensor<Int>.copyToArray(): IntArray =
tensor.elements().map { it.second }.toList().toIntArray()
override fun NoaIntTensor.copyToArray(): IntArray {
val array = IntArray(numElements)
JNoa.getBlobInt(tensorHandle, array)
return array
}
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
@ -620,4 +652,10 @@ protected constructor(scope: NoaScope) :
override fun loadTensor(path: String, device: Device): NoaIntTensor =
wrap(JNoa.loadTensorInt(path, device.toInt()))
override fun NoaIntTensor.assignFromArray(array: IntArray): Unit =
JNoa.assignBlobInt(tensorHandle, array)
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
JNoa.setBlobInt(tensorHandle, i, array)
}

View File

@ -1247,6 +1247,102 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveTensor
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveJitModule
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: assignBlobDouble
* Signature: (J[D)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobDouble
(JNIEnv *, jclass, jlong, jdoubleArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: assignBlobFloat
* Signature: (J[F)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobFloat
(JNIEnv *, jclass, jlong, jfloatArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: assignBlobLong
* Signature: (J[J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobLong
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: assignBlobInt
* Signature: (J[I)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobInt
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setBlobDouble
* Signature: (JI[D)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobDouble
(JNIEnv *, jclass, jlong, jint, jdoubleArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setBlobFloat
* Signature: (JI[F)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobFloat
(JNIEnv *, jclass, jlong, jint, jfloatArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setBlobLong
* Signature: (JI[J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobLong
(JNIEnv *, jclass, jlong, jint, jlongArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setBlobInt
* Signature: (JI[I)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobInt
(JNIEnv *, jclass, jlong, jint, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getBlobDouble
* Signature: (J[D)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobDouble
(JNIEnv *, jclass, jlong, jdoubleArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getBlobFloat
* Signature: (J[F)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobFloat
(JNIEnv *, jclass, jlong, jfloatArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getBlobLong
* Signature: (J[J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobLong
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getBlobInt
* Signature: (J[I)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt
(JNIEnv *, jclass, jlong, jintArray);
#ifdef __cplusplus
}
#endif

View File

@ -38,13 +38,27 @@ internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
assertEquals(tensor[intArrayOf(0)], 10)
}
internal fun NoaFloat.testingTensorSerialisation(tensorPath: String, device: Device = Device.CPU){
internal fun NoaFloat.testingSerialisation(tensorPath: String, device: Device = Device.CPU) {
val tensor = copyFromArray(floatArrayOf(45.5f, 98.6f), intArrayOf(2), device)
tensor.save(tensorPath)
val loadedTensor = loadTensor(tensorPath, device)
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
}
internal fun NoaFloat.testingArrayTransfer(device: Device = Device.CPU) {
val array = (1..8).map { 100f * it }.toFloatArray()
val updateArray = floatArrayOf(15f, 20f)
val resArray = NoaFloat {
val tensor = copyFromArray(array, intArrayOf(2, 2, 2), device)
NoaFloat {
tensor[0][1] = updateArray
}!!
assertEquals(tensor[intArrayOf(0, 1, 0)], 15f)
tensor.copyToArray()
}!!
assertEquals(resArray[3], 20f)
}
class TestTensor {
private val resources = File("").resolve("src/test/resources")
@ -89,9 +103,16 @@ class TestTensor {
}!!
@Test
fun tensorSerialisation() = NoaFloat {
fun testSerialisation() = NoaFloat {
withCuda { device ->
testingTensorSerialisation(tensorPath, device)
testingSerialisation(tensorPath, device)
}
}!!
@Test
fun testArrayTransfer() = NoaFloat {
withCuda { device ->
testingArrayTransfer(device)
}
}!!