JNI wrapper

This commit is contained in:
Roland Grinis 2021-01-19 09:58:47 +00:00
parent d599d1132b
commit 17e6ebbc14
19 changed files with 1643 additions and 186 deletions

3
.gitignore vendored
View File

@ -10,6 +10,7 @@ out/
# Cache of project # Cache of project
.gradletasknamecache .gradletasknamecache
# Generated by javac -h # Generated by javac -h and runtime
*.class *.class
*.log

View File

@ -1,6 +1,6 @@
# LibTorch extension (`kmath-torch`) # LibTorch extension (`kmath-torch`)
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of This is a `Kotlin/Native` & `JVM` module, with only `linuxX64` supported so far. The library wraps some of
the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`. the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
## Installation ## Installation
@ -11,7 +11,7 @@ To install the library, you have to build & publish locally `kmath-core`, `kmath
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal ./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
``` ```
This builds `ctorch`, a C wrapper for `LibTorch` placed inside: This builds `ctorch` a C wrapper and `jtorch` a JNI wrapper for `LibTorch`, placed inside:
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build` `~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
@ -19,8 +19,8 @@ You will have to link against it in your own project.
## Usage ## Usage
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods Tensors are implemented over the `MutableNDStructure`. They can only be created through provided factory methods
and require scoping: and require scoping within a `TensorAlgebra` instance:
```kotlin ```kotlin
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
@ -63,4 +63,4 @@ TorchTensorRealAlgebra {
val hessianAtX = expressionAtX hess tensorX val hessianAtX = expressionAtX hess tensorX
} }
``` ```
Contributed by [Roland Grinis](https://github.com/rgrit91)

View File

@ -127,7 +127,6 @@ val generateJNIHeader by tasks.registering {
kotlin { kotlin {
explicitApiWarning() explicitApiWarning()
jvm { jvm {
withJava() withJava()
} }
@ -158,7 +157,6 @@ kotlin {
val test by nativeTarget.compilations.getting val test by nativeTarget.compilations.getting
sourceSets { sourceSets {
val commonMain by getting { val commonMain by getting {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
@ -171,6 +169,11 @@ kotlin {
} }
} }
val jvmMain by getting {
dependencies {
api(project(":kmath-core"))
}
}
} }
} }

View File

@ -25,12 +25,11 @@ public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : Tor
public fun full(value: T, shape: IntArray, device: Device): TorchTensorType public fun full(value: T, shape: IntArray, device: Device): TorchTensorType
public fun randIntegral( public fun randIntegral(
low: T, high: T, shape: IntArray, low: Long, high: Long, shape: IntArray,
device: Device = Device.CPU device: Device = Device.CPU
): TorchTensorType ): TorchTensorType
public fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType
public fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType public fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit
public fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
public fun TorchTensorType.copy(): TorchTensorType public fun TorchTensorType.copy(): TorchTensorType
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType

View File

@ -73,8 +73,8 @@ internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
numIter, numIter,
device, device,
"integer [0,100]", "integer [0,100]",
{sh, dc -> randIntegral(0f, 100f, shape = sh, device = dc)}, {sh, dc -> randIntegral(0, 100, shape = sh, device = dc)},
{ten -> ten.randIntegralAssign(0f, 100f) } {ten -> ten.randIntegralAssign(0, 100) }
) )
} }

View File

@ -43,7 +43,7 @@ internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverFie
internal inline fun <TorchTensorType : TorchTensor<Int>, internal inline fun <TorchTensorType : TorchTensor<Int>,
TorchTensorAlgebraType : TorchTensorAlgebra<Int, IntArray, TorchTensorType>> TorchTensorAlgebraType : TorchTensorAlgebra<Int, IntArray, TorchTensorType>>
TorchTensorAlgebraType.testingViewWithNoCopy(device: Device = Device.CPU) { TorchTensorAlgebraType.testingViewWithNoCopy(device: Device = Device.CPU) {
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6)) val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
val viewTensor = tensor.view(intArrayOf(2, 3)) val viewTensor = tensor.view(intArrayOf(2, 3))
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3)) assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
viewTensor[intArrayOf(0, 0)] = 10 viewTensor[intArrayOf(0, 0)] = 10

View File

