From 17e6ebbc14b8afe2a84d77d0842edde37255deed Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Tue, 19 Jan 2021 09:58:47 +0000 Subject: [PATCH] JNI wrapper --- .gitignore | 3 +- kmath-torch/README.md | 10 +- kmath-torch/build.gradle.kts | 7 +- .../TorchTensorAlgebra.kt | 7 +- .../BenchmarkRandomGenerators.kt | 4 +- .../kscience.kmath.torch/TestTorchTensor.kt | 2 +- .../kotlin/kscience.kmath.torch/TestUtils.kt | 4 +- kmath-torch/src/cppMain/include/ctorch.h | 17 +- .../include/kscience_kmath_torch_JTorch.h | 772 +++++++++++++++++- kmath-torch/src/cppMain/include/utils.hh | 44 +- kmath-torch/src/cppMain/src/ctorch.cc | 118 ++- kmath-torch/src/cppMain/src/jtorch.cc | 555 ++++++++++++- .../java/kscience/kmath/torch/JTorch.java | 200 ++++- .../kscience/kmath/torch/TorchTensorJVM.kt | 4 + .../kotlin/kscience/kmath/torch/Utils.kt | 17 - .../kotlin/kscience/kmath/torch/TestUtils.kt | 12 +- .../TorchTensorAlgebraNative.kt | 49 +- .../kscience/kmath/torch/BenchmarkMatMul.kt | 2 - .../kotlin/kscience/kmath/torch/TestUtils.kt | 2 +- 19 files changed, 1643 insertions(+), 186 deletions(-) create mode 100644 kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt delete mode 100644 kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt diff --git a/.gitignore b/.gitignore index 2a146aab6..b2451a1a8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ out/ # Cache of project .gradletasknamecache -# Generated by javac -h +# Generated by javac -h and runtime *.class +*.log diff --git a/kmath-torch/README.md b/kmath-torch/README.md index 3137906f3..b7a228a6e 100644 --- a/kmath-torch/README.md +++ b/kmath-torch/README.md @@ -1,6 +1,6 @@ # LibTorch extension (`kmath-torch`) -This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of +This is a `Kotlin/Native` & `JVM` module, with only `linuxX64` supported so far. The library wraps some of the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`. ## Installation @@ -11,7 +11,7 @@ To install the library, you have to build & publish locally `kmath-core`, `kmath ./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal ``` -This builds `ctorch`, a C wrapper for `LibTorch` placed inside: +This builds `ctorch` a C wrapper and `jtorch` a JNI wrapper for `LibTorch`, placed inside: `~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build` @@ -19,8 +19,8 @@ You will have to link against it in your own project. ## Usage -Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods -and require scoping: +Tensors are implemented over the `MutableNDStructure`. They can only be created through provided factory methods +and require scoping within a `TensorAlgebra` instance: ```kotlin TorchTensorRealAlgebra { @@ -63,4 +63,4 @@ TorchTensorRealAlgebra { val hessianAtX = expressionAtX hess tensorX } ``` - +Contributed by [Roland Grinis](https://github.com/rgrit91) diff --git a/kmath-torch/build.gradle.kts b/kmath-torch/build.gradle.kts index 999ea1fa1..1508f0ef7 100644 --- a/kmath-torch/build.gradle.kts +++ b/kmath-torch/build.gradle.kts @@ -127,7 +127,6 @@ val generateJNIHeader by tasks.registering { kotlin { explicitApiWarning() - jvm { withJava() } @@ -158,7 +157,6 @@ kotlin { val test by nativeTarget.compilations.getting sourceSets { - val commonMain by getting { dependencies { api(project(":kmath-core")) @@ -171,6 +169,11 @@ kotlin { } } + val jvmMain by getting { + dependencies { + api(project(":kmath-core")) + } + } } } diff --git a/kmath-torch/src/commonMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/commonMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index 676a26034..c1c18fa6d 100644 --- a/kmath-torch/src/commonMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/commonMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -25,12 +25,11 @@ public interface TorchTensorAlgebra, numIter, device, "integer [0,100]", - {sh, dc -> randIntegral(0f, 100f, shape = sh, device = dc)}, - {ten -> ten.randIntegralAssign(0f, 100f) } + {sh, dc -> randIntegral(0, 100, shape = sh, device = dc)}, + {ten -> ten.randIntegralAssign(0, 100) } ) } diff --git a/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt b/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt index b5d9c9431..a35d154a9 100644 --- a/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt +++ b/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt @@ -43,7 +43,7 @@ internal inline fun , TorchTensorAlgebraType : TorchTensorAlgebra> TorchTensorAlgebraType.testingViewWithNoCopy(device: Device = Device.CPU) { - val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6)) + val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device) val viewTensor = tensor.view(intArrayOf(2, 3)) assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3)) viewTensor[intArrayOf(0, 0)] = 10 diff --git a/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestUtils.kt b/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestUtils.kt index 93bd57fd3..b6dbb11c6 100644 --- a/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestUtils.kt +++ b/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestUtils.kt @@ -26,11 +26,11 @@ internal inline fun , TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra> TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit { setSeed(SEED) - val integral = randIntegral(0f, 100f, IntArray(0), device = device).value() + val integral = randIntegral(0, 100, IntArray(0), device = device).value() val normal = randNormal(IntArray(0), device = device).value() val uniform = randUniform(IntArray(0), device = device).value() setSeed(SEED) - val nextIntegral = randIntegral(0f, 100f, IntArray(0), device = device).value() + val nextIntegral = randIntegral(0, 100, IntArray(0), device = device).value() val nextNormal = randNormal(IntArray(0), device = device).value() val nextUniform = randUniform(IntArray(0), device = device).value() assertEquals(normal, nextNormal) diff --git a/kmath-torch/src/cppMain/include/ctorch.h b/kmath-torch/src/cppMain/include/ctorch.h index 0f60d7356..5df36db0f 100644 --- a/kmath-torch/src/cppMain/include/ctorch.h +++ b/kmath-torch/src/cppMain/include/ctorch.h @@ -70,16 +70,15 @@ extern "C" TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device); TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device); TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device); - TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device); + TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device); TorchTensorHandle rand_like(TorchTensorHandle tensor_handle); void rand_like_assign(TorchTensorHandle tensor_handle); TorchTensorHandle randn_like(TorchTensorHandle tensor_handle); void randn_like_assign(TorchTensorHandle tensor_handle); - TorchTensorHandle randint_long_like(TorchTensorHandle tensor_handle, long low, long high); - void randint_long_like_assign(TorchTensorHandle tensor_handle, long low, long high); - TorchTensorHandle randint_int_like(TorchTensorHandle tensor_handle, int low, int high); - void randint_int_like_assign(TorchTensorHandle tensor_handle, int low, int high); + TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high); + void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high); + TorchTensorHandle full_double(double value, int *shape, int shape_size, int device); TorchTensorHandle full_float(float value, int *shape, int shape_size, int device); @@ -106,14 +105,14 @@ extern "C" void plus_long_assign(long value, TorchTensorHandle other); void plus_int_assign(int value, TorchTensorHandle other); - TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); - void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); - TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); - void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); + void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); + void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle); TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle); diff --git a/kmath-torch/src/cppMain/include/kscience_kmath_torch_JTorch.h b/kmath-torch/src/cppMain/include/kscience_kmath_torch_JTorch.h index 7173e1f9e..154231ff1 100644 --- a/kmath-torch/src/cppMain/include/kscience_kmath_torch_JTorch.h +++ b/kmath-torch/src/cppMain/include/kscience_kmath_torch_JTorch.h @@ -25,18 +25,130 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads /* * Class: kscience_kmath_torch_JTorch - * Method: createTensor - * Signature: ()J + * Method: cudaIsAvailable + * Signature: ()Z */ -JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor +JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_cudaIsAvailable (JNIEnv *, jclass); /* * Class: kscience_kmath_torch_JTorch - * Method: printTensor - * Signature: (J)V + * Method: setSeed + * Signature: (I)V */ -JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setSeed + (JNIEnv *, jclass, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: emptyTensor + * Signature: ()J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_emptyTensor + (JNIEnv *, jclass); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fromBlobDouble + * Signature: ([D[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobDouble + (JNIEnv *, jclass, jdoubleArray, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fromBlobFloat + * Signature: ([F[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobFloat + (JNIEnv *, jclass, jfloatArray, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fromBlobLong + * Signature: ([J[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobLong + (JNIEnv *, jclass, jlongArray, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fromBlobInt + * Signature: ([I[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobInt + (JNIEnv *, jclass, jintArray, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: copyTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: copyToDevice + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDevice + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: copyToDouble + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDouble + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: copyToFloat + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToFloat + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: copyToLong + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToLong + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: copyToInt + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToInt + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: swapTensors + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_swapTensors + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: viewTensor + * Signature: (J[I)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_viewTensor + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: tensorToString + * Signature: (J)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_kscience_kmath_torch_JTorch_tensorToString (JNIEnv *, jclass, jlong); /* @@ -47,6 +159,654 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor (JNIEnv *, jclass, jlong); +/* + * Class: kscience_kmath_torch_JTorch + * Method: getDim + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDim + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getNumel + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumel + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getShapeAt + * Signature: (JI)I + */ +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getShapeAt + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getStrideAt + * Signature: (JI)I + */ +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getStrideAt + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getDevice + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDevice + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getItemDouble + * Signature: (J)D + */ +JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getItemDouble + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getItemFloat + * Signature: (J)F + */ +JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getItemFloat + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getItemLong + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getItemLong + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getItemInt + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getItemInt + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getDouble + * Signature: (J[I)D + */ +JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getDouble + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getFloat + * Signature: (J[I)F + */ +JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getFloat + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getLong + * Signature: (J[I)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getLong + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: getInt + * Signature: (J[I)I + */ +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getInt + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: setDouble + * Signature: (J[ID)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setDouble + (JNIEnv *, jclass, jlong, jintArray, jdouble); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: setFloat + * Signature: (J[IF)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setFloat + (JNIEnv *, jclass, jlong, jintArray, jfloat); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: setLong + * Signature: (J[IJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setLong + (JNIEnv *, jclass, jlong, jintArray, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: setInt + * Signature: (J[II)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setInt + (JNIEnv *, jclass, jlong, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randDouble + * Signature: ([II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randDouble + (JNIEnv *, jclass, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randnDouble + * Signature: ([II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnDouble + (JNIEnv *, jclass, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randFloat + * Signature: ([II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randFloat + (JNIEnv *, jclass, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randnFloat + * Signature: ([II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnFloat + (JNIEnv *, jclass, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randintDouble + * Signature: (JJ[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintDouble + (JNIEnv *, jclass, jlong, jlong, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randintFloat + * Signature: (JJ[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintFloat + (JNIEnv *, jclass, jlong, jlong, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randintLong + * Signature: (JJ[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLong + (JNIEnv *, jclass, jlong, jlong, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randintInt + * Signature: (JJ[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintInt + (JNIEnv *, jclass, jlong, jlong, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randLike + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randLike + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randLikeAssign + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randLikeAssign + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randnLike + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnLike + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randnLikeAssign + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randnLikeAssign + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randintLike + * Signature: (JJJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLike + (JNIEnv *, jclass, jlong, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: randintLikeAssign + * Signature: (JJJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randintLikeAssign + (JNIEnv *, jclass, jlong, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fullDouble + * Signature: (D[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullDouble + (JNIEnv *, jclass, jdouble, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fullFloat + * Signature: (F[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullFloat + (JNIEnv *, jclass, jfloat, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fullLong + * Signature: (J[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullLong + (JNIEnv *, jclass, jlong, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: fullInt + * Signature: (I[II)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullInt + (JNIEnv *, jclass, jint, jintArray, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesDouble + * Signature: (DJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesDouble + (JNIEnv *, jclass, jdouble, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesFloat + * Signature: (FJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesFloat + (JNIEnv *, jclass, jfloat, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesLong + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesLong + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesInt + * Signature: (IJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesInt + (JNIEnv *, jclass, jint, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesDoubleAssign + * Signature: (DJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesDoubleAssign + (JNIEnv *, jclass, jdouble, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesFloatAssign + * Signature: (FJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesFloatAssign + (JNIEnv *, jclass, jfloat, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesLongAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesLongAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesIntAssign + * Signature: (IJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesIntAssign + (JNIEnv *, jclass, jint, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusDouble + * Signature: (DJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusDouble + (JNIEnv *, jclass, jdouble, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusFloat + * Signature: (FJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusFloat + (JNIEnv *, jclass, jfloat, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusLong + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusLong + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusInt + * Signature: (IJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusInt + (JNIEnv *, jclass, jint, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusDoubleAssign + * Signature: (DJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusDoubleAssign + (JNIEnv *, jclass, jdouble, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusFloatAssign + * Signature: (FJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusFloatAssign + (JNIEnv *, jclass, jfloat, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusLongAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusLongAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusIntAssign + * Signature: (IJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusIntAssign + (JNIEnv *, jclass, jint, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesTensor + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesTensor + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: timesTensorAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesTensorAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: divTensor + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_divTensor + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: divTensorAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_divTensorAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusTensor + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusTensor + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: plusTensorAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusTensorAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: minusTensor + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_minusTensor + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: minusTensorAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_minusTensorAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: unaryMinus + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_unaryMinus + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: absTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_absTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: absTensorAssign + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_absTensorAssign + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: transposeTensor + * Signature: (JII)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_transposeTensor + (JNIEnv *, jclass, jlong, jint, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: transposeTensorAssign + * Signature: (JII)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_transposeTensorAssign + (JNIEnv *, jclass, jlong, jint, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: expTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_expTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: expTensorAssign + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_expTensorAssign + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: logTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_logTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: logTensorAssign + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_logTensorAssign + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: sumTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_sumTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: sumTensorAssign + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_sumTensorAssign + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: matmul + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_matmul + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: matmulAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: matmulRightAssign + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulRightAssign + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: diagEmbed + * Signature: (JIII)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_diagEmbed + (JNIEnv *, jclass, jlong, jint, jint, jint); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: svdTensor + * Signature: (JJJJ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_svdTensor + (JNIEnv *, jclass, jlong, jlong, jlong, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: symeigTensor + * Signature: (JJJZ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_symeigTensor + (JNIEnv *, jclass, jlong, jlong, jlong, jboolean); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: requiresGrad + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_requiresGrad + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: setRequiresGrad + * Signature: (JZ)V + */ +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setRequiresGrad + (JNIEnv *, jclass, jlong, jboolean); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: detachFromGraph + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_detachFromGraph + (JNIEnv *, jclass, jlong); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: autogradTensor + * Signature: (JJZ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autogradTensor + (JNIEnv *, jclass, jlong, jlong, jboolean); + +/* + * Class: kscience_kmath_torch_JTorch + * Method: autohessTensor + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autohessTensor + (JNIEnv *, jclass, jlong, jlong); + #ifdef __cplusplus } #endif diff --git a/kmath-torch/src/cppMain/include/utils.hh b/kmath-torch/src/cppMain/include/utils.hh index 2d886352c..0aff9b2e9 100644 --- a/kmath-torch/src/cppMain/include/utils.hh +++ b/kmath-torch/src/cppMain/include/utils.hh @@ -3,7 +3,7 @@ namespace ctorch { - using TorchTensorHandle = void*; + using TorchTensorHandle = void *; template inline c10::ScalarType dtype() @@ -29,16 +29,28 @@ namespace ctorch return torch::kInt32; } - inline torch::Tensor &cast(const TorchTensorHandle &tensor_handle) + template + inline torch::Tensor &cast(const Handle &tensor_handle) { - return *static_cast(tensor_handle); + return *static_cast((TorchTensorHandle)tensor_handle); + } + + template + inline void dispose_tensor(const Handle &tensor_handle) + { + delete static_cast((TorchTensorHandle)tensor_handle); + } + + inline std::string tensor_to_string(const torch::Tensor &tensor) + { + std::stringstream bufrep; + bufrep << tensor; + return bufrep.str(); } inline char *tensor_to_char(const torch::Tensor &tensor) { - std::stringstream bufrep; - bufrep << tensor; - auto rep = bufrep.str(); + auto rep = tensor_to_string(tensor); char *crep = (char *)malloc(rep.length() + 1); std::strcpy(crep, rep.c_str()); return crep; @@ -72,45 +84,43 @@ namespace ctorch } template - inline torch::Tensor from_blob(Dtype *data, std::vector shape, torch::Device device, bool copy) + inline torch::Tensor from_blob(Dtype *data, const std::vector &shape, torch::Device device, bool copy) { return torch::from_blob(data, shape, dtype()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy); } template - inline NumType get(const TorchTensorHandle &tensor_handle, int *index) + inline NumType get(const torch::Tensor &tensor, int *index) { - auto ten = ctorch::cast(tensor_handle); - return ten.index(to_index(index, ten.dim())).item(); + return tensor.index(to_index(index, tensor.dim())).item(); } template - inline void set(TorchTensorHandle &tensor_handle, int *index, NumType value) + inline void set(const torch::Tensor &tensor, int *index, NumType value) { - auto ten = ctorch::cast(tensor_handle); - ten.index(to_index(index, ten.dim())) = value; + tensor.index(to_index(index, tensor.dim())) = value; } template - inline torch::Tensor randn(std::vector shape, torch::Device device) + inline torch::Tensor randn(const std::vector &shape, torch::Device device) { return torch::randn(shape, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } template - inline torch::Tensor rand(std::vector shape, torch::Device device) + inline torch::Tensor rand(const std::vector &shape, torch::Device device) { return torch::rand(shape, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } template - inline torch::Tensor randint(long low, long high, std::vector shape, torch::Device device) + inline torch::Tensor randint(long low, long high, const std::vector &shape, torch::Device device) { return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } template - inline torch::Tensor full(Dtype value, std::vector shape, torch::Device device) + inline torch::Tensor full(Dtype value, const std::vector &shape, torch::Device device) { return torch::full(shape, value, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } diff --git a/kmath-torch/src/cppMain/src/ctorch.cc b/kmath-torch/src/cppMain/src/ctorch.cc index df9635fd0..22256b27b 100644 --- a/kmath-torch/src/cppMain/src/ctorch.cc +++ b/kmath-torch/src/cppMain/src/ctorch.cc @@ -74,9 +74,9 @@ void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle) { std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle)); } -TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim) +TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim) { - return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim))); + return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim))); } char *tensor_to_string(TorchTensorHandle tensor_handle) @@ -89,7 +89,7 @@ void dispose_char(char *ptr) } void dispose_tensor(TorchTensorHandle tensor_handle) { - delete static_cast(tensor_handle); + ctorch::dispose_tensor(tensor_handle); } int get_dim(TorchTensorHandle tensor_handle) @@ -149,35 +149,35 @@ int get_item_int(TorchTensorHandle tensor_handle) double get_double(TorchTensorHandle tensor_handle, int *index) { - return ctorch::get(tensor_handle, index); + return ctorch::get(ctorch::cast(tensor_handle), index); } float get_float(TorchTensorHandle tensor_handle, int *index) { - return ctorch::get(tensor_handle, index); + return ctorch::get(ctorch::cast(tensor_handle), index); } long get_long(TorchTensorHandle tensor_handle, int *index) { - return ctorch::get(tensor_handle, index); + return ctorch::get(ctorch::cast(tensor_handle), index); } int get_int(TorchTensorHandle tensor_handle, int *index) { - return ctorch::get(tensor_handle, index); + return ctorch::get(ctorch::cast(tensor_handle), index); } void set_double(TorchTensorHandle tensor_handle, int *index, double value) { - ctorch::set(tensor_handle, index, value); + ctorch::set(ctorch::cast(tensor_handle), index, value); } void set_float(TorchTensorHandle tensor_handle, int *index, float value) { - ctorch::set(tensor_handle, index, value); + ctorch::set(ctorch::cast(tensor_handle), index, value); } void set_long(TorchTensorHandle tensor_handle, int *index, long value) { - ctorch::set(tensor_handle, index, value); + ctorch::set(ctorch::cast(tensor_handle), index, value); } void set_int(TorchTensorHandle tensor_handle, int *index, int value) { - ctorch::set(tensor_handle, index, value); + ctorch::set(ctorch::cast(tensor_handle), index, value); } TorchTensorHandle rand_double(int *shape, int shape_size, int device) @@ -209,7 +209,7 @@ TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, { return new torch::Tensor(ctorch::randint(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); } -TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device) +TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device) { return new torch::Tensor(ctorch::randint(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); } @@ -230,19 +230,11 @@ void randn_like_assign(TorchTensorHandle tensor_handle) { ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle)); } -TorchTensorHandle randint_long_like(TorchTensorHandle tensor_handle, long low, long high) +TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high) { return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high)); } -void randint_long_like_assign(TorchTensorHandle tensor_handle, long low, long high) -{ - ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high); -} -TorchTensorHandle randint_int_like(TorchTensorHandle tensor_handle, int low, int high) -{ - return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high)); -} -void randint_int_like_assign(TorchTensorHandle tensor_handle, int low, int high) +void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high) { ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high); } @@ -264,39 +256,6 @@ TorchTensorHandle full_int(int value, int *shape, int shape_size, int device) return new torch::Tensor(ctorch::full(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); } -TorchTensorHandle plus_double(double value, TorchTensorHandle other) -{ - return new torch::Tensor(ctorch::cast(other) + value); -} -TorchTensorHandle plus_float(float value, TorchTensorHandle other) -{ - return new torch::Tensor(ctorch::cast(other) + value); -} -TorchTensorHandle plus_long(long value, TorchTensorHandle other) -{ - return new torch::Tensor(ctorch::cast(other) + value); -} -TorchTensorHandle plus_int(int value, TorchTensorHandle other) -{ - return new torch::Tensor(ctorch::cast(other) + value); -} -void plus_double_assign(double value, TorchTensorHandle other) -{ - ctorch::cast(other) += value; -} -void plus_float_assign(float value, TorchTensorHandle other) -{ - ctorch::cast(other) += value; -} -void plus_long_assign(long value, TorchTensorHandle other) -{ - ctorch::cast(other) += value; -} -void plus_int_assign(int value, TorchTensorHandle other) -{ - ctorch::cast(other) += value; -} - TorchTensorHandle times_double(double value, TorchTensorHandle other) { return new torch::Tensor(value * ctorch::cast(other)); @@ -330,22 +289,39 @@ void times_int_assign(int value, TorchTensorHandle other) ctorch::cast(other) *= value; } -TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +TorchTensorHandle plus_double(double value, TorchTensorHandle other) { - return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs)); + return new torch::Tensor(ctorch::cast(other) + value); } -void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +TorchTensorHandle plus_float(float value, TorchTensorHandle other) { - ctorch::cast(lhs) += ctorch::cast(rhs); + return new torch::Tensor(ctorch::cast(other) + value); } -TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +TorchTensorHandle plus_long(long value, TorchTensorHandle other) { - return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs)); + return new torch::Tensor(ctorch::cast(other) + value); } -void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +TorchTensorHandle plus_int(int value, TorchTensorHandle other) { - ctorch::cast(lhs) -= ctorch::cast(rhs); + return new torch::Tensor(ctorch::cast(other) + value); } +void plus_double_assign(double value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} +void plus_float_assign(float value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} +void plus_long_assign(long value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} +void plus_int_assign(int value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} + TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) { return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs)); @@ -362,6 +338,22 @@ void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) { ctorch::cast(lhs) /= ctorch::cast(rhs); } +TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs)); +} +void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + ctorch::cast(lhs) += ctorch::cast(rhs); +} +TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs)); +} +void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + ctorch::cast(lhs) -= ctorch::cast(rhs); +} TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle) { return new torch::Tensor(-ctorch::cast(tensor_handle)); diff --git a/kmath-torch/src/cppMain/src/jtorch.cc b/kmath-torch/src/cppMain/src/jtorch.cc index f803af0e4..4ce22e35c 100644 --- a/kmath-torch/src/cppMain/src/jtorch.cc +++ b/kmath-torch/src/cppMain/src/jtorch.cc @@ -15,21 +15,558 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *, torch::set_num_threads(num_threads); } -JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor(JNIEnv *, jclass) +JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_cudaIsAvailable(JNIEnv *, jclass) { - auto ten = torch::randn({2, 3}); - std::cout << ten << std::endl; - void *ptr = new torch::Tensor(ten); - return (long)ptr; + return torch::cuda::is_available(); } -JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor(JNIEnv *, jclass, jlong tensor_handle) +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setSeed(JNIEnv *, jclass, jint seed) { - auto ten = ctorch::cast((void *)tensor_handle); - std::cout << ten << std::endl; + torch::manual_seed(seed); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_emptyTensor(JNIEnv *, jclass) +{ + return (long)new torch::Tensor; +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobDouble(JNIEnv *env, jclass, jdoubleArray data, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::from_blob( + env->GetDoubleArrayElements(data, 0), + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device), true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobFloat(JNIEnv *env, jclass, jfloatArray data, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::from_blob( + env->GetFloatArrayElements(data, 0), + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device), true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobLong(JNIEnv *env, jclass, jlongArray data, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::from_blob( + env->GetLongArrayElements(data, 0), + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device), true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobInt(JNIEnv *env, jclass, jintArray data, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::from_blob( + env->GetIntArrayElements(data, 0), + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device), true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyTensor(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).clone()); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDevice(JNIEnv *, jclass, jlong tensor_handle, jint device) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDouble(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype(), false, true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToFloat(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype(), false, true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToLong(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype(), false, true)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToInt(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype(), false, true)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_swapTensors(JNIEnv *, jclass, jlong lhs_handle, jlong rhs_handle) +{ + std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_viewTensor(JNIEnv *env, jclass, jlong tensor_handle, jintArray shape) +{ + return (long)new torch::Tensor( + ctorch::cast(tensor_handle).view(ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)))); +} + +JNIEXPORT jstring JNICALL Java_kscience_kmath_torch_JTorch_tensorToString(JNIEnv *env, jclass, jlong tensor_handle) +{ + return env->NewStringUTF(ctorch::tensor_to_string(ctorch::cast(tensor_handle)).c_str()); } JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle) { - delete static_cast((void *)tensor_handle); + ctorch::dispose_tensor(tensor_handle); +} + +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDim(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::cast(tensor_handle).dim(); +} + +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumel(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::cast(tensor_handle).numel(); +} + +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getShapeAt(JNIEnv *, jclass, jlong tensor_handle, jint d) +{ + return ctorch::cast(tensor_handle).size(d); +} + +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getStrideAt(JNIEnv *, jclass, jlong tensor_handle, jint d) +{ + return ctorch::cast(tensor_handle).stride(d); +} + +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDevice(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::device_to_int(ctorch::cast(tensor_handle)); +} + +JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getItemDouble(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::cast(tensor_handle).item(); +} + +JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getItemFloat(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::cast(tensor_handle).item(); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getItemLong(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::cast(tensor_handle).item(); +} + +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getItemInt(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::cast(tensor_handle).item(); +} + +JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index) +{ + return ctorch::get(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0)); +} + +JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index) +{ + return ctorch::get(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index) +{ + return ctorch::get(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0)); +} + +JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index) +{ + return ctorch::get(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jdouble value) +{ + ctorch::set(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jfloat value) +{ + ctorch::set(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jlong value) +{ + ctorch::set(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jint value) +{ + ctorch::set(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randDouble(JNIEnv *env, jclass, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::rand( + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnDouble(JNIEnv *env, jclass, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::randn( + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randFloat(JNIEnv *env, jclass, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::rand( + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnFloat(JNIEnv *env, jclass, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::randn( + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintDouble(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::randint(low, high, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintFloat(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::randint(low, high, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLong(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::randint(low, high, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintInt(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::randint(low, high, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randLike(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(torch::rand_like(ctorch::cast(tensor_handle))); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randLikeAssign(JNIEnv *, jclass, jlong tensor_handle) +{ + ctorch::cast(tensor_handle) = torch::rand_like(ctorch::cast(tensor_handle)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnLike(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(torch::randn_like(ctorch::cast(tensor_handle))); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randnLikeAssign(JNIEnv *, jclass, jlong tensor_handle) +{ + ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLike(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high) +{ + return (long)new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randintLikeAssign(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high) +{ + ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullDouble(JNIEnv *env, jclass, jdouble value, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::full( + value, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullFloat(JNIEnv *env, jclass, jfloat value, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::full( + value, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullLong(JNIEnv *env, jclass, jlong value, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::full( + value, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullInt(JNIEnv *env, jclass, jint value, jintArray shape, jint device) +{ + return (long)new torch::Tensor( + ctorch::full( + value, + ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)), + ctorch::int_to_device(device))); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesDouble(JNIEnv *, jclass, jdouble value, jlong other) +{ + return (long)new torch::Tensor(value * ctorch::cast(other)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesFloat(JNIEnv *, jclass, jfloat value, jlong other) +{ + return (long)new torch::Tensor(value * ctorch::cast(other)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesLong(JNIEnv *, jclass, jlong value, jlong other) +{ + return (long)new torch::Tensor(value * ctorch::cast(other)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesInt(JNIEnv *, jclass, jint value, jlong other) +{ + return (long)new torch::Tensor(value * ctorch::cast(other)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other) +{ + ctorch::cast(other) *= value; +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesFloatAssign(JNIEnv *, jclass, jfloat value, jlong other) +{ + ctorch::cast(other) *= value; +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesLongAssign(JNIEnv *, jclass, jlong value, jlong other) +{ + ctorch::cast(other) *= value; +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesIntAssign(JNIEnv *, jclass, jint value, jlong other) +{ + ctorch::cast(other) *= value; +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusDouble(JNIEnv *, jclass, jdouble value, jlong other) +{ + return (long)new torch::Tensor(value + ctorch::cast(other)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusFloat(JNIEnv *, jclass, jfloat value, jlong other) +{ + return (long)new torch::Tensor(value + ctorch::cast(other)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusLong(JNIEnv *, jclass, jlong value, jlong other) +{ + return (long)new torch::Tensor(value + ctorch::cast(other)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusInt(JNIEnv *, jclass, jint value, jlong other) +{ + return (long)new torch::Tensor(value + ctorch::cast(other)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other) +{ + ctorch::cast(other) += value; +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusFloatAssign(JNIEnv *, jclass, jfloat value, jlong other) +{ + ctorch::cast(other) += value; +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusLongAssign(JNIEnv *, jclass, jlong value, jlong other) +{ + ctorch::cast(other) += value; +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusIntAssign(JNIEnv *, jclass, jint value, jlong other) +{ + ctorch::cast(other) += value; +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesTensor(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + return (long)new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + ctorch::cast(lhs) *= ctorch::cast(rhs); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_divTensor(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + return (long)new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_divTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + ctorch::cast(lhs) /= ctorch::cast(rhs); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + return (long)new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + ctorch::cast(lhs) += ctorch::cast(rhs); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_minusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + return (long)new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_minusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + ctorch::cast(lhs) -= ctorch::cast(rhs); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_unaryMinus(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(-ctorch::cast(tensor_handle)); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_absTensor(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).abs()); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_absTensorAssign(JNIEnv *, jclass, jlong tensor_handle) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs(); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_transposeTensor(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_transposeTensorAssign(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_expTensor(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).exp()); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_expTensorAssign(JNIEnv *, jclass, jlong tensor_handle) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp(); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_logTensor(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).log()); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_logTensorAssign(JNIEnv *, jclass, jlong tensor_handle) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log(); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_sumTensor(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).sum()); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_sumTensorAssign(JNIEnv *, jclass, jlong tensor_handle) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum(); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_matmul(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + return (long)new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs))); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulAssign(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs)); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulRightAssign(JNIEnv *, jclass, jlong lhs, jlong rhs) +{ + ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs)); +} + +JNIEXPORT jlong JNICALL +Java_kscience_kmath_torch_JTorch_diagEmbed(JNIEnv *, jclass, jlong diags_handle, jint offset, jint dim1, jint dim2) +{ + return (long)new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2)); +} + +JNIEXPORT void JNICALL +Java_kscience_kmath_torch_JTorch_svdTensor(JNIEnv *, jclass, jlong tensor_handle, jlong U_handle, jlong S_handle, jlong V_handle) +{ + auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle)); + ctorch::cast(U_handle) = U; + ctorch::cast(S_handle) = S; + ctorch::cast(V_handle) = V; +} + +JNIEXPORT void JNICALL +Java_kscience_kmath_torch_JTorch_symeigTensor(JNIEnv *, jclass, jlong tensor_handle, jlong S_handle, jlong V_handle, jboolean eigenvectors) +{ + auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors); + ctorch::cast(S_handle) = S; + ctorch::cast(V_handle) = V; +} + +JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_requiresGrad(JNIEnv *, jclass, jlong tensor_handle) +{ + return ctorch::cast(tensor_handle).requires_grad(); +} + +JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setRequiresGrad(JNIEnv *, jclass, jlong tensor_handle, jboolean status) +{ + ctorch::cast(tensor_handle).requires_grad_(status); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_detachFromGraph(JNIEnv *, jclass, jlong tensor_handle) +{ + return (long)new torch::Tensor(ctorch::cast(tensor_handle).detach()); +} + +JNIEXPORT jlong JNICALL +Java_kscience_kmath_torch_JTorch_autogradTensor(JNIEnv *, jclass, jlong value, jlong variable, jboolean retain_graph) +{ + return (long)new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]); +} + +JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autohessTensor(JNIEnv *, jclass, jlong value, jlong variable) +{ + return (long)new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable))); } \ No newline at end of file diff --git a/kmath-torch/src/jvmMain/java/kscience/kmath/torch/JTorch.java b/kmath-torch/src/jvmMain/java/kscience/kmath/torch/JTorch.java index 818a3d694..f912c36e2 100644 --- a/kmath-torch/src/jvmMain/java/kscience/kmath/torch/JTorch.java +++ b/kmath-torch/src/jvmMain/java/kscience/kmath/torch/JTorch.java @@ -7,8 +7,202 @@ class JTorch { } public static native int getNumThreads(); + public static native void setNumThreads(int numThreads); - public static native long createTensor(); - public static native void printTensor(long tensorHandle); + + public static native boolean cudaIsAvailable(); + + public static native void setSeed(int seed); + + public static native long emptyTensor(); + + public static native long fromBlobDouble(double[] data, int[] shape, int device); + + public static native long fromBlobFloat(float[] data, int[] shape, int device); + + public static native long fromBlobLong(long[] data, int[] shape, int device); + + public static native long fromBlobInt(int[] data, int[] shape, int device); + + public static native long copyTensor(long tensorHandle); + + public static native long copyToDevice(long tensorHandle, int device); + + public static native long copyToDouble(long tensorHandle); + + public static native long copyToFloat(long tensorHandle); + + public static native long copyToLong(long tensorHandle); + + public static native long copyToInt(long tensorHandle); + + public static native void swapTensors(long lhsHandle, long rhsHandle); + + public static native long viewTensor(long tensorHandle, int[] shape); + + public static native String tensorToString(long tensorHandle); + public static native void disposeTensor(long tensorHandle); -} \ No newline at end of file + + public static native int getDim(long tensorHandle); + + public static native int getNumel(long tensorHandle); + + public static native int getShapeAt(long tensorHandle, int d); + + public static native int getStrideAt(long tensorHandle, int d); + + public static native int getDevice(long tensorHandle); + + public static native double getItemDouble(long tensorHandle); + + public static native float getItemFloat(long tensorHandle); + + public static native long getItemLong(long tensorHandle); + + public static native int getItemInt(long tensorHandle); + + public static native double getDouble(long tensorHandle, int[] index); + + public static native float getFloat(long tensorHandle, int[] index); + + public static native long getLong(long tensorHandle, int[] index); + + public static native int getInt(long tensorHandle, int[] index); + + public static native void setDouble(long tensorHandle, int[] index, double value); + + public static native void setFloat(long tensorHandle, int[] index, float value); + + public static native void setLong(long tensorHandle, int[] index, long value); + + public static native void setInt(long tensorHandle, int[] index, int value); + + public static native long randDouble(int[] shape, int device); + + public static native long randnDouble(int[] shape, int device); + + public static native long randFloat(int[] shape, int device); + + public static native long randnFloat(int[] shape, int device); + + public static native long randintDouble(long low, long high, int[] shape, int device); + + public static native long randintFloat(long low, long high, int[] shape, int device); + + public static native long randintLong(long low, long high, int[] shape, int device); + + public static native long randintInt(long low, long high, int[] shape, int device); + + public static native long randLike(long tensorHandle); + + public static native void randLikeAssign(long tensorHandle); + + public static native long randnLike(long tensorHandle); + + public static native void randnLikeAssign(long tensorHandle); + + public static native long randintLike(long tensorHandle, long low, long high); + + public static native void randintLikeAssign(long tensorHandle, long low, long high); + + public static native long fullDouble(double value, int[] shape, int device); + + public static native long fullFloat(float value, int[] shape, int device); + + public static native long fullLong(long value, int[] shape, int device); + + public static native long fullInt(int value, int[] shape, int device); + + public static native long timesDouble(double value, long other); + + public static native long timesFloat(float value, long other); + + public static native long timesLong(long value, long other); + + public static native long timesInt(int value, long other); + + public static native void timesDoubleAssign(double value, long other); + + public static native void timesFloatAssign(float value, long other); + + public static native void timesLongAssign(long value, long other); + + public static native void timesIntAssign(int value, long other); + + public static native long plusDouble(double value, long other); + + public static native long plusFloat(float value, long other); + + public static native long plusLong(long value, long other); + + public static native long plusInt(int value, long other); + + public static native void plusDoubleAssign(double value, long other); + + public static native void plusFloatAssign(float value, long other); + + public static native void plusLongAssign(long value, long other); + + public static native void plusIntAssign(int value, long other); + + public static native long timesTensor(long lhs, long rhs); + + public static native void timesTensorAssign(long lhs, long rhs); + + public static native long divTensor(long lhs, long rhs); + + public static native void divTensorAssign(long lhs, long rhs); + + public static native long plusTensor(long lhs, long rhs); + + public static native void plusTensorAssign(long lhs, long rhs); + + public static native long minusTensor(long lhs, long rhs); + + public static native void minusTensorAssign(long lhs, long rhs); + + public static native long unaryMinus(long tensorHandle); + + public static native long absTensor(long tensorHandle); + + public static native void absTensorAssign(long tensorHandle); + + public static native long transposeTensor(long tensorHandle, int i, int j); + + public static native void transposeTensorAssign(long tensorHandle, int i, int j); + + public static native long expTensor(long tensorHandle); + + public static native void expTensorAssign(long tensorHandle); + + public static native long logTensor(long tensorHandle); + + public static native void logTensorAssign(long tensorHandle); + + public static native long sumTensor(long tensorHandle); + + public static native void sumTensorAssign(long tensorHandle); + + public static native long matmul(long lhs, long rhs); + + public static native void matmulAssign(long lhs, long rhs); + + public static native void matmulRightAssign(long lhs, long rhs); + + public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2); + + public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle); + + public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors); + + public static native boolean requiresGrad(long tensorHandle); + + public static native void setRequiresGrad(long tensorHandle, boolean status); + + public static native long detachFromGraph(long tensorHandle); + + public static native long autogradTensor(long value, long variable, boolean retainGraph); + + public static native long autohessTensor(long value, long variable); +} diff --git a/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt b/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt new file mode 100644 index 000000000..0064b33b3 --- /dev/null +++ b/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt @@ -0,0 +1,4 @@ +package kscience.kmath.torch + +public class TorchTensorJVM { +} \ No newline at end of file diff --git a/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt b/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt deleted file mode 100644 index dc08360ae..000000000 --- a/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt +++ /dev/null @@ -1,17 +0,0 @@ -package kscience.kmath.torch - - -public fun getNumThreads(): Int { - return JTorch.getNumThreads() -} - -public fun setNumThreads(numThreads: Int): Unit { - JTorch.setNumThreads(numThreads) -} - - -public fun runCPD(): Unit { - val tensorHandle = JTorch.createTensor() - JTorch.printTensor(tensorHandle) - JTorch.disposeTensor(tensorHandle) -} \ No newline at end of file diff --git a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt index 316a9e01c..2371b576e 100644 --- a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt +++ b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt @@ -6,14 +6,10 @@ import kotlin.test.* class TestUtils { @Test - fun testSetNumThreads() { - val numThreads = 2 - setNumThreads(numThreads) - assertEquals(numThreads, getNumThreads()) + fun testJTorch() { + val tensor = JTorch.fullInt(54, intArrayOf(3), 0) + println(JTorch.tensorToString(tensor)) + JTorch.disposeTensor(tensor) } - @Test - fun testCPD() { - runCPD() - } } \ No newline at end of file diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt index aa0564f2e..93987fcf0 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt @@ -117,6 +117,13 @@ public sealed class TorchTensorAlgebraNative< sum_tensor_assign(tensorHandle) } + override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType = + wrap(randint_like(this.tensorHandle, low, high)!!) + + override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit { + randint_like_assign(this.tensorHandle, low, high) + } + override fun TorchTensorType.copy(): TorchTensorType = wrap(copy_tensor(this.tensorHandle)!!) @@ -224,6 +231,9 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : override fun randUniform(shape: IntArray, device: Device): TorchTensorReal = wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!) + override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal = + wrap(randint_double(low, high, shape.toCValues(), shape.size, device.toInt())!!) + override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = wrap(plus_double(this, other.tensorHandle)!!) @@ -256,18 +266,9 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!) - - override fun randIntegral(low: Double, high: Double, shape: IntArray, device: Device): TorchTensorReal = - wrap(randint_double(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!) - - override fun TorchTensorReal.randIntegral(low: Double, high: Double): TorchTensorReal = - wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!) - - override fun TorchTensorReal.randIntegralAssign(low: Double, high: Double): Unit { - randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong()) - } } + public class TorchTensorFloatAlgebra(scope: DeferScope) : TorchTensorPartialDivisionAlgebraNative(scope) { override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat = @@ -295,6 +296,9 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) : override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat = wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!) + override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat = + wrap(randint_float(low, high, shape.toCValues(), shape.size, device.toInt())!!) + override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat = wrap(plus_float(this, other.tensorHandle)!!) @@ -328,15 +332,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) : override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!) - override fun randIntegral(low: Float, high: Float, shape: IntArray, device: Device): TorchTensorFloat = - wrap(randint_float(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!) - - override fun TorchTensorFloat.randIntegral(low: Float, high: Float): TorchTensorFloat = - wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!) - - override fun TorchTensorFloat.randIntegralAssign(low: Float, high: Float): Unit { - randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong()) - } } public class TorchTensorLongAlgebra(scope: DeferScope) : @@ -363,13 +358,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) : override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong = wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!) - override fun TorchTensorLong.randIntegral(low: Long, high: Long): TorchTensorLong = - wrap(randint_long_like(this.tensorHandle, low, high)!!) - - override fun TorchTensorLong.randIntegralAssign(low: Long, high: Long): Unit { - randint_long_like_assign(this.tensorHandle, low, high) - } - override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong = wrap(plus_long(this, other.tensorHandle)!!) @@ -425,16 +413,9 @@ public class TorchTensorIntAlgebra(scope: DeferScope) : return get_data_int(this.tensorHandle)!! } - override fun randIntegral(low: Int, high: Int, shape: IntArray, device: Device): TorchTensorInt = + override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt = wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!) - override fun TorchTensorInt.randIntegral(low: Int, high: Int): TorchTensorInt = - wrap(randint_int_like(this.tensorHandle, low, high)!!) - - override fun TorchTensorInt.randIntegralAssign(low: Int, high: Int): Unit { - randint_int_like_assign(this.tensorHandle, low, high) - } - override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt = wrap(plus_int(this, other.tensorHandle)!!) diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt index f990dbc1a..c3ba10dc0 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt @@ -22,6 +22,4 @@ internal class BenchmarkMatMul { benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0)) } } - - } \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt index 02f362165..3da1eac0d 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt @@ -6,7 +6,7 @@ import kotlin.test.* internal class TestUtils { @Test fun testSetNumThreads() { - TorchTensorIntAlgebra { + TorchTensorLongAlgebra { testingSetNumThreads() } }