new functionality for primitive arrays
This commit is contained in:
parent
36bc127260
commit
c6acedf9e0
@ -48,7 +48,7 @@ To load the native library you will need to add to the VM options:
|
||||
The library is under active development. Many more features
|
||||
will be available soon.
|
||||
|
||||
### Linear Algebra
|
||||
### Tensors and Linear Algebra
|
||||
|
||||
We implement the tensor algebra interfaces
|
||||
from [kmath-tensors](../kmath-tensors):
|
||||
@ -70,12 +70,31 @@ NoaFloat {
|
||||
tensorReg.save("tensorReg.pt")
|
||||
}
|
||||
```
|
||||
|
||||
The saved tensor can be loaded in `C++` or in `python`:
|
||||
```python
|
||||
import torch
|
||||
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
|
||||
```
|
||||
|
||||
The most efficient way pass data between the `JVM` and the native backend
|
||||
is to rely on primitive arrays:
|
||||
```kotlin
|
||||
val array = (1..8).map { 100f * it }.toFloatArray()
|
||||
val updateArray = floatArrayOf(15f, 20f)
|
||||
val resArray = NoaFloat {
|
||||
val tensor = copyFromArray(array, intArrayOf(2, 2, 2))
|
||||
NoaFloat {
|
||||
// The call `tensor[0]` creates a native tensor instance pointing to a slice of `tensor`
|
||||
// The second call `[1]` is a setter call and does not create any new instances
|
||||
tensor[0][1] = updateArray
|
||||
// The instance `tensor[0]` is destroyed as we move out of the scope
|
||||
}!! // if the computation fails the result fill be null
|
||||
tensor.copyToArray()
|
||||
// the instance `tensor` is destroyed here
|
||||
}!!
|
||||
|
||||
```
|
||||
|
||||
### Automatic Differentiation
|
||||
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
|
||||
|
@ -329,4 +329,29 @@ class JNoa {
|
||||
public static native void saveTensor(long tensorHandle, String path);
|
||||
|
||||
public static native void saveJitModule(long jitModuleHandle, String path);
|
||||
|
||||
public static native void assignBlobDouble(long tensorHandle, double[] data);
|
||||
|
||||
public static native void assignBlobFloat(long tensorHandle, float[] data);
|
||||
|
||||
public static native void assignBlobLong(long tensorHandle, long[] data);
|
||||
|
||||
public static native void assignBlobInt(long tensorHandle, int[] data);
|
||||
|
||||
public static native void setBlobDouble(long tensorHandle, int i, double[] data);
|
||||
|
||||
public static native void setBlobFloat(long tensorHandle, int i, float[] data);
|
||||
|
||||
public static native void setBlobLong(long tensorHandle, int i, long[] data);
|
||||
|
||||
public static native void setBlobInt(long tensorHandle, int i, int[] data);
|
||||
|
||||
public static native void getBlobDouble(long tensorHandle, double[] data);
|
||||
|
||||
public static native void getBlobFloat(long tensorHandle, float[] data);
|
||||
|
||||
public static native void getBlobLong(long tensorHandle, long[] data);
|
||||
|
||||
public static native void getBlobInt(long tensorHandle, int[] data);
|
||||
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
|
||||
|
||||
@PerformancePitfall
|
||||
public fun Tensor<T>.cast(): TensorType = tensor
|
||||
|
||||
/**
|
||||
@ -38,8 +39,7 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device = Device.CPU): TensorType
|
||||
|
||||
@PerformancePitfall
|
||||
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
|
||||
public abstract fun TensorType.copyToArray(): PrimitiveArray
|
||||
|
||||
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device = Device.CPU): TensorType
|
||||
|
||||
@ -164,6 +164,10 @@ protected constructor(protected val scope: NoaScope) :
|
||||
public infix fun TensorType.swap(other: TensorType): Unit =
|
||||
JNoa.swapTensors(tensorHandle, other.tensorHandle)
|
||||
|
||||
public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit
|
||||
|
||||
public abstract operator fun TensorType.set(i: Int, array: PrimitiveArray): Unit
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||
@ -345,9 +349,11 @@ protected constructor(scope: NoaScope) :
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor =
|
||||
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun Tensor<Double>.copyToArray(): DoubleArray =
|
||||
tensor.elements().map { it.second }.toList().toDoubleArray()
|
||||
override fun NoaDoubleTensor.copyToArray(): DoubleArray {
|
||||
val array = DoubleArray(numElements)
|
||||
JNoa.getBlobDouble(tensorHandle, array)
|
||||
return array
|
||||
}
|
||||
|
||||
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
|
||||
@ -405,6 +411,13 @@ protected constructor(scope: NoaScope) :
|
||||
|
||||
override fun loadTensor(path: String, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.loadTensorDouble(path, device.toInt()))
|
||||
|
||||
override fun NoaDoubleTensor.assignFromArray(array: DoubleArray): Unit =
|
||||
JNoa.assignBlobDouble(tensorHandle, array)
|
||||
|
||||
override fun NoaDoubleTensor.set(i: Int, array: DoubleArray): Unit =
|
||||
JNoa.setBlobDouble(tensorHandle, i, array)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaFloatAlgebra
|
||||
@ -426,9 +439,11 @@ protected constructor(scope: NoaScope) :
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor =
|
||||
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun Tensor<Float>.copyToArray(): FloatArray =
|
||||
tensor.elements().map { it.second }.toList().toFloatArray()
|
||||
override fun NoaFloatTensor.copyToArray(): FloatArray {
|
||||
val res = FloatArray(numElements)
|
||||
JNoa.getBlobFloat(tensorHandle, res)
|
||||
return res
|
||||
}
|
||||
|
||||
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
|
||||
@ -486,6 +501,13 @@ protected constructor(scope: NoaScope) :
|
||||
|
||||
override fun loadTensor(path: String, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.loadTensorFloat(path, device.toInt()))
|
||||
|
||||
override fun NoaFloatTensor.assignFromArray(array: FloatArray): Unit =
|
||||
JNoa.assignBlobFloat(tensorHandle, array)
|
||||
|
||||
override fun NoaFloatTensor.set(i: Int, array: FloatArray): Unit =
|
||||
JNoa.setBlobFloat(tensorHandle, i, array)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaLongAlgebra
|
||||
@ -507,9 +529,11 @@ protected constructor(scope: NoaScope) :
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaLongTensor =
|
||||
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun Tensor<Long>.copyToArray(): LongArray =
|
||||
tensor.elements().map { it.second }.toList().toLongArray()
|
||||
override fun NoaLongTensor.copyToArray(): LongArray {
|
||||
val array = LongArray(numElements)
|
||||
JNoa.getBlobLong(tensorHandle, array)
|
||||
return array
|
||||
}
|
||||
|
||||
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
|
||||
wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
|
||||
@ -553,6 +577,12 @@ protected constructor(scope: NoaScope) :
|
||||
override fun loadTensor(path: String, device: Device): NoaLongTensor =
|
||||
wrap(JNoa.loadTensorLong(path, device.toInt()))
|
||||
|
||||
override fun NoaLongTensor.assignFromArray(array: LongArray): Unit =
|
||||
JNoa.assignBlobLong(tensorHandle, array)
|
||||
|
||||
override fun NoaLongTensor.set(i: Int, array: LongArray): Unit =
|
||||
JNoa.setBlobLong(tensorHandle, i, array)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaIntAlgebra
|
||||
@ -574,9 +604,11 @@ protected constructor(scope: NoaScope) :
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaIntTensor =
|
||||
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun Tensor<Int>.copyToArray(): IntArray =
|
||||
tensor.elements().map { it.second }.toList().toIntArray()
|
||||
override fun NoaIntTensor.copyToArray(): IntArray {
|
||||
val array = IntArray(numElements)
|
||||
JNoa.getBlobInt(tensorHandle, array)
|
||||
return array
|
||||
}
|
||||
|
||||
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
|
||||
wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
|
||||
@ -620,4 +652,10 @@ protected constructor(scope: NoaScope) :
|
||||
override fun loadTensor(path: String, device: Device): NoaIntTensor =
|
||||
wrap(JNoa.loadTensorInt(path, device.toInt()))
|
||||
|
||||
override fun NoaIntTensor.assignFromArray(array: IntArray): Unit =
|
||||
JNoa.assignBlobInt(tensorHandle, array)
|
||||
|
||||
override fun NoaIntTensor.set(i: Int, array: IntArray): Unit =
|
||||
JNoa.setBlobInt(tensorHandle, i, array)
|
||||
|
||||
}
|
||||
|
@ -1247,6 +1247,102 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveTensor
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveJitModule
|
||||
(JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: assignBlobDouble
|
||||
* Signature: (J[D)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobDouble
|
||||
(JNIEnv *, jclass, jlong, jdoubleArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: assignBlobFloat
|
||||
* Signature: (J[F)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobFloat
|
||||
(JNIEnv *, jclass, jlong, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: assignBlobLong
|
||||
* Signature: (J[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobLong
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: assignBlobInt
|
||||
* Signature: (J[I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_assignBlobInt
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setBlobDouble
|
||||
* Signature: (JI[D)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobDouble
|
||||
(JNIEnv *, jclass, jlong, jint, jdoubleArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setBlobFloat
|
||||
* Signature: (JI[F)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobFloat
|
||||
(JNIEnv *, jclass, jlong, jint, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setBlobLong
|
||||
* Signature: (JI[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobLong
|
||||
(JNIEnv *, jclass, jlong, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setBlobInt
|
||||
* Signature: (JI[I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setBlobInt
|
||||
(JNIEnv *, jclass, jlong, jint, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getBlobDouble
|
||||
* Signature: (J[D)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobDouble
|
||||
(JNIEnv *, jclass, jlong, jdoubleArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getBlobFloat
|
||||
* Signature: (J[F)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobFloat
|
||||
(JNIEnv *, jclass, jlong, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getBlobLong
|
||||
* Signature: (J[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobLong
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getBlobInt
|
||||
* Signature: (J[I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_getBlobInt
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -38,13 +38,27 @@ internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
|
||||
assertEquals(tensor[intArrayOf(0)], 10)
|
||||
}
|
||||
|
||||
internal fun NoaFloat.testingTensorSerialisation(tensorPath: String, device: Device = Device.CPU){
|
||||
internal fun NoaFloat.testingSerialisation(tensorPath: String, device: Device = Device.CPU) {
|
||||
val tensor = copyFromArray(floatArrayOf(45.5f, 98.6f), intArrayOf(2), device)
|
||||
tensor.save(tensorPath)
|
||||
val loadedTensor = loadTensor(tensorPath, device)
|
||||
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
|
||||
}
|
||||
|
||||
internal fun NoaFloat.testingArrayTransfer(device: Device = Device.CPU) {
|
||||
val array = (1..8).map { 100f * it }.toFloatArray()
|
||||
val updateArray = floatArrayOf(15f, 20f)
|
||||
val resArray = NoaFloat {
|
||||
val tensor = copyFromArray(array, intArrayOf(2, 2, 2), device)
|
||||
NoaFloat {
|
||||
tensor[0][1] = updateArray
|
||||
}!!
|
||||
assertEquals(tensor[intArrayOf(0, 1, 0)], 15f)
|
||||
tensor.copyToArray()
|
||||
}!!
|
||||
assertEquals(resArray[3], 20f)
|
||||
}
|
||||
|
||||
class TestTensor {
|
||||
|
||||
private val resources = File("").resolve("src/test/resources")
|
||||
@ -89,9 +103,16 @@ class TestTensor {
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun tensorSerialisation() = NoaFloat {
|
||||
fun testSerialisation() = NoaFloat {
|
||||
withCuda { device ->
|
||||
testingTensorSerialisation(tensorPath, device)
|
||||
testingSerialisation(tensorPath, device)
|
||||
}
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testArrayTransfer() = NoaFloat {
|
||||
withCuda { device ->
|
||||
testingArrayTransfer(device)
|
||||
}
|
||||
}!!
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user