@ -26,11 +26,11 @@ internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>> TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit { TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit {
setSeed(SEED) setSeed(SEED)
val integral = randIntegral(0f, 100f, IntArray(0), device = device).value() val integral = randIntegral(0, 100, IntArray(0), device = device).value()
val normal = randNormal(IntArray(0), device = device).value() val normal = randNormal(IntArray(0), device = device).value()
val uniform = randUniform(IntArray(0), device = device).value() val uniform = randUniform(IntArray(0), device = device).value()
setSeed(SEED) setSeed(SEED)
val nextIntegral = randIntegral(0f, 100f, IntArray(0), device = device).value() val nextIntegral = randIntegral(0, 100, IntArray(0), device = device).value()
val nextNormal = randNormal(IntArray(0), device = device).value() val nextNormal = randNormal(IntArray(0), device = device).value()
val nextUniform = randUniform(IntArray(0), device = device).value() val nextUniform = randUniform(IntArray(0), device = device).value()
assertEquals(normal, nextNormal) assertEquals(normal, nextNormal)

View File

@ -70,16 +70,15 @@ extern "C"
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device); TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device); TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device); TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device); TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle rand_like(TorchTensorHandle tensor_handle); TorchTensorHandle rand_like(TorchTensorHandle tensor_handle);
void rand_like_assign(TorchTensorHandle tensor_handle); void rand_like_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle randn_like(TorchTensorHandle tensor_handle); TorchTensorHandle randn_like(TorchTensorHandle tensor_handle);
void randn_like_assign(TorchTensorHandle tensor_handle); void randn_like_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle randint_long_like(TorchTensorHandle tensor_handle, long low, long high); TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high);
void randint_long_like_assign(TorchTensorHandle tensor_handle, long low, long high); void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high);
TorchTensorHandle randint_int_like(TorchTensorHandle tensor_handle, int low, int high);
void randint_int_like_assign(TorchTensorHandle tensor_handle, int low, int high);
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device); TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device); TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
@ -106,14 +105,14 @@ extern "C"
void plus_long_assign(long value, TorchTensorHandle other); void plus_long_assign(long value, TorchTensorHandle other);
void plus_int_assign(int value, TorchTensorHandle other); void plus_int_assign(int value, TorchTensorHandle other);
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle); TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle);
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);

View File

