forked from kscience/kmath
JNI wrapper
This commit is contained in:
parent
d599d1132b
commit
17e6ebbc14
3
.gitignore
vendored
3
.gitignore
vendored
@ -10,6 +10,7 @@ out/
|
||||
# Cache of project
|
||||
.gradletasknamecache
|
||||
|
||||
# Generated by javac -h
|
||||
# Generated by javac -h and runtime
|
||||
*.class
|
||||
*.log
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# LibTorch extension (`kmath-torch`)
|
||||
|
||||
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of
|
||||
This is a `Kotlin/Native` & `JVM` module, with only `linuxX64` supported so far. The library wraps some of
|
||||
the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||
|
||||
## Installation
|
||||
@ -11,7 +11,7 @@ To install the library, you have to build & publish locally `kmath-core`, `kmath
|
||||
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||||
```
|
||||
|
||||
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
||||
This builds `ctorch` a C wrapper and `jtorch` a JNI wrapper for `LibTorch`, placed inside:
|
||||
|
||||
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
||||
|
||||
@ -19,8 +19,8 @@ You will have to link against it in your own project.
|
||||
|
||||
## Usage
|
||||
|
||||
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods
|
||||
and require scoping:
|
||||
Tensors are implemented over the `MutableNDStructure`. They can only be created through provided factory methods
|
||||
and require scoping within a `TensorAlgebra` instance:
|
||||
|
||||
```kotlin
|
||||
TorchTensorRealAlgebra {
|
||||
@ -63,4 +63,4 @@ TorchTensorRealAlgebra {
|
||||
val hessianAtX = expressionAtX hess tensorX
|
||||
}
|
||||
```
|
||||
|
||||
Contributed by [Roland Grinis](https://github.com/rgrit91)
|
||||
|
@ -127,7 +127,6 @@ val generateJNIHeader by tasks.registering {
|
||||
kotlin {
|
||||
explicitApiWarning()
|
||||
|
||||
|
||||
jvm {
|
||||
withJava()
|
||||
}
|
||||
@ -158,7 +157,6 @@ kotlin {
|
||||
val test by nativeTarget.compilations.getting
|
||||
|
||||
sourceSets {
|
||||
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
@ -171,6 +169,11 @@ kotlin {
|
||||
}
|
||||
}
|
||||
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,12 +25,11 @@ public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : Tor
|
||||
public fun full(value: T, shape: IntArray, device: Device): TorchTensorType
|
||||
|
||||
public fun randIntegral(
|
||||
low: T, high: T, shape: IntArray,
|
||||
low: Long, high: Long, shape: IntArray,
|
||||
device: Device = Device.CPU
|
||||
): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
|
||||
public fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
|
||||
public fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType
|
||||
public fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit
|
||||
|
||||
public fun TorchTensorType.copy(): TorchTensorType
|
||||
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType
|
||||
|
@ -73,8 +73,8 @@ internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
numIter,
|
||||
device,
|
||||
"integer [0,100]",
|
||||
{sh, dc -> randIntegral(0f, 100f, shape = sh, device = dc)},
|
||||
{ten -> ten.randIntegralAssign(0f, 100f) }
|
||||
{sh, dc -> randIntegral(0, 100, shape = sh, device = dc)},
|
||||
{ten -> ten.randIntegralAssign(0, 100) }
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverFie
|
||||
internal inline fun <TorchTensorType : TorchTensor<Int>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<Int, IntArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingViewWithNoCopy(device: Device = Device.CPU) {
|
||||
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6))
|
||||
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
|
||||
val viewTensor = tensor.view(intArrayOf(2, 3))
|
||||
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
|
||||
viewTensor[intArrayOf(0, 0)] = 10
|
||||
|
@ -26,11 +26,11 @@ internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit {
|
||||
setSeed(SEED)
|
||||
val integral = randIntegral(0f, 100f, IntArray(0), device = device).value()
|
||||
val integral = randIntegral(0, 100, IntArray(0), device = device).value()
|
||||
val normal = randNormal(IntArray(0), device = device).value()
|
||||
val uniform = randUniform(IntArray(0), device = device).value()
|
||||
setSeed(SEED)
|
||||
val nextIntegral = randIntegral(0f, 100f, IntArray(0), device = device).value()
|
||||
val nextIntegral = randIntegral(0, 100, IntArray(0), device = device).value()
|
||||
val nextNormal = randNormal(IntArray(0), device = device).value()
|
||||
val nextUniform = randUniform(IntArray(0), device = device).value()
|
||||
assertEquals(normal, nextNormal)
|
||||
|
@ -70,16 +70,15 @@ extern "C"
|
||||
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle rand_like(TorchTensorHandle tensor_handle);
|
||||
void rand_like_assign(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle randn_like(TorchTensorHandle tensor_handle);
|
||||
void randn_like_assign(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle randint_long_like(TorchTensorHandle tensor_handle, long low, long high);
|
||||
void randint_long_like_assign(TorchTensorHandle tensor_handle, long low, long high);
|
||||
TorchTensorHandle randint_int_like(TorchTensorHandle tensor_handle, int low, int high);
|
||||
void randint_int_like_assign(TorchTensorHandle tensor_handle, int low, int high);
|
||||
TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high);
|
||||
void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high);
|
||||
|
||||
|
||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
||||
@ -106,14 +105,14 @@ extern "C"
|
||||
void plus_long_assign(long value, TorchTensorHandle other);
|
||||
void plus_int_assign(int value, TorchTensorHandle other);
|
||||
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||
|
@ -25,18 +25,130 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: createTensor
|
||||
* Signature: ()J
|
||||
* Method: cudaIsAvailable
|
||||
* Signature: ()Z
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor
|
||||
JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_cudaIsAvailable
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: printTensor
|
||||
* Signature: (J)V
|
||||
* Method: setSeed
|
||||
* Signature: (I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setSeed
|
||||
(JNIEnv *, jclass, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: emptyTensor
|
||||
* Signature: ()J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_emptyTensor
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobDouble
|
||||
* Signature: ([D[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobDouble
|
||||
(JNIEnv *, jclass, jdoubleArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobFloat
|
||||
* Signature: ([F[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobFloat
|
||||
(JNIEnv *, jclass, jfloatArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobLong
|
||||
* Signature: ([J[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobLong
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobInt
|
||||
* Signature: ([I[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobInt
|
||||
(JNIEnv *, jclass, jintArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: copyTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: copyToDevice
|
||||
* Signature: (JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDevice
|
||||
(JNIEnv *, jclass, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: copyToDouble
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDouble
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: copyToFloat
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToFloat
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: copyToLong
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToLong
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: copyToInt
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToInt
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: swapTensors
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_swapTensors
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: viewTensor
|
||||
* Signature: (J[I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_viewTensor
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: tensorToString
|
||||
* Signature: (J)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_kscience_kmath_torch_JTorch_tensorToString
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
@ -47,6 +159,654 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getDim
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDim
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getNumel
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumel
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getShapeAt
|
||||
* Signature: (JI)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getShapeAt
|
||||
(JNIEnv *, jclass, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getStrideAt
|
||||
* Signature: (JI)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getStrideAt
|
||||
(JNIEnv *, jclass, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getDevice
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDevice
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getItemDouble
|
||||
* Signature: (J)D
|
||||
*/
|
||||
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getItemDouble
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getItemFloat
|
||||
* Signature: (J)F
|
||||
*/
|
||||
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getItemFloat
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getItemLong
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getItemLong
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getItemInt
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getItemInt
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getDouble
|
||||
* Signature: (J[I)D
|
||||
*/
|
||||
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getDouble
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getFloat
|
||||
* Signature: (J[I)F
|
||||
*/
|
||||
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getFloat
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getLong
|
||||
* Signature: (J[I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getLong
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getInt
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getInt
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: setDouble
|
||||
* Signature: (J[ID)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setDouble
|
||||
(JNIEnv *, jclass, jlong, jintArray, jdouble);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: setFloat
|
||||
* Signature: (J[IF)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setFloat
|
||||
(JNIEnv *, jclass, jlong, jintArray, jfloat);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: setLong
|
||||
* Signature: (J[IJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setLong
|
||||
(JNIEnv *, jclass, jlong, jintArray, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: setInt
|
||||
* Signature: (J[II)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setInt
|
||||
(JNIEnv *, jclass, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randDouble
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randDouble
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randnDouble
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnDouble
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randFloat
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randFloat
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randnFloat
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnFloat
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randintDouble
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintDouble
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randintFloat
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintFloat
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randintLong
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLong
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randintInt
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintInt
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randLike
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randLike
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randLikeAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randLikeAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randnLike
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnLike
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randnLikeAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randnLikeAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randintLike
|
||||
* Signature: (JJJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLike
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: randintLikeAssign
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randintLikeAssign
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fullDouble
|
||||
* Signature: (D[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullDouble
|
||||
(JNIEnv *, jclass, jdouble, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fullFloat
|
||||
* Signature: (F[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullFloat
|
||||
(JNIEnv *, jclass, jfloat, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fullLong
|
||||
* Signature: (J[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullLong
|
||||
(JNIEnv *, jclass, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: fullInt
|
||||
* Signature: (I[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullInt
|
||||
(JNIEnv *, jclass, jint, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesDouble
|
||||
* Signature: (DJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesDouble
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesFloat
|
||||
* Signature: (FJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesFloat
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesLong
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesLong
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesInt
|
||||
* Signature: (IJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesInt
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesDoubleAssign
|
||||
* Signature: (DJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesDoubleAssign
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesFloatAssign
|
||||
* Signature: (FJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesFloatAssign
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesLongAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesLongAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesIntAssign
|
||||
* Signature: (IJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesIntAssign
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusDouble
|
||||
* Signature: (DJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusDouble
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusFloat
|
||||
* Signature: (FJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusFloat
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusLong
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusLong
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusInt
|
||||
* Signature: (IJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusInt
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusDoubleAssign
|
||||
* Signature: (DJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusDoubleAssign
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusFloatAssign
|
||||
* Signature: (FJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusFloatAssign
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusLongAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusLongAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusIntAssign
|
||||
* Signature: (IJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusIntAssign
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: timesTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: divTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_divTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: divTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_divTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: plusTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: minusTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_minusTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: minusTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_minusTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: unaryMinus
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_unaryMinus
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: absTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_absTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: absTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_absTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: transposeTensor
|
||||
* Signature: (JII)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_transposeTensor
|
||||
(JNIEnv *, jclass, jlong, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: transposeTensorAssign
|
||||
* Signature: (JII)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_transposeTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: expTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_expTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: expTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_expTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: logTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_logTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: logTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_logTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: sumTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_sumTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: sumTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_sumTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: matmul
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_matmul
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: matmulAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: matmulRightAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulRightAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: diagEmbed
|
||||
* Signature: (JIII)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_diagEmbed
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: svdTensor
|
||||
* Signature: (JJJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_svdTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: symeigTensor
|
||||
* Signature: (JJJZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_symeigTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong, jboolean);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: requiresGrad
|
||||
* Signature: (J)Z
|
||||
*/
|
||||
JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_requiresGrad
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: setRequiresGrad
|
||||
* Signature: (JZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setRequiresGrad
|
||||
(JNIEnv *, jclass, jlong, jboolean);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: detachFromGraph
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_detachFromGraph
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: autogradTensor
|
||||
* Signature: (JJZ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autogradTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jboolean);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: autohessTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autohessTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -3,7 +3,7 @@
|
||||
namespace ctorch
|
||||
{
|
||||
|
||||
using TorchTensorHandle = void*;
|
||||
using TorchTensorHandle = void *;
|
||||
|
||||
template <typename Dtype>
|
||||
inline c10::ScalarType dtype()
|
||||
@ -29,16 +29,28 @@ namespace ctorch
|
||||
return torch::kInt32;
|
||||
}
|
||||
|
||||
inline torch::Tensor &cast(const TorchTensorHandle &tensor_handle)
|
||||
template <typename Handle>
|
||||
inline torch::Tensor &cast(const Handle &tensor_handle)
|
||||
{
|
||||
return *static_cast<torch::Tensor *>(tensor_handle);
|
||||
return *static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
|
||||
}
|
||||
|
||||
template <typename Handle>
|
||||
inline void dispose_tensor(const Handle &tensor_handle)
|
||||
{
|
||||
delete static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
|
||||
}
|
||||
|
||||
inline std::string tensor_to_string(const torch::Tensor &tensor)
|
||||
{
|
||||
std::stringstream bufrep;
|
||||
bufrep << tensor;
|
||||
return bufrep.str();
|
||||
}
|
||||
|
||||
inline char *tensor_to_char(const torch::Tensor &tensor)
|
||||
{
|
||||
std::stringstream bufrep;
|
||||
bufrep << tensor;
|
||||
auto rep = bufrep.str();
|
||||
auto rep = tensor_to_string(tensor);
|
||||
char *crep = (char *)malloc(rep.length() + 1);
|
||||
std::strcpy(crep, rep.c_str());
|
||||
return crep;
|
||||
@ -72,45 +84,43 @@ namespace ctorch
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device, bool copy)
|
||||
inline torch::Tensor from_blob(Dtype *data, const std::vector<int64_t> &shape, torch::Device device, bool copy)
|
||||
{
|
||||
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
|
||||
}
|
||||
|
||||
template <typename NumType>
|
||||
inline NumType get(const TorchTensorHandle &tensor_handle, int *index)
|
||||
inline NumType get(const torch::Tensor &tensor, int *index)
|
||||
{
|
||||
auto ten = ctorch::cast(tensor_handle);
|
||||
return ten.index(to_index(index, ten.dim())).item<NumType>();
|
||||
return tensor.index(to_index(index, tensor.dim())).item<NumType>();
|
||||
}
|
||||
|
||||
template <typename NumType>
|
||||
inline void set(TorchTensorHandle &tensor_handle, int *index, NumType value)
|
||||
inline void set(const torch::Tensor &tensor, int *index, NumType value)
|
||||
{
|
||||
auto ten = ctorch::cast(tensor_handle);
|
||||
ten.index(to_index(index, ten.dim())) = value;
|
||||
tensor.index(to_index(index, tensor.dim())) = value;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor randn(std::vector<int64_t> shape, torch::Device device)
|
||||
inline torch::Tensor randn(const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::randn(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor rand(std::vector<int64_t> shape, torch::Device device)
|
||||
inline torch::Tensor rand(const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor randint(long low, long high, std::vector<int64_t> shape, torch::Device device)
|
||||
inline torch::Tensor randint(long low, long high, const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
|
||||
inline torch::Tensor full(Dtype value, const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
@ -74,9 +74,9 @@ void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
|
||||
{
|
||||
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||
}
|
||||
TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim)
|
||||
TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim)));
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim)));
|
||||
}
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||
@ -89,7 +89,7 @@ void dispose_char(char *ptr)
|
||||
}
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
delete static_cast<torch::Tensor *>(tensor_handle);
|
||||
ctorch::dispose_tensor(tensor_handle);
|
||||
}
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle)
|
||||
@ -149,35 +149,35 @@ int get_item_int(TorchTensorHandle tensor_handle)
|
||||
|
||||
double get_double(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<double>(tensor_handle, index);
|
||||
return ctorch::get<double>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
float get_float(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<float>(tensor_handle, index);
|
||||
return ctorch::get<float>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
long get_long(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<long>(tensor_handle, index);
|
||||
return ctorch::get<long>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
int get_int(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<int>(tensor_handle, index);
|
||||
return ctorch::get<int>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
void set_double(TorchTensorHandle tensor_handle, int *index, double value)
|
||||
{
|
||||
ctorch::set<double>(tensor_handle, index, value);
|
||||
ctorch::set<double>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
void set_float(TorchTensorHandle tensor_handle, int *index, float value)
|
||||
{
|
||||
ctorch::set<float>(tensor_handle, index, value);
|
||||
ctorch::set<float>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
void set_long(TorchTensorHandle tensor_handle, int *index, long value)
|
||||
{
|
||||
ctorch::set<long>(tensor_handle, index, value);
|
||||
ctorch::set<long>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
void set_int(TorchTensorHandle tensor_handle, int *index, int value)
|
||||
{
|
||||
ctorch::set<int>(tensor_handle, index, value);
|
||||
ctorch::set<int>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
|
||||
TorchTensorHandle rand_double(int *shape, int shape_size, int device)
|
||||
@ -209,7 +209,7 @@ TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size,
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device)
|
||||
TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
@ -230,19 +230,11 @@ void randn_like_assign(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
|
||||
}
|
||||
TorchTensorHandle randint_long_like(TorchTensorHandle tensor_handle, long low, long high)
|
||||
TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high)
|
||||
{
|
||||
return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
|
||||
}
|
||||
void randint_long_like_assign(TorchTensorHandle tensor_handle, long low, long high)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
|
||||
}
|
||||
TorchTensorHandle randint_int_like(TorchTensorHandle tensor_handle, int low, int high)
|
||||
{
|
||||
return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
|
||||
}
|
||||
void randint_int_like_assign(TorchTensorHandle tensor_handle, int low, int high)
|
||||
void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
|
||||
}
|
||||
@ -264,39 +256,6 @@ TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
|
||||
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
void plus_double_assign(double value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_float_assign(float value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_long_assign(long value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_int_assign(int value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
@ -330,22 +289,39 @@ void times_int_assign(int value, TorchTensorHandle other)
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
void plus_double_assign(double value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_float_assign(float value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_long_assign(long value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_int_assign(int value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||
@ -362,6 +338,22 @@ void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||
}
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||
}
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(-ctorch::cast(tensor_handle));
|
||||
|
@ -15,21 +15,558 @@ JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *,
|
||||
torch::set_num_threads(num_threads);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor(JNIEnv *, jclass)
|
||||
JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_cudaIsAvailable(JNIEnv *, jclass)
|
||||
{
|
||||
auto ten = torch::randn({2, 3});
|
||||
std::cout << ten << std::endl;
|
||||
void *ptr = new torch::Tensor(ten);
|
||||
return (long)ptr;
|
||||
return torch::cuda::is_available();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setSeed(JNIEnv *, jclass, jint seed)
|
||||
{
|
||||
auto ten = ctorch::cast((void *)tensor_handle);
|
||||
std::cout << ten << std::endl;
|
||||
torch::manual_seed(seed);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_emptyTensor(JNIEnv *, jclass)
|
||||
{
|
||||
return (long)new torch::Tensor;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobDouble(JNIEnv *env, jclass, jdoubleArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<double>(
|
||||
env->GetDoubleArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobFloat(JNIEnv *env, jclass, jfloatArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<float>(
|
||||
env->GetFloatArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobLong(JNIEnv *env, jclass, jlongArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<long>(
|
||||
env->GetLongArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fromBlobInt(JNIEnv *env, jclass, jintArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<int>(
|
||||
env->GetIntArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).clone());
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDevice(JNIEnv *, jclass, jlong tensor_handle, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToDouble(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToFloat(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToLong(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_copyToInt(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_swapTensors(JNIEnv *, jclass, jlong lhs_handle, jlong rhs_handle)
|
||||
{
|
||||
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_viewTensor(JNIEnv *env, jclass, jlong tensor_handle, jintArray shape)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::cast(tensor_handle).view(ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape))));
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL Java_kscience_kmath_torch_JTorch_tensorToString(JNIEnv *env, jclass, jlong tensor_handle)
|
||||
{
|
||||
return env->NewStringUTF(ctorch::tensor_to_string(ctorch::cast(tensor_handle)).c_str());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
delete static_cast<torch::Tensor *>((void *)tensor_handle);
|
||||
ctorch::dispose_tensor(tensor_handle);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDim(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).dim();
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumel(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).numel();
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getShapeAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).size(d);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getStrideAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).stride(d);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getDevice(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getItemDouble(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<double>();
|
||||
}
|
||||
|
||||
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getItemFloat(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<float>();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getItemLong(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<long>();
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getItemInt(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<int>();
|
||||
}
|
||||
|
||||
JNIEXPORT jdouble JNICALL Java_kscience_kmath_torch_JTorch_getDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT jfloat JNICALL Java_kscience_kmath_torch_JTorch_getFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_getLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jdouble value)
|
||||
{
|
||||
ctorch::set<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jfloat value)
|
||||
{
|
||||
ctorch::set<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jlong value)
|
||||
{
|
||||
ctorch::set<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jint value)
|
||||
{
|
||||
ctorch::set<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randDouble(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::rand<double>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnDouble(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randn<double>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randFloat(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::rand<float>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnFloat(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randn<float>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintDouble(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<double>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintFloat(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<float>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLong(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<long>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintInt(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<int>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randLike(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::rand_like(ctorch::cast(tensor_handle)));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::rand_like(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randnLike(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::randn_like(ctorch::cast(tensor_handle)));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randnLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_randintLike(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_randintLikeAssign(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullDouble(JNIEnv *env, jclass, jdouble value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<double>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullFloat(JNIEnv *env, jclass, jfloat value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<float>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullLong(JNIEnv *env, jclass, jlong value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<long>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_fullInt(JNIEnv *env, jclass, jint value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<int>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesDouble(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesFloat(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesLong(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesInt(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesLongAssign(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesIntAssign(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusDouble(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusFloat(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusLong(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusInt(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusLongAssign(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusIntAssign(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_timesTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_timesTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) *= ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_divTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_divTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_plusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_plusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_minusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_minusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_unaryMinus(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(-ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_absTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_absTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_transposeTensor(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_transposeTensorAssign(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_expTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).exp());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_expTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_logTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).log());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_logTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_sumTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_sumTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_matmul(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_matmulRightAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_kscience_kmath_torch_JTorch_diagEmbed(JNIEnv *, jclass, jlong diags_handle, jint offset, jint dim1, jint dim2)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_kscience_kmath_torch_JTorch_svdTensor(JNIEnv *, jclass, jlong tensor_handle, jlong U_handle, jlong S_handle, jlong V_handle)
|
||||
{
|
||||
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
|
||||
ctorch::cast(U_handle) = U;
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_kscience_kmath_torch_JTorch_symeigTensor(JNIEnv *, jclass, jlong tensor_handle, jlong S_handle, jlong V_handle, jboolean eigenvectors)
|
||||
{
|
||||
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL Java_kscience_kmath_torch_JTorch_requiresGrad(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).requires_grad();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setRequiresGrad(JNIEnv *, jclass, jlong tensor_handle, jboolean status)
|
||||
{
|
||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_detachFromGraph(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_kscience_kmath_torch_JTorch_autogradTensor(JNIEnv *, jclass, jlong value, jlong variable, jboolean retain_graph)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_autohessTensor(JNIEnv *, jclass, jlong value, jlong variable)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
|
||||
}
|
@ -7,8 +7,202 @@ class JTorch {
|
||||
}
|
||||
|
||||
public static native int getNumThreads();
|
||||
|
||||
public static native void setNumThreads(int numThreads);
|
||||
public static native long createTensor();
|
||||
public static native void printTensor(long tensorHandle);
|
||||
|
||||
public static native boolean cudaIsAvailable();
|
||||
|
||||
public static native void setSeed(int seed);
|
||||
|
||||
public static native long emptyTensor();
|
||||
|
||||
public static native long fromBlobDouble(double[] data, int[] shape, int device);
|
||||
|
||||
public static native long fromBlobFloat(float[] data, int[] shape, int device);
|
||||
|
||||
public static native long fromBlobLong(long[] data, int[] shape, int device);
|
||||
|
||||
public static native long fromBlobInt(int[] data, int[] shape, int device);
|
||||
|
||||
public static native long copyTensor(long tensorHandle);
|
||||
|
||||
public static native long copyToDevice(long tensorHandle, int device);
|
||||
|
||||
public static native long copyToDouble(long tensorHandle);
|
||||
|
||||
public static native long copyToFloat(long tensorHandle);
|
||||
|
||||
public static native long copyToLong(long tensorHandle);
|
||||
|
||||
public static native long copyToInt(long tensorHandle);
|
||||
|
||||
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||
|
||||
public static native long viewTensor(long tensorHandle, int[] shape);
|
||||
|
||||
public static native String tensorToString(long tensorHandle);
|
||||
|
||||
public static native void disposeTensor(long tensorHandle);
|
||||
}
|
||||
|
||||
public static native int getDim(long tensorHandle);
|
||||
|
||||
public static native int getNumel(long tensorHandle);
|
||||
|
||||
public static native int getShapeAt(long tensorHandle, int d);
|
||||
|
||||
public static native int getStrideAt(long tensorHandle, int d);
|
||||
|
||||
public static native int getDevice(long tensorHandle);
|
||||
|
||||
public static native double getItemDouble(long tensorHandle);
|
||||
|
||||
public static native float getItemFloat(long tensorHandle);
|
||||
|
||||
public static native long getItemLong(long tensorHandle);
|
||||
|
||||
public static native int getItemInt(long tensorHandle);
|
||||
|
||||
public static native double getDouble(long tensorHandle, int[] index);
|
||||
|
||||
public static native float getFloat(long tensorHandle, int[] index);
|
||||
|
||||
public static native long getLong(long tensorHandle, int[] index);
|
||||
|
||||
public static native int getInt(long tensorHandle, int[] index);
|
||||
|
||||
public static native void setDouble(long tensorHandle, int[] index, double value);
|
||||
|
||||
public static native void setFloat(long tensorHandle, int[] index, float value);
|
||||
|
||||
public static native void setLong(long tensorHandle, int[] index, long value);
|
||||
|
||||
public static native void setInt(long tensorHandle, int[] index, int value);
|
||||
|
||||
public static native long randDouble(int[] shape, int device);
|
||||
|
||||
public static native long randnDouble(int[] shape, int device);
|
||||
|
||||
public static native long randFloat(int[] shape, int device);
|
||||
|
||||
public static native long randnFloat(int[] shape, int device);
|
||||
|
||||
public static native long randintDouble(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randintFloat(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randintLong(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randintInt(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randLike(long tensorHandle);
|
||||
|
||||
public static native void randLikeAssign(long tensorHandle);
|
||||
|
||||
public static native long randnLike(long tensorHandle);
|
||||
|
||||
public static native void randnLikeAssign(long tensorHandle);
|
||||
|
||||
public static native long randintLike(long tensorHandle, long low, long high);
|
||||
|
||||
public static native void randintLikeAssign(long tensorHandle, long low, long high);
|
||||
|
||||
public static native long fullDouble(double value, int[] shape, int device);
|
||||
|
||||
public static native long fullFloat(float value, int[] shape, int device);
|
||||
|
||||
public static native long fullLong(long value, int[] shape, int device);
|
||||
|
||||
public static native long fullInt(int value, int[] shape, int device);
|
||||
|
||||
public static native long timesDouble(double value, long other);
|
||||
|
||||
public static native long timesFloat(float value, long other);
|
||||
|
||||
public static native long timesLong(long value, long other);
|
||||
|
||||
public static native long timesInt(int value, long other);
|
||||
|
||||
public static native void timesDoubleAssign(double value, long other);
|
||||
|
||||
public static native void timesFloatAssign(float value, long other);
|
||||
|
||||
public static native void timesLongAssign(long value, long other);
|
||||
|
||||
public static native void timesIntAssign(int value, long other);
|
||||
|
||||
public static native long plusDouble(double value, long other);
|
||||
|
||||
public static native long plusFloat(float value, long other);
|
||||
|
||||
public static native long plusLong(long value, long other);
|
||||
|
||||
public static native long plusInt(int value, long other);
|
||||
|
||||
public static native void plusDoubleAssign(double value, long other);
|
||||
|
||||
public static native void plusFloatAssign(float value, long other);
|
||||
|
||||
public static native void plusLongAssign(long value, long other);
|
||||
|
||||
public static native void plusIntAssign(int value, long other);
|
||||
|
||||
public static native long timesTensor(long lhs, long rhs);
|
||||
|
||||
public static native void timesTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long divTensor(long lhs, long rhs);
|
||||
|
||||
public static native void divTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long plusTensor(long lhs, long rhs);
|
||||
|
||||
public static native void plusTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long minusTensor(long lhs, long rhs);
|
||||
|
||||
public static native void minusTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long unaryMinus(long tensorHandle);
|
||||
|
||||
public static native long absTensor(long tensorHandle);
|
||||
|
||||
public static native void absTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long transposeTensor(long tensorHandle, int i, int j);
|
||||
|
||||
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
|
||||
|
||||
public static native long expTensor(long tensorHandle);
|
||||
|
||||
public static native void expTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long logTensor(long tensorHandle);
|
||||
|
||||
public static native void logTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long sumTensor(long tensorHandle);
|
||||
|
||||
public static native void sumTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long matmul(long lhs, long rhs);
|
||||
|
||||
public static native void matmulAssign(long lhs, long rhs);
|
||||
|
||||
public static native void matmulRightAssign(long lhs, long rhs);
|
||||
|
||||
public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2);
|
||||
|
||||
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
|
||||
|
||||
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors);
|
||||
|
||||
public static native boolean requiresGrad(long tensorHandle);
|
||||
|
||||
public static native void setRequiresGrad(long tensorHandle, boolean status);
|
||||
|
||||
public static native long detachFromGraph(long tensorHandle);
|
||||
|
||||
public static native long autogradTensor(long value, long variable, boolean retainGraph);
|
||||
|
||||
public static native long autohessTensor(long value, long variable);
|
||||
}
|
||||
|
@ -0,0 +1,4 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
public class TorchTensorJVM {
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
|
||||
public fun getNumThreads(): Int {
|
||||
return JTorch.getNumThreads()
|
||||
}
|
||||
|
||||
public fun setNumThreads(numThreads: Int): Unit {
|
||||
JTorch.setNumThreads(numThreads)
|
||||
}
|
||||
|
||||
|
||||
public fun runCPD(): Unit {
|
||||
val tensorHandle = JTorch.createTensor()
|
||||
JTorch.printTensor(tensorHandle)
|
||||
JTorch.disposeTensor(tensorHandle)
|
||||
}
|
@ -6,14 +6,10 @@ import kotlin.test.*
|
||||
class TestUtils {
|
||||
|
||||
@Test
|
||||
fun testSetNumThreads() {
|
||||
val numThreads = 2
|
||||
setNumThreads(numThreads)
|
||||
assertEquals(numThreads, getNumThreads())
|
||||
fun testJTorch() {
|
||||
val tensor = JTorch.fullInt(54, intArrayOf(3), 0)
|
||||
println(JTorch.tensorToString(tensor))
|
||||
JTorch.disposeTensor(tensor)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCPD() {
|
||||
runCPD()
|
||||
}
|
||||
}
|
@ -117,6 +117,13 @@ public sealed class TorchTensorAlgebraNative<
|
||||
sum_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
|
||||
wrap(randint_like(this.tensorHandle, low, high)!!)
|
||||
|
||||
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit {
|
||||
randint_like_assign(this.tensorHandle, low, high)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(copy_tensor(this.tensorHandle)!!)
|
||||
|
||||
@ -224,6 +231,9 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(randint_double(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(plus_double(this, other.tensorHandle)!!)
|
||||
|
||||
@ -256,18 +266,9 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randIntegral(low: Double, high: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(randint_double(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun TorchTensorReal.randIntegral(low: Double, high: Double): TorchTensorReal =
|
||||
wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!)
|
||||
|
||||
override fun TorchTensorReal.randIntegralAssign(low: Double, high: Double): Unit {
|
||||
randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
||||
@ -295,6 +296,9 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(randint_float(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(plus_float(this, other.tensorHandle)!!)
|
||||
|
||||
@ -328,15 +332,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randIntegral(low: Float, high: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(randint_float(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun TorchTensorFloat.randIntegral(low: Float, high: Float): TorchTensorFloat =
|
||||
wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!)
|
||||
|
||||
override fun TorchTensorFloat.randIntegralAssign(low: Float, high: Float): Unit {
|
||||
randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong())
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
@ -363,13 +358,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun TorchTensorLong.randIntegral(low: Long, high: Long): TorchTensorLong =
|
||||
wrap(randint_long_like(this.tensorHandle, low, high)!!)
|
||||
|
||||
override fun TorchTensorLong.randIntegralAssign(low: Long, high: Long): Unit {
|
||||
randint_long_like_assign(this.tensorHandle, low, high)
|
||||
}
|
||||
|
||||
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(plus_long(this, other.tensorHandle)!!)
|
||||
|
||||
@ -425,16 +413,9 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
return get_data_int(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randIntegral(low: Int, high: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun TorchTensorInt.randIntegral(low: Int, high: Int): TorchTensorInt =
|
||||
wrap(randint_int_like(this.tensorHandle, low, high)!!)
|
||||
|
||||
override fun TorchTensorInt.randIntegralAssign(low: Int, high: Int): Unit {
|
||||
randint_int_like_assign(this.tensorHandle, low, high)
|
||||
}
|
||||
|
||||
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(plus_int(this, other.tensorHandle)!!)
|
||||
|
||||
|
@ -22,6 +22,4 @@ internal class BenchmarkMatMul {
|
||||
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -6,7 +6,7 @@ import kotlin.test.*
|
||||
internal class TestUtils {
|
||||
@Test
|
||||
fun testSetNumThreads() {
|
||||
TorchTensorIntAlgebra {
|
||||
TorchTensorLongAlgebra {
|
||||
testingSetNumThreads()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user