@ -25,18 +25,130 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads
/* /*
* Class: kscience_kmath_torch_JTorch * Class: kscience_kmath_torch_JTorch
* Method: createTensor * Method: cudaIsAvailable
* Signature: ()J * Signature: ()Z
*/ */
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_cudaIsAvailable
(JNIEnv *, jclass); (JNIEnv *, jclass);
/* /*
* Class: kscience_kmath_torch_JTorch * Class: kscience_kmath_torch_JTorch
* Method: printTensor * Method: setSeed
* Signature: (J)V * Signature: (I)V
*/ */
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setSeed
(JNIEnv *, jclass, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: emptyTensor
* Signature: ()J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_emptyTensor
(JNIEnv *, jclass);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fromBlobDouble
* Signature: ([D[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobDouble
(JNIEnv *, jclass, jdoubleArray, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fromBlobFloat
* Signature: ([F[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobFloat
(JNIEnv *, jclass, jfloatArray, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fromBlobLong
* Signature: ([J[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobLong
(JNIEnv *, jclass, jlongArray, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fromBlobInt
* Signature: ([I[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobInt
(JNIEnv *, jclass, jintArray, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: copyTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyTensor
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: copyToDevice
* Signature: (JI)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDevice
(JNIEnv *, jclass, jlong, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: copyToDouble
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDouble
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: copyToFloat
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToFloat
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: copyToLong
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToLong
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: copyToInt
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToInt
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: swapTensors
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_swapTensors
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: viewTensor
* Signature: (J[I)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_viewTensor
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: kscience_kmath_torch_JTorch
* Method: tensorToString
* Signature: (J)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_kscience_kmath_torch_JTorch_tensorToString
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/* /*
@ -47,6 +159,654 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getDim
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDim
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getNumel
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumel
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getShapeAt
* Signature: (JI)I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getShapeAt
(JNIEnv *, jclass, jlong, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getStrideAt
* Signature: (JI)I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getStrideAt
(JNIEnv *, jclass, jlong, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getDevice
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDevice
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getItemDouble
* Signature: (J)D
*/
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getItemDouble
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getItemFloat
* Signature: (J)F
*/
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getItemFloat
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getItemLong
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getItemLong
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getItemInt
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getItemInt
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getDouble
* Signature: (J[I)D
*/
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getDouble
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getFloat
* Signature: (J[I)F
*/
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getFloat
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getLong
* Signature: (J[I)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getLong
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: kscience_kmath_torch_JTorch
* Method: getInt
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getInt
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: kscience_kmath_torch_JTorch
* Method: setDouble
* Signature: (J[ID)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setDouble
(JNIEnv *, jclass, jlong, jintArray, jdouble);
/*
* Class: kscience_kmath_torch_JTorch
* Method: setFloat
* Signature: (J[IF)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setFloat
(JNIEnv *, jclass, jlong, jintArray, jfloat);
/*
* Class: kscience_kmath_torch_JTorch
* Method: setLong
* Signature: (J[IJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setLong
(JNIEnv *, jclass, jlong, jintArray, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: setInt
* Signature: (J[II)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setInt
(JNIEnv *, jclass, jlong, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randDouble
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randDouble
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randnDouble
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnDouble
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randFloat
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randFloat
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randnFloat
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnFloat
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randintDouble
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintDouble
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randintFloat
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintFloat
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randintLong
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLong
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randintInt
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintInt
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randLike
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randLike
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randLikeAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randLikeAssign
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randnLike
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnLike
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randnLikeAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randnLikeAssign
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randintLike
* Signature: (JJJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLike
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: randintLikeAssign
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randintLikeAssign
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fullDouble
* Signature: (D[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullDouble
(JNIEnv *, jclass, jdouble, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fullFloat
* Signature: (F[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullFloat
(JNIEnv *, jclass, jfloat, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fullLong
* Signature: (J[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullLong
(JNIEnv *, jclass, jlong, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: fullInt
* Signature: (I[II)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullInt
(JNIEnv *, jclass, jint, jintArray, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesDouble
* Signature: (DJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesDouble
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesFloat
* Signature: (FJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesFloat
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesLong
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesLong
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesInt
* Signature: (IJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesInt
(JNIEnv *, jclass, jint, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesDoubleAssign
* Signature: (DJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesDoubleAssign
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesFloatAssign
* Signature: (FJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesFloatAssign
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesLongAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesLongAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesIntAssign
* Signature: (IJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesIntAssign
(JNIEnv *, jclass, jint, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusDouble
* Signature: (DJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusDouble
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusFloat
* Signature: (FJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusFloat
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusLong
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusLong
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusInt
* Signature: (IJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusInt
(JNIEnv *, jclass, jint, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusDoubleAssign
* Signature: (DJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusDoubleAssign
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusFloatAssign
* Signature: (FJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusFloatAssign
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusLongAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusLongAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusIntAssign
* Signature: (IJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusIntAssign
(JNIEnv *, jclass, jint, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: timesTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: divTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_divTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: divTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_divTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: plusTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: minusTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_minusTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: minusTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_minusTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: unaryMinus
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_unaryMinus
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: absTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_absTensor
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: absTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_absTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: transposeTensor
* Signature: (JII)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_transposeTensor
(JNIEnv *, jclass, jlong, jint, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: transposeTensorAssign
* Signature: (JII)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_transposeTensorAssign
(JNIEnv *, jclass, jlong, jint, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: expTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_expTensor
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: expTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_expTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: logTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_logTensor
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: logTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_logTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: sumTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_sumTensor
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: sumTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_sumTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: matmul
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_matmul
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: matmulAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: matmulRightAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulRightAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: diagEmbed
* Signature: (JIII)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_diagEmbed
(JNIEnv *, jclass, jlong, jint, jint, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: svdTensor
* Signature: (JJJJ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_svdTensor
(JNIEnv *, jclass, jlong, jlong, jlong, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: symeigTensor
* Signature: (JJJZ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_symeigTensor
(JNIEnv *, jclass, jlong, jlong, jlong, jboolean);
/*
* Class: kscience_kmath_torch_JTorch
* Method: requiresGrad
* Signature: (J)Z
*/
JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_requiresGrad
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: setRequiresGrad
* Signature: (JZ)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setRequiresGrad
(JNIEnv *, jclass, jlong, jboolean);
/*
* Class: kscience_kmath_torch_JTorch
* Method: detachFromGraph
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_detachFromGraph
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: autogradTensor
* Signature: (JJZ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autogradTensor
(JNIEnv *, jclass, jlong, jlong, jboolean);
/*
* Class: kscience_kmath_torch_JTorch
* Method: autohessTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autohessTensor
(JNIEnv *, jclass, jlong, jlong);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -3,7 +3,7 @@
namespace ctorch namespace ctorch
{ {
using TorchTensorHandle = void*; using TorchTensorHandle = void *;
template <typename Dtype> template <typename Dtype>
inline c10::ScalarType dtype() inline c10::ScalarType dtype()
@ -29,16 +29,28 @@ namespace ctorch
return torch::kInt32; return torch::kInt32;
} }
inline torch::Tensor &cast(const TorchTensorHandle &tensor_handle) template <typename Handle>
inline torch::Tensor &cast(const Handle &tensor_handle)
{ {
return *static_cast<torch::Tensor *>(tensor_handle); return *static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
}
template <typename Handle>
inline void dispose_tensor(const Handle &tensor_handle)
{
delete static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
}
inline std::string tensor_to_string(const torch::Tensor &tensor)
{
std::stringstream bufrep;
bufrep << tensor;
return bufrep.str();
} }
inline char *tensor_to_char(const torch::Tensor &tensor) inline char *tensor_to_char(const torch::Tensor &tensor)
{ {
std::stringstream bufrep; auto rep = tensor_to_string(tensor);
bufrep << tensor;
auto rep = bufrep.str();
char *crep = (char *)malloc(rep.length() + 1); char *crep = (char *)malloc(rep.length() + 1);
std::strcpy(crep, rep.c_str()); std::strcpy(crep, rep.c_str());
return crep; return crep;
@ -72,45 +84,43 @@ namespace ctorch
} }
template <typename Dtype> template <typename Dtype>
inline torch::Tensor from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device, bool copy) inline torch::Tensor from_blob(Dtype *data, const std::vector<int64_t> &shape, torch::Device device, bool copy)
{ {
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy); return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
} }
template <typename NumType> template <typename NumType>
inline NumType get(const TorchTensorHandle &tensor_handle, int *index) inline NumType get(const torch::Tensor &tensor, int *index)
{ {
auto ten = ctorch::cast(tensor_handle); return tensor.index(to_index(index, tensor.dim())).item<NumType>();
return ten.index(to_index(index, ten.dim())).item<NumType>();
} }
template <typename NumType> template <typename NumType>
inline void set(TorchTensorHandle &tensor_handle, int *index, NumType value) inline void set(const torch::Tensor &tensor, int *index, NumType value)
{ {
auto ten = ctorch::cast(tensor_handle); tensor.index(to_index(index, tensor.dim())) = value;
ten.index(to_index(index, ten.dim())) = value;
} }
template <typename Dtype> template <typename Dtype>
inline torch::Tensor randn(std::vector<int64_t> shape, torch::Device device) inline torch::Tensor randn(const std::vector<int64_t> &shape, torch::Device device)
{ {
return torch::randn(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device)); return torch::randn(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
} }
template <typename Dtype> template <typename Dtype>
inline torch::Tensor rand(std::vector<int64_t> shape, torch::Device device) inline torch::Tensor rand(const std::vector<int64_t> &shape, torch::Device device)
{ {
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device)); return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
} }
template <typename Dtype> template <typename Dtype>
inline torch::Tensor randint(long low, long high, std::vector<int64_t> shape, torch::Device device) inline torch::Tensor randint(long low, long high, const std::vector<int64_t> &shape, torch::Device device)
{ {
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device)); return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
} }
template <typename Dtype> template <typename Dtype>
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device) inline torch::Tensor full(Dtype value, const std::vector<int64_t> &shape, torch::Device device)
{ {
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device)); return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
} }

View File

@ -74,9 +74,9 @@ void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
{ {
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle)); std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
} }
TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim) TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim))); return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim)));
} }
char *tensor_to_string(TorchTensorHandle tensor_handle) char *tensor_to_string(TorchTensorHandle tensor_handle)
@ -89,7 +89,7 @@ void dispose_char(char *ptr)
} }
void dispose_tensor(TorchTensorHandle tensor_handle) void dispose_tensor(TorchTensorHandle tensor_handle)
{ {
delete static_cast<torch::Tensor *>(tensor_handle); ctorch::dispose_tensor(tensor_handle);
} }
int get_dim(TorchTensorHandle tensor_handle) int get_dim(TorchTensorHandle tensor_handle)
@ -149,35 +149,35 @@ int get_item_int(TorchTensorHandle tensor_handle)
double get_double(TorchTensorHandle tensor_handle, int *index) double get_double(TorchTensorHandle tensor_handle, int *index)
{ {
return ctorch::get<double>(tensor_handle, index); return ctorch::get<double>(ctorch::cast(tensor_handle), index);
} }
float get_float(TorchTensorHandle tensor_handle, int *index) float get_float(TorchTensorHandle tensor_handle, int *index)
{ {
return ctorch::get<float>(tensor_handle, index); return ctorch::get<float>(ctorch::cast(tensor_handle), index);
} }
long get_long(TorchTensorHandle tensor_handle, int *index) long get_long(TorchTensorHandle tensor_handle, int *index)
{ {
return ctorch::get<long>(tensor_handle, index); return ctorch::get<long>(ctorch::cast(tensor_handle), index);
} }
int get_int(TorchTensorHandle tensor_handle, int *index) int get_int(TorchTensorHandle tensor_handle, int *index)
{ {
return ctorch::get<int>(tensor_handle, index); return ctorch::get<int>(ctorch::cast(tensor_handle), index);
} }
void set_double(TorchTensorHandle tensor_handle, int *index, double value) void set_double(TorchTensorHandle tensor_handle, int *index, double value)
{ {
ctorch::set<double>(tensor_handle, index, value); ctorch::set<double>(ctorch::cast(tensor_handle), index, value);
} }
void set_float(TorchTensorHandle tensor_handle, int *index, float value) void set_float(TorchTensorHandle tensor_handle, int *index, float value)
{ {
ctorch::set<float>(tensor_handle, index, value); ctorch::set<float>(ctorch::cast(tensor_handle), index, value);
} }
void set_long(TorchTensorHandle tensor_handle, int *index, long value) void set_long(TorchTensorHandle tensor_handle, int *index, long value)
{ {
ctorch::set<long>(tensor_handle, index, value); ctorch::set<long>(ctorch::cast(tensor_handle), index, value);
} }
void set_int(TorchTensorHandle tensor_handle, int *index, int value) void set_int(TorchTensorHandle tensor_handle, int *index, int value)
{ {
ctorch::set<int>(tensor_handle, index, value); ctorch::set<int>(ctorch::cast(tensor_handle), index, value);
} }
TorchTensorHandle rand_double(int *shape, int shape_size, int device) TorchTensorHandle rand_double(int *shape, int shape_size, int device)
@ -209,7 +209,7 @@ TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size,
{ {
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
} }
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device) TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device)
{ {
return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
} }
@ -230,19 +230,11 @@ void randn_like_assign(TorchTensorHandle tensor_handle)
{ {
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle)); ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
} }
TorchTensorHandle randint_long_like(TorchTensorHandle tensor_handle, long low, long high) TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high)
{ {
return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high)); return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
} }
void randint_long_like_assign(TorchTensorHandle tensor_handle, long low, long high) void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high)
{
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
}
TorchTensorHandle randint_int_like(TorchTensorHandle tensor_handle, int low, int high)
{
return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
}
void randint_int_like_assign(TorchTensorHandle tensor_handle, int low, int high)
{ {
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high); ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
} }
@ -264,39 +256,6 @@ TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
} }
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
void plus_double_assign(double value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_float_assign(float value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_long_assign(long value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_int_assign(int value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
TorchTensorHandle times_double(double value, TorchTensorHandle other) TorchTensorHandle times_double(double value, TorchTensorHandle other)
{ {
return new torch::Tensor(value * ctorch::cast(other)); return new torch::Tensor(value * ctorch::cast(other));
@ -330,22 +289,39 @@ void times_int_assign(int value, TorchTensorHandle other)
ctorch::cast(other) *= value; ctorch::cast(other) *= value;
} }
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) TorchTensorHandle plus_double(double value, TorchTensorHandle other)
{ {
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs)); return new torch::Tensor(ctorch::cast(other) + value);
} }
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) TorchTensorHandle plus_float(float value, TorchTensorHandle other)
{ {
ctorch::cast(lhs) += ctorch::cast(rhs); return new torch::Tensor(ctorch::cast(other) + value);
} }
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) TorchTensorHandle plus_long(long value, TorchTensorHandle other)
{ {
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs)); return new torch::Tensor(ctorch::cast(other) + value);
} }
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) TorchTensorHandle plus_int(int value, TorchTensorHandle other)
{ {
ctorch::cast(lhs) -= ctorch::cast(rhs); return new torch::Tensor(ctorch::cast(other) + value);
} }
void plus_double_assign(double value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_float_assign(float value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_long_assign(long value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_int_assign(int value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
{ {
return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs)); return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
@ -362,6 +338,22 @@ void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{ {
ctorch::cast(lhs) /= ctorch::cast(rhs); ctorch::cast(lhs) /= ctorch::cast(rhs);
} }
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
}
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) += ctorch::cast(rhs);
}
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
}
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) -= ctorch::cast(rhs);
}
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle) TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle)
{ {
return new torch::Tensor(-ctorch::cast(tensor_handle)); return new torch::Tensor(-ctorch::cast(tensor_handle));

View File

@ -15,21 +15,558 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *,
torch::set_num_threads(num_threads); torch::set_num_threads(num_threads);
} }
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor(JNIEnv *, jclass) JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_cudaIsAvailable(JNIEnv *, jclass)
{ {
auto ten = torch::randn({2, 3}); return torch::cuda::is_available();
std::cout << ten << std::endl;
void *ptr = new torch::Tensor(ten);
return (long)ptr;
} }
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor(JNIEnv *, jclass, jlong tensor_handle) JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setSeed(JNIEnv *, jclass, jint seed)
{ {
auto ten = ctorch::cast((void *)tensor_handle); torch::manual_seed(seed);
std::cout << ten << std::endl; }
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_emptyTensor(JNIEnv *, jclass)
{
return (long)new torch::Tensor;
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobDouble(JNIEnv *env, jclass, jdoubleArray data, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::from_blob<double>(
env->GetDoubleArrayElements(data, 0),
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device), true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobFloat(JNIEnv *env, jclass, jfloatArray data, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::from_blob<float>(
env->GetFloatArrayElements(data, 0),
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device), true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobLong(JNIEnv *env, jclass, jlongArray data, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::from_blob<long>(
env->GetLongArrayElements(data, 0),
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device), true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobInt(JNIEnv *env, jclass, jintArray data, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::from_blob<int>(
env->GetIntArrayElements(data, 0),
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device), true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyTensor(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).clone());
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDevice(JNIEnv *, jclass, jlong tensor_handle, jint device)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDouble(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToFloat(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToLong(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToInt(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_swapTensors(JNIEnv *, jclass, jlong lhs_handle, jlong rhs_handle)
{
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_viewTensor(JNIEnv *env, jclass, jlong tensor_handle, jintArray shape)
{
return (long)new torch::Tensor(
ctorch::cast(tensor_handle).view(ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape))));
}
JNIEXPORT jstring JNICALL Java_kscience_kmath_torch_JTorch_tensorToString(JNIEnv *env, jclass, jlong tensor_handle)
{
return env->NewStringUTF(ctorch::tensor_to_string(ctorch::cast(tensor_handle)).c_str());
} }
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle) JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle)
{ {
delete static_cast<torch::Tensor *>((void *)tensor_handle); ctorch::dispose_tensor(tensor_handle);
}
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDim(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::cast(tensor_handle).dim();
}
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumel(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::cast(tensor_handle).numel();
}
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getShapeAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
{
return ctorch::cast(tensor_handle).size(d);
}
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getStrideAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
{
return ctorch::cast(tensor_handle).stride(d);
}
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDevice(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::device_to_int(ctorch::cast(tensor_handle));
}
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getItemDouble(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::cast(tensor_handle).item<double>();
}
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getItemFloat(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::cast(tensor_handle).item<float>();
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getItemLong(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::cast(tensor_handle).item<long>();
}
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getItemInt(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::cast(tensor_handle).item<int>();
}
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
{
return ctorch::get<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
}
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
{
return ctorch::get<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
{
return ctorch::get<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
}
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
{
return ctorch::get<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jdouble value)
{
ctorch::set<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jfloat value)
{
ctorch::set<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jlong value)
{
ctorch::set<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jint value)
{
ctorch::set<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randDouble(JNIEnv *env, jclass, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::rand<double>(
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnDouble(JNIEnv *env, jclass, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::randn<double>(
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randFloat(JNIEnv *env, jclass, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::rand<float>(
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnFloat(JNIEnv *env, jclass, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::randn<float>(
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintDouble(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::randint<double>(low, high,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintFloat(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::randint<float>(low, high,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLong(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::randint<long>(low, high,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintInt(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::randint<int>(low, high,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randLike(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(torch::rand_like(ctorch::cast(tensor_handle)));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
{
ctorch::cast(tensor_handle) = torch::rand_like(ctorch::cast(tensor_handle));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnLike(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(torch::randn_like(ctorch::cast(tensor_handle)));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randnLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
{
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLike(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
{
return (long)new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randintLikeAssign(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
{
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullDouble(JNIEnv *env, jclass, jdouble value, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::full<double>(
value,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullFloat(JNIEnv *env, jclass, jfloat value, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::full<float>(
value,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullLong(JNIEnv *env, jclass, jlong value, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::full<long>(
value,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullInt(JNIEnv *env, jclass, jint value, jintArray shape, jint device)
{
return (long)new torch::Tensor(
ctorch::full<int>(
value,
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
ctorch::int_to_device(device)));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesDouble(JNIEnv *, jclass, jdouble value, jlong other)
{
return (long)new torch::Tensor(value * ctorch::cast(other));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesFloat(JNIEnv *, jclass, jfloat value, jlong other)
{
return (long)new torch::Tensor(value * ctorch::cast(other));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesLong(JNIEnv *, jclass, jlong value, jlong other)
{
return (long)new torch::Tensor(value * ctorch::cast(other));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesInt(JNIEnv *, jclass, jint value, jlong other)
{
return (long)new torch::Tensor(value * ctorch::cast(other));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
{
ctorch::cast(other) *= value;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
{
ctorch::cast(other) *= value;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesLongAssign(JNIEnv *, jclass, jlong value, jlong other)
{
ctorch::cast(other) *= value;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesIntAssign(JNIEnv *, jclass, jint value, jlong other)
{
ctorch::cast(other) *= value;
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusDouble(JNIEnv *, jclass, jdouble value, jlong other)
{
return (long)new torch::Tensor(value + ctorch::cast(other));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusFloat(JNIEnv *, jclass, jfloat value, jlong other)
{
return (long)new torch::Tensor(value + ctorch::cast(other));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusLong(JNIEnv *, jclass, jlong value, jlong other)
{
return (long)new torch::Tensor(value + ctorch::cast(other));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusInt(JNIEnv *, jclass, jint value, jlong other)
{
return (long)new torch::Tensor(value + ctorch::cast(other));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
{
ctorch::cast(other) += value;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
{
ctorch::cast(other) += value;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusLongAssign(JNIEnv *, jclass, jlong value, jlong other)
{
ctorch::cast(other) += value;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusIntAssign(JNIEnv *, jclass, jint value, jlong other)
{
ctorch::cast(other) += value;
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
return (long)new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
ctorch::cast(lhs) *= ctorch::cast(rhs);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_divTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
return (long)new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_divTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
ctorch::cast(lhs) /= ctorch::cast(rhs);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
return (long)new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
ctorch::cast(lhs) += ctorch::cast(rhs);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_minusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
return (long)new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_minusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
ctorch::cast(lhs) -= ctorch::cast(rhs);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_unaryMinus(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(-ctorch::cast(tensor_handle));
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_absTensor(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).abs());
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_absTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_transposeTensor(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_transposeTensorAssign(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_expTensor(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).exp());
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_expTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_logTensor(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).log());
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_logTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_sumTensor(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).sum());
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_sumTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_matmul(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
return (long)new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulRightAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
{
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
JNIEXPORT jlong JNICALL
Java_kscience_kmath_torch_JTorch_diagEmbed(JNIEnv *, jclass, jlong diags_handle, jint offset, jint dim1, jint dim2)
{
return (long)new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
}
JNIEXPORT void JNICALL
Java_kscience_kmath_torch_JTorch_svdTensor(JNIEnv *, jclass, jlong tensor_handle, jlong U_handle, jlong S_handle, jlong V_handle)
{
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
ctorch::cast(U_handle) = U;
ctorch::cast(S_handle) = S;
ctorch::cast(V_handle) = V;
}
JNIEXPORT void JNICALL
Java_kscience_kmath_torch_JTorch_symeigTensor(JNIEnv *, jclass, jlong tensor_handle, jlong S_handle, jlong V_handle, jboolean eigenvectors)
{
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
ctorch::cast(S_handle) = S;
ctorch::cast(V_handle) = V;
}
JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_requiresGrad(JNIEnv *, jclass, jlong tensor_handle)
{
return ctorch::cast(tensor_handle).requires_grad();
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setRequiresGrad(JNIEnv *, jclass, jlong tensor_handle, jboolean status)
{
ctorch::cast(tensor_handle).requires_grad_(status);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_detachFromGraph(JNIEnv *, jclass, jlong tensor_handle)
{
return (long)new torch::Tensor(ctorch::cast(tensor_handle).detach());
}
JNIEXPORT jlong JNICALL
Java_kscience_kmath_torch_JTorch_autogradTensor(JNIEnv *, jclass, jlong value, jlong variable, jboolean retain_graph)
{
return (long)new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autohessTensor(JNIEnv *, jclass, jlong value, jlong variable)
{
return (long)new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
} }

View File

@ -7,8 +7,202 @@ class JTorch {
} }
public static native int getNumThreads(); public static native int getNumThreads();
public static native void setNumThreads(int numThreads); public static native void setNumThreads(int numThreads);
public static native long createTensor();
public static native void printTensor(long tensorHandle); public static native boolean cudaIsAvailable();
public static native void setSeed(int seed);
public static native long emptyTensor();
public static native long fromBlobDouble(double[] data, int[] shape, int device);
public static native long fromBlobFloat(float[] data, int[] shape, int device);
public static native long fromBlobLong(long[] data, int[] shape, int device);
public static native long fromBlobInt(int[] data, int[] shape, int device);
public static native long copyTensor(long tensorHandle);
public static native long copyToDevice(long tensorHandle, int device);
public static native long copyToDouble(long tensorHandle);
public static native long copyToFloat(long tensorHandle);
public static native long copyToLong(long tensorHandle);
public static native long copyToInt(long tensorHandle);
public static native void swapTensors(long lhsHandle, long rhsHandle);
public static native long viewTensor(long tensorHandle, int[] shape);
public static native String tensorToString(long tensorHandle);
public static native void disposeTensor(long tensorHandle); public static native void disposeTensor(long tensorHandle);
}
public static native int getDim(long tensorHandle);
public static native int getNumel(long tensorHandle);
public static native int getShapeAt(long tensorHandle, int d);
public static native int getStrideAt(long tensorHandle, int d);
public static native int getDevice(long tensorHandle);
public static native double getItemDouble(long tensorHandle);
public static native float getItemFloat(long tensorHandle);
public static native long getItemLong(long tensorHandle);
public static native int getItemInt(long tensorHandle);
public static native double getDouble(long tensorHandle, int[] index);
public static native float getFloat(long tensorHandle, int[] index);
public static native long getLong(long tensorHandle, int[] index);
public static native int getInt(long tensorHandle, int[] index);
public static native void setDouble(long tensorHandle, int[] index, double value);
public static native void setFloat(long tensorHandle, int[] index, float value);
public static native void setLong(long tensorHandle, int[] index, long value);
public static native void setInt(long tensorHandle, int[] index, int value);
public static native long randDouble(int[] shape, int device);
public static native long randnDouble(int[] shape, int device);
public static native long randFloat(int[] shape, int device);
public static native long randnFloat(int[] shape, int device);
public static native long randintDouble(long low, long high, int[] shape, int device);
public static native long randintFloat(long low, long high, int[] shape, int device);
public static native long randintLong(long low, long high, int[] shape, int device);
public static native long randintInt(long low, long high, int[] shape, int device);
public static native long randLike(long tensorHandle);
public static native void randLikeAssign(long tensorHandle);
public static native long randnLike(long tensorHandle);
public static native void randnLikeAssign(long tensorHandle);
public static native long randintLike(long tensorHandle, long low, long high);
public static native void randintLikeAssign(long tensorHandle, long low, long high);
public static native long fullDouble(double value, int[] shape, int device);
public static native long fullFloat(float value, int[] shape, int device);
public static native long fullLong(long value, int[] shape, int device);
public static native long fullInt(int value, int[] shape, int device);
public static native long timesDouble(double value, long other);
public static native long timesFloat(float value, long other);
public static native long timesLong(long value, long other);
public static native long timesInt(int value, long other);
public static native void timesDoubleAssign(double value, long other);
public static native void timesFloatAssign(float value, long other);
public static native void timesLongAssign(long value, long other);
public static native void timesIntAssign(int value, long other);
public static native long plusDouble(double value, long other);
public static native long plusFloat(float value, long other);
public static native long plusLong(long value, long other);
public static native long plusInt(int value, long other);
public static native void plusDoubleAssign(double value, long other);
public static native void plusFloatAssign(float value, long other);
public static native void plusLongAssign(long value, long other);
public static native void plusIntAssign(int value, long other);
public static native long timesTensor(long lhs, long rhs);
public static native void timesTensorAssign(long lhs, long rhs);
public static native long divTensor(long lhs, long rhs);
public static native void divTensorAssign(long lhs, long rhs);
public static native long plusTensor(long lhs, long rhs);
public static native void plusTensorAssign(long lhs, long rhs);
public static native long minusTensor(long lhs, long rhs);
public static native void minusTensorAssign(long lhs, long rhs);
public static native long unaryMinus(long tensorHandle);
public static native long absTensor(long tensorHandle);
public static native void absTensorAssign(long tensorHandle);
public static native long transposeTensor(long tensorHandle, int i, int j);
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
public static native long expTensor(long tensorHandle);
public static native void expTensorAssign(long tensorHandle);
public static native long logTensor(long tensorHandle);
public static native void logTensorAssign(long tensorHandle);
public static native long sumTensor(long tensorHandle);
public static native void sumTensorAssign(long tensorHandle);
public static native long matmul(long lhs, long rhs);
public static native void matmulAssign(long lhs, long rhs);
public static native void matmulRightAssign(long lhs, long rhs);
public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2);
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors);
public static native boolean requiresGrad(long tensorHandle);
public static native void setRequiresGrad(long tensorHandle, boolean status);
public static native long detachFromGraph(long tensorHandle);
public static native long autogradTensor(long value, long variable, boolean retainGraph);
public static native long autohessTensor(long value, long variable);
}

View File

@ -0,0 +1,4 @@
package kscience.kmath.torch
public class TorchTensorJVM {
}

View File

@ -1,17 +0,0 @@
package kscience.kmath.torch
public fun getNumThreads(): Int {
return JTorch.getNumThreads()
}
public fun setNumThreads(numThreads: Int): Unit {
JTorch.setNumThreads(numThreads)
}
public fun runCPD(): Unit {
val tensorHandle = JTorch.createTensor()
JTorch.printTensor(tensorHandle)
JTorch.disposeTensor(tensorHandle)
}

View File

@ -6,14 +6,10 @@ import kotlin.test.*
class TestUtils { class TestUtils {
@Test @Test
fun testSetNumThreads() { fun testJTorch() {
val numThreads = 2 val tensor = JTorch.fullInt(54, intArrayOf(3), 0)
setNumThreads(numThreads) println(JTorch.tensorToString(tensor))
assertEquals(numThreads, getNumThreads()) JTorch.disposeTensor(tensor)
} }
@Test
fun testCPD() {
runCPD()
}
} }

View File

@ -117,6 +117,13 @@ public sealed class TorchTensorAlgebraNative<
sum_tensor_assign(tensorHandle) sum_tensor_assign(tensorHandle)
} }
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
wrap(randint_like(this.tensorHandle, low, high)!!)
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit {
randint_like_assign(this.tensorHandle, low, high)
}
override fun TorchTensorType.copy(): TorchTensorType = override fun TorchTensorType.copy(): TorchTensorType =
wrap(copy_tensor(this.tensorHandle)!!) wrap(copy_tensor(this.tensorHandle)!!)
@ -224,6 +231,9 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal = override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!) wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
wrap(randint_double(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
wrap(plus_double(this, other.tensorHandle)!!) wrap(plus_double(this, other.tensorHandle)!!)
@ -256,18 +266,9 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun randIntegral(low: Double, high: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(randint_double(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorReal.randIntegral(low: Double, high: Double): TorchTensorReal =
wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!)
override fun TorchTensorReal.randIntegralAssign(low: Double, high: Double): Unit {
randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong())
}
} }
public class TorchTensorFloatAlgebra(scope: DeferScope) : public class TorchTensorFloatAlgebra(scope: DeferScope) :
TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) { TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat = override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
@ -295,6 +296,9 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat = override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!) wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
wrap(randint_float(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat = override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
wrap(plus_float(this, other.tensorHandle)!!) wrap(plus_float(this, other.tensorHandle)!!)
@ -328,15 +332,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun randIntegral(low: Float, high: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(randint_float(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorFloat.randIntegral(low: Float, high: Float): TorchTensorFloat =
wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!)
override fun TorchTensorFloat.randIntegralAssign(low: Float, high: Float): Unit {
randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong())
}
} }
public class TorchTensorLongAlgebra(scope: DeferScope) : public class TorchTensorLongAlgebra(scope: DeferScope) :
@ -363,13 +358,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong = override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!) wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorLong.randIntegral(low: Long, high: Long): TorchTensorLong =
wrap(randint_long_like(this.tensorHandle, low, high)!!)
override fun TorchTensorLong.randIntegralAssign(low: Long, high: Long): Unit {
randint_long_like_assign(this.tensorHandle, low, high)
}
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong = override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
wrap(plus_long(this, other.tensorHandle)!!) wrap(plus_long(this, other.tensorHandle)!!)
@ -425,16 +413,9 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
return get_data_int(this.tensorHandle)!! return get_data_int(this.tensorHandle)!!
} }
override fun randIntegral(low: Int, high: Int, shape: IntArray, device: Device): TorchTensorInt = override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!) wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorInt.randIntegral(low: Int, high: Int): TorchTensorInt =
wrap(randint_int_like(this.tensorHandle, low, high)!!)
override fun TorchTensorInt.randIntegralAssign(low: Int, high: Int): Unit {
randint_int_like_assign(this.tensorHandle, low, high)
}
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt = override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
wrap(plus_int(this, other.tensorHandle)!!) wrap(plus_int(this, other.tensorHandle)!!)

View File

@ -22,6 +22,4 @@ internal class BenchmarkMatMul {
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0)) benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
} }
} }
} }

View File

@ -6,7 +6,7 @@ import kotlin.test.*
internal class TestUtils { internal class TestUtils {
@Test @Test
fun testSetNumThreads() { fun testSetNumThreads() {
TorchTensorIntAlgebra { TorchTensorLongAlgebra {
testingSetNumThreads() testingSetNumThreads()
} }
} }