Compare commits
43 Commits
dev
...
commandert
Author | SHA1 | Date | |
---|---|---|---|
|
7ee7daa1ab | ||
|
3d8390a130 | ||
|
bd736c6716 | ||
|
86528fc943 | ||
|
3e3d433a64 | ||
|
0a4e7acb4c | ||
|
6594ffc965 | ||
|
9b3258b06b | ||
|
120189fd89 | ||
|
40895e5936 | ||
|
c141c04e99 | ||
|
c9dfb6a08c | ||
|
391eb28cad | ||
|
17e6ebbc14 | ||
|
d599d1132b | ||
|
274d1a3105 | ||
|
b30ca920e1 | ||
|
6eb718f64a | ||
|
889691a122 | ||
|
e5205d5afd | ||
|
ed4ac2623d | ||
|
97ef57697d | ||
|
4f4fcba559 | ||
|
ca2082405a | ||
|
0b784474b4 | ||
|
fbb414731b | ||
|
7d25aa2834 | ||
|
ef570254e6 | ||
|
80f28dbcd5 | ||
|
ca3cca65ef | ||
|
39f3a87bbd | ||
|
524b1d80d1 | ||
|
8967691b7d | ||
|
7105331149 | ||
|
7894799e8e | ||
|
0cb2c3f0da | ||
|
9b1a958491 | ||
|
cfe93886ac | ||
|
fb9d612081 | ||
|
d97f8857a0 | ||
|
0fc29b40c5 | ||
|
32e4b68061 | ||
|
a229aaa6a4 |
@ -0,0 +1,7 @@
|
||||
package space.kscience.kmath.memory
|
||||
|
||||
public expect class DeferScope {
|
||||
public inline fun defer(crossinline block: () -> Unit)
|
||||
}
|
||||
|
||||
public expect inline fun <R> withDeferScope(block: DeferScope.() -> R): R
|
@ -0,0 +1,30 @@
|
||||
package space.kscience.kmath.memory
|
||||
|
||||
private typealias Deferred = () -> Unit
|
||||
|
||||
public actual class DeferScope {
|
||||
@PublishedApi
|
||||
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||
|
||||
@PublishedApi
|
||||
internal fun executeAllDeferred() {
|
||||
deferred.forEach(Deferred::invoke)
|
||||
deferred.clear()
|
||||
}
|
||||
|
||||
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||
deferred += {
|
||||
try {
|
||||
block()
|
||||
} catch (ignored: Throwable) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||
val ds = DeferScope()
|
||||
val r = ds.block()
|
||||
ds.executeAllDeferred()
|
||||
return r
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package space.kscience.kmath.memory
|
||||
|
||||
private typealias Deferred = () -> Unit
|
||||
|
||||
public actual class DeferScope {
|
||||
@PublishedApi
|
||||
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||
|
||||
@PublishedApi
|
||||
internal fun executeAllDeferred() {
|
||||
deferred.forEach(Deferred::invoke)
|
||||
deferred.clear()
|
||||
}
|
||||
|
||||
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||
deferred += {
|
||||
try {
|
||||
block()
|
||||
} catch (ignored: Throwable) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||
val ds = DeferScope()
|
||||
val r = ds.block()
|
||||
ds.executeAllDeferred()
|
||||
return r
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package space.kscience.kmath.memory
|
||||
|
||||
import kotlinx.cinterop.memScoped
|
||||
|
||||
public actual typealias DeferScope = kotlinx.cinterop.DeferScope
|
||||
|
||||
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R = memScoped(block)
|
66
kmath-torch/README.md
Normal file
66
kmath-torch/README.md
Normal file
@ -0,0 +1,66 @@
|
||||
# LibTorch extension (`kmath-torch`)
|
||||
|
||||
This is a `Kotlin/Native` & `JVM` module, with only `linuxX64` supported so far. The library wraps some of
|
||||
the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||
|
||||
## Installation
|
||||
|
||||
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
||||
|
||||
```
|
||||
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||||
```
|
||||
|
||||
This builds `ctorch` a C wrapper and `jtorch` a JNI wrapper for `LibTorch`, placed inside:
|
||||
|
||||
`~/.konan/third-party/kmath-torch-0.2.0/cpp-build`
|
||||
|
||||
You will have to link against it in your own project.
|
||||
|
||||
## Usage
|
||||
|
||||
Tensors are implemented over the `MutableNDStructure`. They can only be created through provided factory methods
|
||||
and require scoping within a `TensorAlgebra` instance:
|
||||
|
||||
```kotlin
|
||||
TorchTensorRealAlgebra {
|
||||
|
||||
val realTensor: TorchTensorReal = copyFromArray(
|
||||
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
||||
shape = intArrayOf(2, 5)
|
||||
)
|
||||
println(realTensor)
|
||||
|
||||
val gpuRealTensor: TorchTensorReal = copyFromArray(
|
||||
array = (1..8).map { it * 2.5 }.toList().toDoubleArray(),
|
||||
shape = intArrayOf(2, 2, 2),
|
||||
device = Device.CUDA(0)
|
||||
)
|
||||
println(gpuRealTensor)
|
||||
}
|
||||
```
|
||||
|
||||
High performance automatic differentiation engine is available:
|
||||
|
||||
```kotlin
|
||||
TorchTensorRealAlgebra {
|
||||
val dim = 10
|
||||
val device = Device.CPU //or Device.CUDA(0)
|
||||
|
||||
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
|
||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
|
||||
// expression to differentiate w.r.t. x evaluated at x = tensorX
|
||||
val expressionAtX = withGradAt(tensorX, { x ->
|
||||
0.5 * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9
|
||||
})
|
||||
|
||||
// value of the gradient at x = tensorX
|
||||
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||
// value of the hessian at x = tensorX
|
||||
val hessianAtX = expressionAtX hess tensorX
|
||||
}
|
||||
```
|
||||
Contributed by [Roland Grinis](https://github.com/rgrit91)
|
397
kmath-torch/api/kmath-torch.api
Normal file
397
kmath-torch/api/kmath-torch.api
Normal file
@ -0,0 +1,397 @@
|
||||
public abstract class space/kscience/kmath/torch/Device {
|
||||
public static final field Companion Lspace/kscience/kmath/torch/Device$Companion;
|
||||
public final fun toInt ()I
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/Device$CPU : space/kscience/kmath/torch/Device {
|
||||
public static final field INSTANCE Lspace/kscience/kmath/torch/Device$CPU;
|
||||
public fun toString ()Ljava/lang/String;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/Device$CUDA : space/kscience/kmath/torch/Device {
|
||||
public fun <init> (I)V
|
||||
public final fun component1 ()I
|
||||
public final fun copy (I)Lspace/kscience/kmath/torch/Device$CUDA;
|
||||
public static synthetic fun copy$default (Lspace/kscience/kmath/torch/Device$CUDA;IILjava/lang/Object;)Lspace/kscience/kmath/torch/Device$CUDA;
|
||||
public fun equals (Ljava/lang/Object;)Z
|
||||
public final fun getIndex ()I
|
||||
public fun hashCode ()I
|
||||
public fun toString ()Ljava/lang/String;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/Device$Companion {
|
||||
public final fun fromInt (I)Lspace/kscience/kmath/torch/Device;
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/torch/TorchTensor : space/kscience/kmath/tensors/TensorStructure {
|
||||
public abstract fun elements ()Lkotlin/sequences/Sequence;
|
||||
public abstract fun getDevice ()Lspace/kscience/kmath/torch/Device;
|
||||
public abstract fun getSize ()I
|
||||
public abstract fun getStrides ()[I
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensor$DefaultImpls {
|
||||
public static fun elements (Lspace/kscience/kmath/torch/TorchTensor;)Lkotlin/sequences/Sequence;
|
||||
public static fun getDimension (Lspace/kscience/kmath/torch/TorchTensor;)I
|
||||
public static fun value (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/torch/TorchTensorAlgebra : space/kscience/kmath/tensors/TensorAlgebra {
|
||||
public abstract fun copy (Lspace/kscience/kmath/torch/TorchTensor;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public abstract fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public abstract fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||
public abstract fun copyToDevice (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public abstract fun cudaAvailable ()Z
|
||||
public abstract fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public abstract fun getChecks ()Z
|
||||
public abstract fun getNumThreads ()I
|
||||
public abstract fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public abstract fun randIntegral (Lspace/kscience/kmath/torch/TorchTensor;JJ)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public abstract fun randIntegralAssign (Lspace/kscience/kmath/torch/TorchTensor;JJ)V
|
||||
public abstract fun setChecks (Z)V
|
||||
public abstract fun setNumThreads (I)V
|
||||
public abstract fun setSeed (I)V
|
||||
public abstract fun swap (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorAlgebra$DefaultImpls {
|
||||
public static synthetic fun copyFromArray$default (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public static synthetic fun randIntegral$default (Lspace/kscience/kmath/torch/TorchTensorAlgebra;JJ[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
}
|
||||
|
||||
public abstract class space/kscience/kmath/torch/TorchTensorAlgebraJVM : space/kscience/kmath/torch/TorchTensorAlgebra {
|
||||
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||
public synthetic fun abs (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun abs (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun absAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun absAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public synthetic fun copy (Lspace/kscience/kmath/torch/TorchTensor;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun copy (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun copyToDevice (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun copyToDevice (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public fun cudaAvailable ()Z
|
||||
public synthetic fun diagonalEmbedding (Lspace/kscience/kmath/tensors/TensorStructure;III)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun diagonalEmbedding (Lspace/kscience/kmath/torch/TorchTensorJVM;III)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun dot (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun dot (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun dotAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun dotAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public synthetic fun dotRightAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun dotRightAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public fun getChecks ()Z
|
||||
public fun getNumThreads ()I
|
||||
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun minus (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun plus (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public synthetic fun randIntegral (Lspace/kscience/kmath/torch/TorchTensor;JJ)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun randIntegral (Lspace/kscience/kmath/torch/TorchTensorJVM;JJ)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun randIntegralAssign (Lspace/kscience/kmath/torch/TorchTensor;JJ)V
|
||||
public fun randIntegralAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;JJ)V
|
||||
public fun setChecks (Z)V
|
||||
public fun setNumThreads (I)V
|
||||
public fun setSeed (I)V
|
||||
public synthetic fun sum (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun sum (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun sumAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun sumAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public synthetic fun swap (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||
public fun swap (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun times (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||
public synthetic fun transpose (Lspace/kscience/kmath/tensors/TensorStructure;II)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun transpose (Lspace/kscience/kmath/torch/TorchTensorJVM;II)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun transposeAssign (Lspace/kscience/kmath/tensors/TensorStructure;II)V
|
||||
public fun transposeAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;II)V
|
||||
public synthetic fun unaryMinus (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun unaryMinus (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
public synthetic fun view (Lspace/kscience/kmath/tensors/TensorStructure;[I)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun view (Lspace/kscience/kmath/torch/TorchTensorJVM;[I)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorAlgebraJVMKt {
|
||||
public static final fun TorchTensorFloatAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||
public static final fun TorchTensorIntAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||
public static final fun TorchTensorLongAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||
public static final fun TorchTensorRealAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorAlgebraKt {
|
||||
public static final fun checkDeviceCompatible (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||
public static final fun checkDotOperation (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||
public static final fun checkLinearOperation (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||
public static final fun withChecks (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lkotlin/jvm/functions/Function1;)V
|
||||
public static final fun withGradAt (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;Lspace/kscience/kmath/torch/TorchTensorOverField;Lkotlin/jvm/functions/Function2;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorFloat : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||
public fun get ([I)Ljava/lang/Float;
|
||||
public synthetic fun get ([I)Ljava/lang/Object;
|
||||
public fun item ()Ljava/lang/Float;
|
||||
public synthetic fun item ()Ljava/lang/Object;
|
||||
public fun set ([IF)V
|
||||
public synthetic fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorFloatAlgebra : space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebraJVM {
|
||||
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun copyFromArray ([F[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorFloat;)[F
|
||||
public fun full (F[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun minus (FLspace/kscience/kmath/torch/TorchTensorFloat;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun minus (Lspace/kscience/kmath/torch/TorchTensorFloat;F)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorFloat;F)V
|
||||
public fun plus (FLspace/kscience/kmath/torch/TorchTensorFloat;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun plus (Lspace/kscience/kmath/torch/TorchTensorFloat;F)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorFloat;F)V
|
||||
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun times (FLspace/kscience/kmath/torch/TorchTensorFloat;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun times (Lspace/kscience/kmath/torch/TorchTensorFloat;F)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorFloat;F)V
|
||||
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorInt : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||
public fun get ([I)Ljava/lang/Integer;
|
||||
public synthetic fun get ([I)Ljava/lang/Object;
|
||||
public fun item ()Ljava/lang/Integer;
|
||||
public synthetic fun item ()Ljava/lang/Object;
|
||||
public fun set ([II)V
|
||||
public synthetic fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorIntAlgebra : space/kscience/kmath/torch/TorchTensorAlgebraJVM {
|
||||
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun copyFromArray ([I[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorInt;)[I
|
||||
public fun full (I[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun minus (ILspace/kscience/kmath/torch/TorchTensorInt;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun minus (Lspace/kscience/kmath/torch/TorchTensorInt;I)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorInt;I)V
|
||||
public fun plus (ILspace/kscience/kmath/torch/TorchTensorInt;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun plus (Lspace/kscience/kmath/torch/TorchTensorInt;I)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorInt;I)V
|
||||
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public fun times (ILspace/kscience/kmath/torch/TorchTensorInt;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun times (Lspace/kscience/kmath/torch/TorchTensorInt;I)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorInt;I)V
|
||||
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
}
|
||||
|
||||
public abstract class space/kscience/kmath/torch/TorchTensorJVM : space/kscience/kmath/torch/TorchTensorMemoryHolder, space/kscience/kmath/torch/TorchTensor {
|
||||
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;JLkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||
protected fun close ()V
|
||||
public final fun copyToDouble ()Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public final fun copyToFloat ()Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||
public final fun copyToInt ()Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||
public final fun copyToLong ()Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public fun elements ()Lkotlin/sequences/Sequence;
|
||||
public fun getDevice ()Lspace/kscience/kmath/torch/Device;
|
||||
public fun getDimension ()I
|
||||
public fun getShape ()[I
|
||||
public fun getSize ()I
|
||||
public fun getStrides ()[I
|
||||
public fun toString ()Ljava/lang/String;
|
||||
public fun value ()Ljava/lang/Object;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorLong : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||
public fun get ([I)Ljava/lang/Long;
|
||||
public synthetic fun get ([I)Ljava/lang/Object;
|
||||
public fun item ()Ljava/lang/Long;
|
||||
public synthetic fun item ()Ljava/lang/Object;
|
||||
public fun set ([IJ)V
|
||||
public synthetic fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorLongAlgebra : space/kscience/kmath/torch/TorchTensorAlgebraJVM {
|
||||
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun copyFromArray ([J[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorLong;)[J
|
||||
public fun full (J[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun minus (JLspace/kscience/kmath/torch/TorchTensorLong;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun minus (Lspace/kscience/kmath/torch/TorchTensorLong;J)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorLong;J)V
|
||||
public fun plus (JLspace/kscience/kmath/torch/TorchTensorLong;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun plus (Lspace/kscience/kmath/torch/TorchTensorLong;J)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorLong;J)V
|
||||
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public fun times (JLspace/kscience/kmath/torch/TorchTensorLong;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun times (Lspace/kscience/kmath/torch/TorchTensorLong;J)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorLong;J)V
|
||||
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
}
|
||||
|
||||
public abstract class space/kscience/kmath/torch/TorchTensorMemoryHolder {
|
||||
protected abstract fun close ()V
|
||||
public fun equals (Ljava/lang/Object;)Z
|
||||
public final fun getScope ()Lspace/kscience/kmath/memory/DeferScope;
|
||||
public fun hashCode ()I
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/torch/TorchTensorOverField : space/kscience/kmath/torch/TorchTensor {
|
||||
public abstract fun getRequiresGrad ()Z
|
||||
public abstract fun setRequiresGrad (Z)V
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorOverField$DefaultImpls {
|
||||
public static fun elements (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lkotlin/sequences/Sequence;
|
||||
public static fun getDimension (Lspace/kscience/kmath/torch/TorchTensorOverField;)I
|
||||
public static fun value (Lspace/kscience/kmath/torch/TorchTensorOverField;)Ljava/lang/Object;
|
||||
}
|
||||
|
||||
public abstract class space/kscience/kmath/torch/TorchTensorOverFieldJVM : space/kscience/kmath/torch/TorchTensorJVM, space/kscience/kmath/torch/TorchTensorOverField {
|
||||
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;JLkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||
public fun getRequiresGrad ()Z
|
||||
public fun setRequiresGrad (Z)V
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra : space/kscience/kmath/tensors/TensorPartialDivisionAlgebra, space/kscience/kmath/torch/TorchTensorAlgebra {
|
||||
public abstract fun detachFromGraph (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;Z)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun hess (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun randNormal (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun randNormalAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||
public abstract fun randUniform (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public abstract fun randUniformAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra$DefaultImpls {
|
||||
public static fun grad (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public static synthetic fun grad$default (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;ZILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public static synthetic fun randNormal$default (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public static synthetic fun randUniform$default (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
}
|
||||
|
||||
public abstract class space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebraJVM : space/kscience/kmath/torch/TorchTensorAlgebraJVM, space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra {
|
||||
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||
public synthetic fun detachFromGraph (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun detachFromGraph (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun div (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun div (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun divAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun divAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||
public synthetic fun exp (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun exp (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun expAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun expAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||
public synthetic fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public synthetic fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;Z)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun grad (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public fun grad (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Z)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun hess (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun hess (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun log (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun log (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun logAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||
public fun logAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||
public synthetic fun randNormal (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun randNormal (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun randNormalAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||
public fun randNormalAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||
public synthetic fun randUniform (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun randUniform (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||
public synthetic fun randUniformAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||
public fun randUniformAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||
public synthetic fun svd (Lspace/kscience/kmath/tensors/TensorStructure;)Lkotlin/Triple;
|
||||
public fun svd (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lkotlin/Triple;
|
||||
public synthetic fun symEig (Lspace/kscience/kmath/tensors/TensorStructure;Z)Lkotlin/Pair;
|
||||
public fun symEig (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Z)Lkotlin/Pair;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorReal : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||
public fun get ([I)Ljava/lang/Double;
|
||||
public synthetic fun get ([I)Ljava/lang/Object;
|
||||
public fun item ()Ljava/lang/Double;
|
||||
public synthetic fun item ()Ljava/lang/Object;
|
||||
public fun set ([ID)V
|
||||
public synthetic fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/torch/TorchTensorRealAlgebra : space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebraJVM {
|
||||
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun copyFromArray ([D[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorReal;)[D
|
||||
public fun full (D[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun minus (DLspace/kscience/kmath/torch/TorchTensorReal;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun minus (Lspace/kscience/kmath/torch/TorchTensorReal;D)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorReal;D)V
|
||||
public fun plus (DLspace/kscience/kmath/torch/TorchTensorReal;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun plus (Lspace/kscience/kmath/torch/TorchTensorReal;D)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorReal;D)V
|
||||
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||
public fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public fun times (DLspace/kscience/kmath/torch/TorchTensorReal;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||
public fun times (Lspace/kscience/kmath/torch/TorchTensorReal;D)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorReal;D)V
|
||||
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||
}
|
||||
|
194
kmath-torch/build.gradle.kts
Normal file
194
kmath-torch/build.gradle.kts
Normal file
@ -0,0 +1,194 @@
|
||||
import de.undercouch.gradle.tasks.download.Download
|
||||
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
|
||||
import org.gradle.api.JavaVersion.VERSION_11
|
||||
|
||||
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.mpp")
|
||||
id("de.undercouch.download")
|
||||
}
|
||||
|
||||
java {
|
||||
sourceCompatibility = VERSION_11
|
||||
targetCompatibility = VERSION_11
|
||||
}
|
||||
|
||||
val home = System.getProperty("user.home")
|
||||
val javaHome = System.getProperty("java.home")
|
||||
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
||||
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
||||
val cppSources = projectDir.resolve("src/cppMain")
|
||||
|
||||
val cudaHome: String? = System.getenv("CUDA_HOME")
|
||||
val cudaDefault = file("/usr/local/cuda").exists()
|
||||
val cudaFound = cudaHome?.isNotEmpty() ?: false or cudaDefault
|
||||
|
||||
val cmakeArchive = "cmake-3.19.2-Linux-x86_64"
|
||||
val torchArchive = "libtorch"
|
||||
|
||||
val cmakeCmd = "$thirdPartyDir/$cmakeArchive/bin/cmake"
|
||||
val ninjaCmd = "$thirdPartyDir/ninja"
|
||||
|
||||
val downloadCMake by tasks.registering(Download::class) {
|
||||
val tarFile = "$cmakeArchive.tar.gz"
|
||||
src("https://github.com/Kitware/CMake/releases/download/v3.19.2/$tarFile")
|
||||
dest(File(thirdPartyDir, tarFile))
|
||||
overwrite(false)
|
||||
}
|
||||
|
||||
val downloadNinja by tasks.registering(Download::class) {
|
||||
src("https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip")
|
||||
dest(File(thirdPartyDir, "ninja-linux.zip"))
|
||||
overwrite(false)
|
||||
}
|
||||
|
||||
val downloadTorch by tasks.registering(Download::class) {
|
||||
val abiMeta = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2B"
|
||||
val cudaUrl = "https://download.pytorch.org/libtorch/cu110/${abiMeta}cu110.zip"
|
||||
val cpuUrl = "https://download.pytorch.org/libtorch/cpu/${abiMeta}cpu.zip"
|
||||
val url = if (cudaFound) cudaUrl else cpuUrl
|
||||
src(url)
|
||||
dest(File(thirdPartyDir, "$torchArchive.zip"))
|
||||
overwrite(false)
|
||||
}
|
||||
|
||||
val extractCMake by tasks.registering(Copy::class) {
|
||||
dependsOn(downloadCMake)
|
||||
from(tarTree(resources.gzip(downloadCMake.get().dest)))
|
||||
into(thirdPartyDir)
|
||||
}
|
||||
|
||||
val extractTorch by tasks.registering(Copy::class) {
|
||||
dependsOn(downloadTorch)
|
||||
from(zipTree(downloadTorch.get().dest))
|
||||
into(thirdPartyDir)
|
||||
}
|
||||
|
||||
val extractNinja by tasks.registering(Copy::class) {
|
||||
dependsOn(downloadNinja)
|
||||
from(zipTree(downloadNinja.get().dest))
|
||||
into(thirdPartyDir)
|
||||
}
|
||||
|
||||
val configureCpp by tasks.registering {
|
||||
dependsOn(extractCMake)
|
||||
dependsOn(extractNinja)
|
||||
dependsOn(extractTorch)
|
||||
onlyIf { !file(cppBuildDir).exists() }
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(thirdPartyDir)
|
||||
commandLine("mkdir", "-p", cppBuildDir)
|
||||
}
|
||||
exec {
|
||||
workingDir(cppBuildDir)
|
||||
commandLine(
|
||||
cmakeCmd,
|
||||
cppSources,
|
||||
"-GNinja",
|
||||
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
|
||||
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
|
||||
"-DJAVA_HOME=$javaHome",
|
||||
"-DCMAKE_BUILD_TYPE=Release"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val cleanCppBuild by tasks.registering {
|
||||
onlyIf { file(cppBuildDir).exists() }
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(thirdPartyDir)
|
||||
commandLine("rm", "-rf", cppBuildDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val buildCpp by tasks.registering {
|
||||
dependsOn(configureCpp)
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(cppBuildDir)
|
||||
commandLine(cmakeCmd, "--build", ".", "--config", "Release")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val generateJNIHeader by tasks.registering {
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(projectDir.resolve("src/jvmMain/java/space/kscience/kmath/torch"))
|
||||
commandLine("$javaHome/bin/javac", "-h", cppSources.resolve("include") , "JTorch.java")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kotlin {
|
||||
explicitApiWarning()
|
||||
|
||||
jvm {
|
||||
withJava()
|
||||
}
|
||||
|
||||
val nativeTarget = linuxX64("native")
|
||||
nativeTarget.apply {
|
||||
binaries {
|
||||
all {
|
||||
linkerOpts(
|
||||
"-L$cppBuildDir",
|
||||
"-Wl,-rpath=$cppBuildDir",
|
||||
"-lctorch"
|
||||
)
|
||||
optimized = true
|
||||
debuggable = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val main by nativeTarget.compilations.getting {
|
||||
cinterops {
|
||||
val libctorch by creating {
|
||||
includeDirs(cppSources.resolve("include"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val test by nativeTarget.compilations.getting
|
||||
|
||||
sourceSets {
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
|
||||
val nativeMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val native: KotlinNativeTarget by kotlin.targets
|
||||
tasks[native.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
||||
.dependsOn(buildCpp)
|
||||
|
||||
tasks["jvmProcessResources"].dependsOn(buildCpp)
|
||||
|
||||
tasks {
|
||||
withType<Test>{
|
||||
systemProperty("java.library.path", cppBuildDir)
|
||||
}
|
||||
}
|
||||
|
||||
// No JS implementation
|
||||
project.gradle.startParameter.excludedTaskNames.add("jsTest")
|
||||
project.gradle.startParameter.excludedTaskNames.add("jsBrowserTest")
|
@ -0,0 +1,24 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
|
||||
public sealed class Device {
|
||||
public object CPU: Device() {
|
||||
override fun toString(): String {
|
||||
return "CPU"
|
||||
}
|
||||
}
|
||||
public data class CUDA(val index: Int): Device()
|
||||
public fun toInt(): Int {
|
||||
when(this) {
|
||||
is CPU -> return 0
|
||||
is CUDA -> return this.index + 1
|
||||
}
|
||||
}
|
||||
public companion object {
|
||||
public fun fromInt(deviceInt: Int): Device {
|
||||
return if (deviceInt == 0) CPU else CUDA(
|
||||
deviceInt - 1
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.tensors.*
|
||||
|
||||
public interface TorchTensor<T> : TensorStructure<T> {
|
||||
|
||||
public val strides: IntArray
|
||||
public val size: Int
|
||||
public val device: Device
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||
if (dimension == 0) {
|
||||
return emptySequence()
|
||||
}
|
||||
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
|
||||
return indices.map { it to get(it) }
|
||||
}
|
||||
}
|
||||
|
||||
public interface TorchTensorOverField<T>: TorchTensor<T>
|
||||
{
|
||||
public var requiresGrad: Boolean
|
||||
}
|
||||
|
@ -0,0 +1,105 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.tensors.*
|
||||
|
||||
public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>> :
|
||||
TensorAlgebra<T, TorchTensorType> {
|
||||
|
||||
public fun getNumThreads(): Int
|
||||
public fun setNumThreads(numThreads: Int): Unit
|
||||
public fun cudaAvailable(): Boolean
|
||||
public fun setSeed(seed: Int): Unit
|
||||
|
||||
public var checks: Boolean
|
||||
|
||||
public fun copyFromArray(
|
||||
array: PrimitiveArrayType,
|
||||
shape: IntArray,
|
||||
device: Device = Device.CPU
|
||||
): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.copyToArray(): PrimitiveArrayType
|
||||
|
||||
public fun full(value: T, shape: IntArray, device: Device): TorchTensorType
|
||||
|
||||
public fun randIntegral(
|
||||
low: Long, high: Long, shape: IntArray,
|
||||
device: Device = Device.CPU
|
||||
): TorchTensorType
|
||||
public fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType
|
||||
public fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit
|
||||
|
||||
public fun TorchTensorType.copy(): TorchTensorType
|
||||
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType
|
||||
public infix fun TorchTensorType.swap(other: TorchTensorType)
|
||||
|
||||
}
|
||||
|
||||
public interface TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>> :
|
||||
TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>, TensorPartialDivisionAlgebra<T, TorchTensorType> {
|
||||
|
||||
public fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||
public fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||
public fun TorchTensorType.randUniform(): TorchTensorType
|
||||
public fun TorchTensorType.randUniformAssign(): Unit
|
||||
public fun TorchTensorType.randNormal(): TorchTensorType
|
||||
public fun TorchTensorType.randNormalAssign(): Unit
|
||||
|
||||
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType
|
||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||
this.grad(variable, false)
|
||||
|
||||
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType
|
||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType
|
||||
}
|
||||
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.withChecks(block: TorchTensorAlgebraType.() -> Unit): Unit {
|
||||
val state = this.checks
|
||||
this.checks = true
|
||||
this.block()
|
||||
this.checks = state
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkDeviceCompatible(
|
||||
a: TorchTensorType, b: TorchTensorType
|
||||
): Unit =
|
||||
check(a.device == b.device) {
|
||||
"Tensors must be on the same device"
|
||||
}
|
||||
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkLinearOperation(
|
||||
a: TorchTensorType,
|
||||
b: TorchTensorType
|
||||
) {
|
||||
if (a.isNotValue() and b.isNotValue()) {
|
||||
checkDeviceCompatible(a, b)
|
||||
checkShapeCompatible(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
||||
checkDeviceCompatible(a, b)
|
||||
checkDot(a,b)
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||
TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorDivisionAlgebraType.withGradAt(
|
||||
tensor: TorchTensorType,
|
||||
block: TorchTensorDivisionAlgebraType.(TorchTensorType) -> TorchTensorType
|
||||
): TorchTensorType {
|
||||
tensor.requiresGrad = true
|
||||
return this.block(tensor)
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.memory.DeferScope
|
||||
|
||||
public abstract class TorchTensorMemoryHolder internal constructor(
|
||||
public val scope: DeferScope
|
||||
) {
|
||||
init {
|
||||
scope.defer(::close)
|
||||
}
|
||||
protected abstract fun close(): Unit
|
||||
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
override fun hashCode(): Int = 0
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.time.measureTime
|
||||
|
||||
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkMatMul(
|
||||
scale: Int,
|
||||
numWarmUp: Int,
|
||||
numIter: Int,
|
||||
fieldName: String,
|
||||
device: Device = Device.CPU
|
||||
): Unit {
|
||||
println("Benchmarking $scale x $scale $fieldName matrices on $device: ")
|
||||
setSeed(SEED)
|
||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||
}
|
||||
|
@ -0,0 +1,132 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.time.measureTime
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkRand(
|
||||
samples: Int,
|
||||
numWarmUp: Int,
|
||||
numIter: Int,
|
||||
device: Device,
|
||||
distName: String,
|
||||
initBock: TorchTensorAlgebraType.(IntArray, Device) -> TorchTensorType,
|
||||
runBlock: TorchTensorAlgebraType.(TorchTensorType) -> Unit
|
||||
): Unit{
|
||||
println("Benchmarking generation of $samples $distName samples on $device: ")
|
||||
setSeed(SEED)
|
||||
val shape = intArrayOf(samples)
|
||||
val tensor = this.initBock(shape,device)
|
||||
repeat(numWarmUp) { this.runBlock(tensor) }
|
||||
val measuredTime = measureTime { repeat(numIter) { this.runBlock(tensor) } }
|
||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkRandNormal(
|
||||
samples: Int,
|
||||
numWarmUp: Int,
|
||||
numIter: Int,
|
||||
device: Device = Device.CPU): Unit{
|
||||
benchmarkRand(
|
||||
samples,
|
||||
numWarmUp,
|
||||
numIter,
|
||||
device,
|
||||
"Normal",
|
||||
{sh, dc -> randNormal(shape = sh, device = dc)},
|
||||
{ten -> ten.randNormalAssign() }
|
||||
)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkRandUniform(
|
||||
samples: Int,
|
||||
numWarmUp: Int,
|
||||
numIter: Int,
|
||||
device: Device = Device.CPU): Unit{
|
||||
benchmarkRand(
|
||||
samples,
|
||||
numWarmUp,
|
||||
numIter,
|
||||
device,
|
||||
"Uniform",
|
||||
{sh, dc -> randUniform(shape = sh, device = dc)},
|
||||
{ten -> ten.randUniformAssign() }
|
||||
)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkRandIntegral(
|
||||
samples: Int,
|
||||
numWarmUp: Int,
|
||||
numIter: Int,
|
||||
device: Device = Device.CPU): Unit{
|
||||
benchmarkRand(
|
||||
samples,
|
||||
numWarmUp,
|
||||
numIter,
|
||||
device,
|
||||
"integer [0,100]",
|
||||
{sh, dc -> randIntegral(0, 100, shape = sh, device = dc)},
|
||||
{ten -> ten.randIntegralAssign(0, 100) }
|
||||
)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkingRand1(): Unit {
|
||||
benchmarkRandNormal(10, 10, 100000)
|
||||
benchmarkRandUniform(10, 10, 100000)
|
||||
benchmarkRandIntegral(10, 10, 100000)
|
||||
if(cudaAvailable()) {
|
||||
benchmarkRandNormal(10, 10, 100000, device = Device.CUDA(0))
|
||||
benchmarkRandUniform(10, 10, 100000, device = Device.CUDA(0))
|
||||
benchmarkRandIntegral(10, 10, 100000, device = Device.CUDA(0))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkingRand3(): Unit {
|
||||
benchmarkRandNormal(1000, 10, 10000)
|
||||
benchmarkRandUniform(1000, 10, 10000)
|
||||
benchmarkRandIntegral(1000, 10, 10000)
|
||||
if(cudaAvailable()) {
|
||||
benchmarkRandNormal(1000, 10, 100000, device = Device.CUDA(0))
|
||||
benchmarkRandUniform(1000, 10, 100000, device = Device.CUDA(0))
|
||||
benchmarkRandIntegral(1000, 10, 100000, device = Device.CUDA(0))
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkingRand5(): Unit {
|
||||
benchmarkRandNormal(100000, 5, 100)
|
||||
benchmarkRandUniform(100000, 5, 100)
|
||||
benchmarkRandIntegral(100000, 5, 100)
|
||||
if(cudaAvailable()){
|
||||
benchmarkRandNormal(100000, 10, 100000, device = Device.CUDA(0))
|
||||
benchmarkRandUniform(100000, 10, 100000, device = Device.CUDA(0))
|
||||
benchmarkRandIntegral(100000, 10, 100000, device = Device.CUDA(0))
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.benchmarkingRand7(): Unit {
|
||||
benchmarkRandNormal(10000000, 3, 20)
|
||||
benchmarkRandUniform(10000000, 3, 20)
|
||||
benchmarkRandIntegral(10000000, 3, 20)
|
||||
if(cudaAvailable()){
|
||||
benchmarkRandNormal(10000000, 10, 10000, device = Device.CUDA(0))
|
||||
benchmarkRandUniform(10000000, 10, 10000, device = Device.CUDA(0))
|
||||
benchmarkRandIntegral(10000000, 10, 10000, device = Device.CUDA(0))
|
||||
}
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingAutoGrad(device: Device = Device.CPU): Unit {
|
||||
setSeed(SEED)
|
||||
val dim = 3
|
||||
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
|
||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
|
||||
val expressionAtX = withGradAt(tensorX, { x ->
|
||||
0.5f * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9f
|
||||
})
|
||||
|
||||
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||
val hessianAtX = expressionAtX hess tensorX
|
||||
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
||||
|
||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value() +
|
||||
(hessianAtX - tensorSigma).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingBatchedAutoGrad(device: Device = Device.CPU): Unit {
|
||||
setSeed(SEED)
|
||||
val batch = intArrayOf(2)
|
||||
val dim = 2
|
||||
val tensorX = randNormal(shape = batch + intArrayOf(1, dim), device = device)
|
||||
val randFeatures = randNormal(shape = batch + intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(-2, -1)
|
||||
val tensorMu = randNormal(shape = batch + intArrayOf(1, dim), device = device)
|
||||
|
||||
val expressionAtX = withGradAt(tensorX, { x ->
|
||||
val xt = x.transpose(-1, -2)
|
||||
0.5f * (x dot (tensorSigma dot xt)) + (tensorMu dot xt) + 58.2f
|
||||
})
|
||||
expressionAtX.sumAssign()
|
||||
|
||||
val gradientAtX = expressionAtX grad tensorX
|
||||
val expectedGradientAtX = (tensorX dot tensorSigma) + tensorMu
|
||||
|
||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
|
@ -0,0 +1,53 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingCopying(device: Device = Device.CPU): Unit {
|
||||
val array = (1..24).map { 10f * it * it }.toFloatArray()
|
||||
val shape = intArrayOf(2, 3, 4)
|
||||
val tensor = copyFromArray(array, shape = shape, device = device)
|
||||
val copyOfTensor = tensor.copy()
|
||||
tensor[intArrayOf(1, 2, 3)] = 0.1f
|
||||
assertTrue(copyOfTensor.copyToArray() contentEquals array)
|
||||
assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)])
|
||||
if(device != Device.CPU){
|
||||
val normalCpu = randNormal(intArrayOf(2, 3))
|
||||
val normalGpu = normalCpu.copyToDevice(device)
|
||||
assertTrue(normalCpu.copyToArray() contentEquals normalGpu.copyToArray())
|
||||
|
||||
val uniformGpu = randUniform(intArrayOf(3,2),device)
|
||||
val uniformCpu = uniformGpu.copyToDevice(Device.CPU)
|
||||
assertTrue(uniformGpu.copyToArray() contentEquals uniformCpu.copyToArray())
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingRequiresGrad(): Unit {
|
||||
val tensor = randNormal(intArrayOf(3))
|
||||
assertTrue(!tensor.requiresGrad)
|
||||
tensor.requiresGrad = true
|
||||
assertTrue(tensor.requiresGrad)
|
||||
tensor.requiresGrad = false
|
||||
assertTrue(!tensor.requiresGrad)
|
||||
tensor.requiresGrad = true
|
||||
val detachedTensor = tensor.detachFromGraph()
|
||||
assertTrue(!detachedTensor.requiresGrad)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensor<Int>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<Int, IntArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingViewWithNoCopy(device: Device = Device.CPU) {
|
||||
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
|
||||
val viewTensor = tensor.view(intArrayOf(2, 3))
|
||||
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
|
||||
viewTensor[intArrayOf(0, 0)] = 10
|
||||
assertEquals(tensor[intArrayOf(0)], 10)
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,133 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.linear.RealMatrixContext
|
||||
import space.kscience.kmath.linear.Matrix
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.math.*
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingScalarProduct(device: Device = Device.CPU): Unit {
|
||||
val lhs = randUniform(shape = intArrayOf(3), device = device)
|
||||
val rhs = randUniform(shape = intArrayOf(3), device = device)
|
||||
val product = lhs dot rhs
|
||||
var expected = 0.0
|
||||
lhs.elements().forEach {
|
||||
expected += it.second * rhs[it.first]
|
||||
}
|
||||
assertTrue(abs(expected - product.value()) < TOLERANCE)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingMatrixMultiplication(device: Device = Device.CPU): Unit {
|
||||
setSeed(SEED)
|
||||
|
||||
val lhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||
val rhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||
val product = lhsTensor dot rhsTensor
|
||||
|
||||
val expected: Matrix<Double> = RealMatrixContext {
|
||||
val lhs = produce(3, 3) { i, j -> lhsTensor[intArrayOf(i, j)] }
|
||||
val rhs = produce(3, 3) { i, j -> rhsTensor[intArrayOf(i, j)] }
|
||||
lhs dot rhs
|
||||
}
|
||||
|
||||
val lhsTensorCopy = lhsTensor.copy()
|
||||
val rhsTensorCopy = rhsTensor.copy()
|
||||
|
||||
lhsTensorCopy dotAssign rhsTensor
|
||||
lhsTensor dotRightAssign rhsTensorCopy
|
||||
|
||||
var error = 0.0
|
||||
product.elements().forEach {
|
||||
error += abs(expected[it.first] - it.second) +
|
||||
abs(expected[it.first] - lhsTensorCopy[it.first]) +
|
||||
abs(expected[it.first] - rhsTensorCopy[it.first])
|
||||
}
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingLinearStructure(device: Device = Device.CPU): Unit {
|
||||
|
||||
val shape = intArrayOf(3)
|
||||
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||
val tensorC = full(value = 789.3, shape = shape, device = device)
|
||||
val tensorD = full(value = -72.9, shape = shape, device = device)
|
||||
val tensorE = full(value = 553.1, shape = shape, device = device)
|
||||
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
|
||||
val expected = copyFromArray(
|
||||
array = (1..3).map {
|
||||
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
|
||||
}
|
||||
.toDoubleArray(),
|
||||
shape = shape,
|
||||
device = device
|
||||
)
|
||||
|
||||
val assignResult = full(value = 0.0, shape = shape, device = device)
|
||||
tensorA *= 15.8
|
||||
tensorB *= 1.5
|
||||
tensorB *= -tensorD
|
||||
tensorC *= 0.02
|
||||
tensorC /= tensorE
|
||||
assignResult += tensorA
|
||||
assignResult -= tensorB
|
||||
assignResult += tensorC
|
||||
assignResult += -39.4
|
||||
|
||||
val error = (expected - result).abs().sum().value() +
|
||||
(expected - assignResult).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingTensorTransformations(device: Device = Device.CPU): Unit {
|
||||
setSeed(SEED)
|
||||
val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||
val result = tensor.exp().log()
|
||||
val assignResult = tensor.copy()
|
||||
assignResult.transposeAssign(0, 1)
|
||||
assignResult.expAssign()
|
||||
assignResult.logAssign()
|
||||
assignResult.transposeAssign(0, 1)
|
||||
val error = tensor - result
|
||||
error.absAssign()
|
||||
error.sumAssign()
|
||||
error += (tensor - assignResult).abs().sum()
|
||||
assertTrue(error.value() < TOLERANCE)
|
||||
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingBatchedSVD(device: Device = Device.CPU): Unit {
|
||||
val tensor = randNormal(shape = intArrayOf(7, 5, 3), device = device)
|
||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||
val error = tensor - (tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)))
|
||||
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingBatchedSymEig(device: Device = Device.CPU): Unit {
|
||||
val tensor = randNormal(shape = intArrayOf(5, 5), device = device)
|
||||
val tensorSigma = tensor + tensor.transpose(-2, -1)
|
||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||
val error = tensorSigma - (tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)))
|
||||
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,39 @@
|
||||
@file:Suppress("NOTHING_TO_INLINE")
|
||||
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal val SEED = 987654
|
||||
internal val TOLERANCE = 1e-6
|
||||
|
||||
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.withCuda(block: TorchTensorAlgebraType.(Device) -> Unit): Unit {
|
||||
this.block(Device.CPU)
|
||||
if (cudaAvailable()) this.block(Device.CUDA(0))
|
||||
}
|
||||
|
||||
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingSetNumThreads(): Unit {
|
||||
val numThreads = 2
|
||||
setNumThreads(numThreads)
|
||||
assertEquals(numThreads, getNumThreads())
|
||||
}
|
||||
|
||||
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||
TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit {
|
||||
setSeed(SEED)
|
||||
val integral = randIntegral(0, 100, IntArray(0), device = device).value()
|
||||
val normal = randNormal(IntArray(0), device = device).value()
|
||||
val uniform = randUniform(IntArray(0), device = device).value()
|
||||
setSeed(SEED)
|
||||
val nextIntegral = randIntegral(0, 100, IntArray(0), device = device).value()
|
||||
val nextNormal = randNormal(IntArray(0), device = device).value()
|
||||
val nextUniform = randUniform(IntArray(0), device = device).value()
|
||||
assertEquals(normal, nextNormal)
|
||||
assertEquals(uniform, nextUniform)
|
||||
assertEquals(integral, nextIntegral)
|
||||
}
|
34
kmath-torch/src/cppMain/CMakeLists.txt
Normal file
34
kmath-torch/src/cppMain/CMakeLists.txt
Normal file
@ -0,0 +1,34 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
|
||||
project(CTorch LANGUAGES C CXX)
|
||||
|
||||
# Require C++17
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
# Build configuration
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||
endif()
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
find_package(JNI REQUIRED)
|
||||
|
||||
add_library(ctorch SHARED src/ctorch.cc)
|
||||
target_include_directories(ctorch PRIVATE include)
|
||||
target_link_libraries(ctorch PRIVATE torch)
|
||||
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||
|
||||
add_library(jtorch SHARED src/jtorch.cc)
|
||||
target_include_directories(jtorch PRIVATE include ${JNI_INCLUDE_DIRS})
|
||||
target_link_libraries(jtorch PRIVATE torch)
|
||||
target_compile_options(jtorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
|
||||
install(TARGETS ctorch
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
install(TARGETS jtorch LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
160
kmath-torch/src/cppMain/include/ctorch.h
Normal file
160
kmath-torch/src/cppMain/include/ctorch.h
Normal file
@ -0,0 +1,160 @@
|
||||
#ifndef CTORCH
|
||||
#define CTORCH
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
{
|
||||
#endif
|
||||
|
||||
typedef void *TorchTensorHandle;
|
||||
|
||||
int get_num_threads();
|
||||
|
||||
void set_num_threads(int num_threads);
|
||||
|
||||
bool cuda_is_available();
|
||||
|
||||
void set_seed(int seed);
|
||||
|
||||
TorchTensorHandle empty_tensor();
|
||||
|
||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
||||
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle);
|
||||
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle);
|
||||
TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim);
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
||||
void dispose_char(char *ptr);
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle);
|
||||
int get_numel(TorchTensorHandle tensor_handle);
|
||||
int get_shape_at(TorchTensorHandle tensor_handle, int d);
|
||||
int get_stride_at(TorchTensorHandle tensor_handle, int d);
|
||||
int get_device(TorchTensorHandle tensor_handle);
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||
|
||||
double get_item_double(TorchTensorHandle tensor_handle);
|
||||
float get_item_float(TorchTensorHandle tensor_handle);
|
||||
long get_item_long(TorchTensorHandle tensor_handle);
|
||||
int get_item_int(TorchTensorHandle tensor_handle);
|
||||
|
||||
double get_double(TorchTensorHandle tensor_handle, int *index);
|
||||
float get_float(TorchTensorHandle tensor_handle, int *index);
|
||||
long get_long(TorchTensorHandle tensor_handle, int *index);
|
||||
int get_int(TorchTensorHandle tensor_handle, int *index);
|
||||
void set_double(TorchTensorHandle tensor_handle, int *index, double value);
|
||||
void set_float(TorchTensorHandle tensor_handle, int *index, float value);
|
||||
void set_long(TorchTensorHandle tensor_handle, int *index, long value);
|
||||
void set_int(TorchTensorHandle tensor_handle, int *index, int value);
|
||||
|
||||
TorchTensorHandle rand_double(int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randn_double(int *shape, int shape_size, int device);
|
||||
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle rand_like(TorchTensorHandle tensor_handle);
|
||||
void rand_like_assign(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle randn_like(TorchTensorHandle tensor_handle);
|
||||
void randn_like_assign(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high);
|
||||
void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high);
|
||||
|
||||
|
||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle times_double(double value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_float(float value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_int(int value, TorchTensorHandle other);
|
||||
|
||||
void times_double_assign(double value, TorchTensorHandle other);
|
||||
void times_float_assign(float value, TorchTensorHandle other);
|
||||
void times_long_assign(long value, TorchTensorHandle other);
|
||||
void times_int_assign(int value, TorchTensorHandle other);
|
||||
|
||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
|
||||
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
|
||||
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
|
||||
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
|
||||
|
||||
void plus_double_assign(double value, TorchTensorHandle other);
|
||||
void plus_float_assign(float value, TorchTensorHandle other);
|
||||
void plus_long_assign(long value, TorchTensorHandle other);
|
||||
void plus_int_assign(int value, TorchTensorHandle other);
|
||||
|
||||
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||
void abs_tensor_assign(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
||||
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j);
|
||||
|
||||
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle);
|
||||
void exp_tensor_assign(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle);
|
||||
void log_tensor_assign(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||
void sum_tensor_assign(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
|
||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
|
||||
|
||||
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle U_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle);
|
||||
|
||||
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle,
|
||||
bool eigenvectors);
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph);
|
||||
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||
TorchTensorHandle autohess_tensor_given_grad(TorchTensorHandle value, TorchTensorHandle variable, TorchTensorHandle gradient);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif //CTORCH
|
@ -0,0 +1,813 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class space_kscience_kmath_torch_JTorch */
|
||||
|
||||
#ifndef _Included_space_kscience_kmath_torch_JTorch
|
||||
#define _Included_space_kscience_kmath_torch_JTorch
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getNumThreads
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumThreads
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: setNumThreads
|
||||
* Signature: (I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setNumThreads
|
||||
(JNIEnv *, jclass, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: cudaIsAvailable
|
||||
* Signature: ()Z
|
||||
*/
|
||||
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_cudaIsAvailable
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: setSeed
|
||||
* Signature: (I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setSeed
|
||||
(JNIEnv *, jclass, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: emptyTensor
|
||||
* Signature: ()J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_emptyTensor
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobDouble
|
||||
* Signature: ([D[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobDouble
|
||||
(JNIEnv *, jclass, jdoubleArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobFloat
|
||||
* Signature: ([F[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobFloat
|
||||
(JNIEnv *, jclass, jfloatArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobLong
|
||||
* Signature: ([J[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobLong
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fromBlobInt
|
||||
* Signature: ([I[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobInt
|
||||
(JNIEnv *, jclass, jintArray, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: copyTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: copyToDevice
|
||||
* Signature: (JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDevice
|
||||
(JNIEnv *, jclass, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: copyToDouble
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDouble
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: copyToFloat
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToFloat
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: copyToLong
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToLong
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: copyToInt
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToInt
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: swapTensors
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_swapTensors
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: viewTensor
|
||||
* Signature: (J[I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_viewTensor
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: tensorToString
|
||||
* Signature: (J)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_space_kscience_kmath_torch_JTorch_tensorToString
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: disposeTensor
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_disposeTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getDim
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDim
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getNumel
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumel
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getShapeAt
|
||||
* Signature: (JI)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getShapeAt
|
||||
(JNIEnv *, jclass, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getStrideAt
|
||||
* Signature: (JI)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getStrideAt
|
||||
(JNIEnv *, jclass, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getDevice
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDevice
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getItemDouble
|
||||
* Signature: (J)D
|
||||
*/
|
||||
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getItemDouble
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getItemFloat
|
||||
* Signature: (J)F
|
||||
*/
|
||||
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getItemFloat
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getItemLong
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getItemLong
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getItemInt
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getItemInt
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getDouble
|
||||
* Signature: (J[I)D
|
||||
*/
|
||||
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getDouble
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getFloat
|
||||
* Signature: (J[I)F
|
||||
*/
|
||||
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getFloat
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getLong
|
||||
* Signature: (J[I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getLong
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: getInt
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getInt
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: setDouble
|
||||
* Signature: (J[ID)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setDouble
|
||||
(JNIEnv *, jclass, jlong, jintArray, jdouble);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: setFloat
|
||||
* Signature: (J[IF)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setFloat
|
||||
(JNIEnv *, jclass, jlong, jintArray, jfloat);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: setLong
|
||||
* Signature: (J[IJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setLong
|
||||
(JNIEnv *, jclass, jlong, jintArray, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: setInt
|
||||
* Signature: (J[II)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setInt
|
||||
(JNIEnv *, jclass, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randDouble
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randDouble
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randnDouble
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnDouble
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randFloat
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randFloat
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randnFloat
|
||||
* Signature: ([II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnFloat
|
||||
(JNIEnv *, jclass, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randintDouble
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintDouble
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randintFloat
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintFloat
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randintLong
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLong
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randintInt
|
||||
* Signature: (JJ[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintInt
|
||||
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randLike
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randLike
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randLikeAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randLikeAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randnLike
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnLike
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randnLikeAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randnLikeAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randintLike
|
||||
* Signature: (JJJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLike
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: randintLikeAssign
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randintLikeAssign
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fullDouble
|
||||
* Signature: (D[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullDouble
|
||||
(JNIEnv *, jclass, jdouble, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fullFloat
|
||||
* Signature: (F[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullFloat
|
||||
(JNIEnv *, jclass, jfloat, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fullLong
|
||||
* Signature: (J[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullLong
|
||||
(JNIEnv *, jclass, jlong, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: fullInt
|
||||
* Signature: (I[II)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullInt
|
||||
(JNIEnv *, jclass, jint, jintArray, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesDouble
|
||||
* Signature: (DJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesDouble
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesFloat
|
||||
* Signature: (FJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloat
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesLong
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesLong
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesInt
|
||||
* Signature: (IJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesInt
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesDoubleAssign
|
||||
* Signature: (DJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesDoubleAssign
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesFloatAssign
|
||||
* Signature: (FJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloatAssign
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesLongAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesLongAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesIntAssign
|
||||
* Signature: (IJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesIntAssign
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusDouble
|
||||
* Signature: (DJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusDouble
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusFloat
|
||||
* Signature: (FJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloat
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusLong
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusLong
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusInt
|
||||
* Signature: (IJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusInt
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusDoubleAssign
|
||||
* Signature: (DJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusDoubleAssign
|
||||
(JNIEnv *, jclass, jdouble, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusFloatAssign
|
||||
* Signature: (FJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloatAssign
|
||||
(JNIEnv *, jclass, jfloat, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusLongAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusLongAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusIntAssign
|
||||
* Signature: (IJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusIntAssign
|
||||
(JNIEnv *, jclass, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: timesTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: divTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_divTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: divTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_divTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: plusTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: minusTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: minusTensorAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: unaryMinus
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_unaryMinus
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: absTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_absTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: absTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_absTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: transposeTensor
|
||||
* Signature: (JII)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensor
|
||||
(JNIEnv *, jclass, jlong, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: transposeTensorAssign
|
||||
* Signature: (JII)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: expTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_expTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: expTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_expTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: logTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_logTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: logTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_logTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: sumTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: sumTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: matmul
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_matmul
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: matmulAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: matmulRightAssign
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulRightAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: diagEmbed
|
||||
* Signature: (JIII)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_diagEmbed
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: svdTensor
|
||||
* Signature: (JJJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_svdTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: symeigTensor
|
||||
* Signature: (JJJZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_symeigTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong, jboolean);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: requiresGrad
|
||||
* Signature: (J)Z
|
||||
*/
|
||||
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_requiresGrad
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: setRequiresGrad
|
||||
* Signature: (JZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setRequiresGrad
|
||||
(JNIEnv *, jclass, jlong, jboolean);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: detachFromGraph
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_detachFromGraph
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: autogradTensor
|
||||
* Signature: (JJZ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_autogradTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jboolean);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_torch_JTorch
|
||||
* Method: autohessTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_autohessTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
151
kmath-torch/src/cppMain/include/utils.hh
Normal file
151
kmath-torch/src/cppMain/include/utils.hh
Normal file
@ -0,0 +1,151 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace ctorch
|
||||
{
|
||||
|
||||
using TorchTensorHandle = void *;
|
||||
|
||||
template <typename Dtype>
|
||||
inline c10::ScalarType dtype()
|
||||
{
|
||||
return torch::kFloat64;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline c10::ScalarType dtype<float>()
|
||||
{
|
||||
return torch::kFloat32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline c10::ScalarType dtype<long>()
|
||||
{
|
||||
return torch::kInt64;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline c10::ScalarType dtype<int>()
|
||||
{
|
||||
return torch::kInt32;
|
||||
}
|
||||
|
||||
template <typename Handle>
|
||||
inline torch::Tensor &cast(const Handle &tensor_handle)
|
||||
{
|
||||
return *static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
|
||||
}
|
||||
|
||||
template <typename Handle>
|
||||
inline void dispose_tensor(const Handle &tensor_handle)
|
||||
{
|
||||
delete static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
|
||||
}
|
||||
|
||||
inline std::string tensor_to_string(const torch::Tensor &tensor)
|
||||
{
|
||||
std::stringstream bufrep;
|
||||
bufrep << tensor;
|
||||
return bufrep.str();
|
||||
}
|
||||
|
||||
inline char *tensor_to_char(const torch::Tensor &tensor)
|
||||
{
|
||||
auto rep = tensor_to_string(tensor);
|
||||
char *crep = (char *)malloc(rep.length() + 1);
|
||||
std::strcpy(crep, rep.c_str());
|
||||
return crep;
|
||||
}
|
||||
|
||||
inline int device_to_int(const torch::Tensor &tensor)
|
||||
{
|
||||
return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index();
|
||||
}
|
||||
|
||||
inline torch::Device int_to_device(int device_int)
|
||||
{
|
||||
return (device_int == 0) ? torch::kCPU : torch::Device(torch::kCUDA, device_int - 1);
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> to_vec_int(int *arr, int arr_size)
|
||||
{
|
||||
auto vec = std::vector<int64_t>(arr_size);
|
||||
vec.assign(arr, arr + arr_size);
|
||||
return vec;
|
||||
}
|
||||
|
||||
inline std::vector<at::indexing::TensorIndex> to_index(int *arr, int arr_size)
|
||||
{
|
||||
std::vector<at::indexing::TensorIndex> index;
|
||||
for (int i = 0; i < arr_size; i++)
|
||||
{
|
||||
index.emplace_back(arr[i]);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor from_blob(Dtype *data, const std::vector<int64_t> &shape, torch::Device device, bool copy)
|
||||
{
|
||||
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
|
||||
}
|
||||
|
||||
template <typename NumType>
|
||||
inline NumType get(const torch::Tensor &tensor, int *index)
|
||||
{
|
||||
return tensor.index(to_index(index, tensor.dim())).item<NumType>();
|
||||
}
|
||||
|
||||
template <typename NumType>
|
||||
inline void set(const torch::Tensor &tensor, int *index, NumType value)
|
||||
{
|
||||
tensor.index(to_index(index, tensor.dim())) = value;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor randn(const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::randn(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor rand(const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor randint(long low, long high, const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor full(Dtype value, const std::vector<int64_t> &shape, torch::Device device)
|
||||
{
|
||||
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
inline torch::Tensor hessian(const torch::Tensor &value, const torch::Tensor &variable)
|
||||
{
|
||||
auto nelem = variable.numel();
|
||||
auto hess = value.new_zeros({nelem, nelem});
|
||||
auto grad = torch::autograd::grad({value}, {variable}, {}, torch::nullopt, true)[0].view(nelem);
|
||||
int i = 0;
|
||||
for (int j = 0; j < nelem; j++)
|
||||
{
|
||||
auto row = grad[j].requires_grad()
|
||||
? torch::autograd::grad({grad[i]}, {variable}, {}, true, true, true)[0].view(nelem).slice(0, j, nelem)
|
||||
: grad[j].new_zeros(nelem - j);
|
||||
hess[i].slice(0, i, nelem).add_(row.type_as(hess));
|
||||
i++;
|
||||
}
|
||||
auto ndim = variable.dim();
|
||||
auto sizes = variable.sizes().data();
|
||||
auto shape = std::vector<int64_t>(ndim);
|
||||
shape.assign(sizes, sizes + ndim);
|
||||
shape.reserve(2 * ndim);
|
||||
std::copy_n(shape.begin(), ndim, std::back_inserter(shape));
|
||||
return (hess + torch::triu(hess, 1).t()).view(shape);
|
||||
}
|
||||
|
||||
} // namespace ctorch
|
465
kmath-torch/src/cppMain/src/ctorch.cc
Normal file
465
kmath-torch/src/cppMain/src/ctorch.cc
Normal file
@ -0,0 +1,465 @@
|
||||
#include <torch/torch.h>
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "ctorch.h"
|
||||
#include "utils.hh"
|
||||
|
||||
int get_num_threads()
|
||||
{
|
||||
return torch::get_num_threads();
|
||||
}
|
||||
|
||||
void set_num_threads(int num_threads)
|
||||
{
|
||||
torch::set_num_threads(num_threads);
|
||||
}
|
||||
|
||||
bool cuda_is_available()
|
||||
{
|
||||
return torch::cuda::is_available();
|
||||
}
|
||||
|
||||
void set_seed(int seed)
|
||||
{
|
||||
torch::manual_seed(seed);
|
||||
}
|
||||
|
||||
TorchTensorHandle empty_tensor()
|
||||
{
|
||||
return new torch::Tensor;
|
||||
}
|
||||
|
||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).clone());
|
||||
}
|
||||
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
|
||||
}
|
||||
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
|
||||
}
|
||||
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
|
||||
{
|
||||
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||
}
|
||||
TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim)));
|
||||
}
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::tensor_to_char(ctorch::cast(tensor_handle));
|
||||
}
|
||||
void dispose_char(char *ptr)
|
||||
{
|
||||
free(ptr);
|
||||
}
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::dispose_tensor(tensor_handle);
|
||||
}
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).dim();
|
||||
}
|
||||
int get_numel(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).numel();
|
||||
}
|
||||
int get_shape_at(TorchTensorHandle tensor_handle, int d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).size(d);
|
||||
}
|
||||
int get_stride_at(TorchTensorHandle tensor_handle, int d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).stride(d);
|
||||
}
|
||||
int get_device(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||
}
|
||||
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||
}
|
||||
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||
}
|
||||
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||
}
|
||||
|
||||
double get_item_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<double>();
|
||||
}
|
||||
float get_item_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<float>();
|
||||
}
|
||||
long get_item_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<long>();
|
||||
}
|
||||
int get_item_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<int>();
|
||||
}
|
||||
|
||||
double get_double(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<double>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
float get_float(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<float>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
long get_long(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<long>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
int get_int(TorchTensorHandle tensor_handle, int *index)
|
||||
{
|
||||
return ctorch::get<int>(ctorch::cast(tensor_handle), index);
|
||||
}
|
||||
void set_double(TorchTensorHandle tensor_handle, int *index, double value)
|
||||
{
|
||||
ctorch::set<double>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
void set_float(TorchTensorHandle tensor_handle, int *index, float value)
|
||||
{
|
||||
ctorch::set<float>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
void set_long(TorchTensorHandle tensor_handle, int *index, long value)
|
||||
{
|
||||
ctorch::set<long>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
void set_int(TorchTensorHandle tensor_handle, int *index, int value)
|
||||
{
|
||||
ctorch::set<int>(ctorch::cast(tensor_handle), index, value);
|
||||
}
|
||||
|
||||
TorchTensorHandle rand_double(int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::rand<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle randn_double(int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randn<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle rand_float(int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle randn_float(int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randn<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<double>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<float>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle rand_like(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(torch::rand_like(ctorch::cast(tensor_handle)));
|
||||
}
|
||||
void rand_like_assign(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::rand_like(ctorch::cast(tensor_handle));
|
||||
}
|
||||
TorchTensorHandle randn_like(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(torch::randn_like(ctorch::cast(tensor_handle)));
|
||||
}
|
||||
void randn_like_assign(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
|
||||
}
|
||||
TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high)
|
||||
{
|
||||
return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
|
||||
}
|
||||
void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
|
||||
}
|
||||
|
||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<float>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<long>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
TorchTensorHandle times_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
TorchTensorHandle times_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
TorchTensorHandle times_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
void times_double_assign(double value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
void times_float_assign(float value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
void times_long_assign(long value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
void times_int_assign(int value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
void plus_double_assign(double value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_float_assign(float value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_long_assign(long value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_int_assign(int value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||
}
|
||||
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) *= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
|
||||
}
|
||||
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||
}
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||
}
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(-ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||
}
|
||||
void abs_tensor_assign(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
|
||||
}
|
||||
|
||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
||||
}
|
||||
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
|
||||
}
|
||||
|
||||
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).exp());
|
||||
}
|
||||
void exp_tensor_assign(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
|
||||
}
|
||||
|
||||
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).log());
|
||||
}
|
||||
void log_tensor_assign(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
|
||||
}
|
||||
|
||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||
}
|
||||
void sum_tensor_assign(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||
}
|
||||
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||
}
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
|
||||
{
|
||||
return new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
|
||||
}
|
||||
|
||||
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle U_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle)
|
||||
{
|
||||
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
|
||||
ctorch::cast(U_handle) = U;
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle,
|
||||
bool eigenvectors)
|
||||
{
|
||||
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).requires_grad();
|
||||
}
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||
{
|
||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||
}
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||
}
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph)
|
||||
{
|
||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
|
||||
}
|
||||
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||
{
|
||||
return new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
|
||||
}
|
572
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
572
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
@ -0,0 +1,572 @@
|
||||
#include <torch/torch.h>
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "space_kscience_kmath_torch_JTorch.h"
|
||||
#include "utils.hh"
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumThreads(JNIEnv *, jclass)
|
||||
{
|
||||
return torch::get_num_threads();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *, jclass, jint num_threads)
|
||||
{
|
||||
torch::set_num_threads(num_threads);
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_cudaIsAvailable(JNIEnv *, jclass)
|
||||
{
|
||||
return torch::cuda::is_available();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setSeed(JNIEnv *, jclass, jint seed)
|
||||
{
|
||||
torch::manual_seed(seed);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_emptyTensor(JNIEnv *, jclass)
|
||||
{
|
||||
return (long)new torch::Tensor;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobDouble(JNIEnv *env, jclass, jdoubleArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<double>(
|
||||
env->GetDoubleArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobFloat(JNIEnv *env, jclass, jfloatArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<float>(
|
||||
env->GetFloatArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobLong(JNIEnv *env, jclass, jlongArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<long>(
|
||||
env->GetLongArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobInt(JNIEnv *env, jclass, jintArray data, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::from_blob<int>(
|
||||
env->GetIntArrayElements(data, 0),
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device), true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).clone());
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDevice(JNIEnv *, jclass, jlong tensor_handle, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDouble(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToFloat(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToLong(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToInt(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_swapTensors(JNIEnv *, jclass, jlong lhs_handle, jlong rhs_handle)
|
||||
{
|
||||
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_viewTensor(JNIEnv *env, jclass, jlong tensor_handle, jintArray shape)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::cast(tensor_handle).view(ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape))));
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL Java_space_kscience_kmath_torch_JTorch_tensorToString(JNIEnv *env, jclass, jlong tensor_handle)
|
||||
{
|
||||
return env->NewStringUTF(ctorch::tensor_to_string(ctorch::cast(tensor_handle)).c_str());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::dispose_tensor(tensor_handle);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDim(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).dim();
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumel(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).numel();
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getShapeAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).size(d);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getStrideAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).stride(d);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDevice(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getItemDouble(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<double>();
|
||||
}
|
||||
|
||||
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getItemFloat(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<float>();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getItemLong(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<long>();
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getItemInt(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).item<int>();
|
||||
}
|
||||
|
||||
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||
{
|
||||
return ctorch::get<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jdouble value)
|
||||
{
|
||||
ctorch::set<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jfloat value)
|
||||
{
|
||||
ctorch::set<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jlong value)
|
||||
{
|
||||
ctorch::set<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jint value)
|
||||
{
|
||||
ctorch::set<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randDouble(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::rand<double>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnDouble(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randn<double>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randFloat(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::rand<float>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnFloat(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randn<float>(
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintDouble(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<double>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintFloat(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<float>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLong(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<long>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintInt(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::randint<int>(low, high,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randLike(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::rand_like(ctorch::cast(tensor_handle)));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::rand_like(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnLike(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::randn_like(ctorch::cast(tensor_handle)));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randnLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLike(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randintLikeAssign(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullDouble(JNIEnv *env, jclass, jdouble value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<double>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullFloat(JNIEnv *env, jclass, jfloat value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<float>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullLong(JNIEnv *env, jclass, jlong value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<long>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullInt(JNIEnv *env, jclass, jint value, jintArray shape, jint device)
|
||||
{
|
||||
return (long)new torch::Tensor(
|
||||
ctorch::full<int>(
|
||||
value,
|
||||
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||
ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesDouble(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloat(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesLong(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesInt(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesLongAssign(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesIntAssign(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusDouble(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloat(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusLong(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusInt(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusLongAssign(JNIEnv *, jclass, jlong value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusIntAssign(JNIEnv *, jclass, jint value, jlong other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) *= ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_divTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_divTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_unaryMinus(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(-ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_absTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_absTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensor(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensorAssign(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_expTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).exp());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_expTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_logTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).log());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_logTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_matmul(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulRightAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||
{
|
||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_space_kscience_kmath_torch_JTorch_diagEmbed(JNIEnv *, jclass, jlong diags_handle, jint offset, jint dim1, jint dim2)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_space_kscience_kmath_torch_JTorch_svdTensor(JNIEnv *, jclass, jlong tensor_handle, jlong U_handle, jlong S_handle, jlong V_handle)
|
||||
{
|
||||
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
|
||||
ctorch::cast(U_handle) = U;
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_space_kscience_kmath_torch_JTorch_symeigTensor(JNIEnv *, jclass, jlong tensor_handle, jlong S_handle, jlong V_handle, jboolean eigenvectors)
|
||||
{
|
||||
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_requiresGrad(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).requires_grad();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setRequiresGrad(JNIEnv *, jclass, jlong tensor_handle, jboolean status)
|
||||
{
|
||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_detachFromGraph(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_space_kscience_kmath_torch_JTorch_autogradTensor(JNIEnv *, jclass, jlong value, jlong variable, jboolean retain_graph)
|
||||
{
|
||||
return (long)new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_autohessTensor(JNIEnv *, jclass, jlong value, jlong variable)
|
||||
{
|
||||
return (long)new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
|
||||
}
|
@ -0,0 +1,208 @@
|
||||
package space.kscience.kmath.torch;
|
||||
|
||||
class JTorch {
|
||||
|
||||
static {
|
||||
System.loadLibrary("jtorch");
|
||||
}
|
||||
|
||||
public static native int getNumThreads();
|
||||
|
||||
public static native void setNumThreads(int numThreads);
|
||||
|
||||
public static native boolean cudaIsAvailable();
|
||||
|
||||
public static native void setSeed(int seed);
|
||||
|
||||
public static native long emptyTensor();
|
||||
|
||||
public static native long fromBlobDouble(double[] data, int[] shape, int device);
|
||||
|
||||
public static native long fromBlobFloat(float[] data, int[] shape, int device);
|
||||
|
||||
public static native long fromBlobLong(long[] data, int[] shape, int device);
|
||||
|
||||
public static native long fromBlobInt(int[] data, int[] shape, int device);
|
||||
|
||||
public static native long copyTensor(long tensorHandle);
|
||||
|
||||
public static native long copyToDevice(long tensorHandle, int device);
|
||||
|
||||
public static native long copyToDouble(long tensorHandle);
|
||||
|
||||
public static native long copyToFloat(long tensorHandle);
|
||||
|
||||
public static native long copyToLong(long tensorHandle);
|
||||
|
||||
public static native long copyToInt(long tensorHandle);
|
||||
|
||||
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||
|
||||
public static native long viewTensor(long tensorHandle, int[] shape);
|
||||
|
||||
public static native String tensorToString(long tensorHandle);
|
||||
|
||||
public static native void disposeTensor(long tensorHandle);
|
||||
|
||||
public static native int getDim(long tensorHandle);
|
||||
|
||||
public static native int getNumel(long tensorHandle);
|
||||
|
||||
public static native int getShapeAt(long tensorHandle, int d);
|
||||
|
||||
public static native int getStrideAt(long tensorHandle, int d);
|
||||
|
||||
public static native int getDevice(long tensorHandle);
|
||||
|
||||
public static native double getItemDouble(long tensorHandle);
|
||||
|
||||
public static native float getItemFloat(long tensorHandle);
|
||||
|
||||
public static native long getItemLong(long tensorHandle);
|
||||
|
||||
public static native int getItemInt(long tensorHandle);
|
||||
|
||||
public static native double getDouble(long tensorHandle, int[] index);
|
||||
|
||||
public static native float getFloat(long tensorHandle, int[] index);
|
||||
|
||||
public static native long getLong(long tensorHandle, int[] index);
|
||||
|
||||
public static native int getInt(long tensorHandle, int[] index);
|
||||
|
||||
public static native void setDouble(long tensorHandle, int[] index, double value);
|
||||
|
||||
public static native void setFloat(long tensorHandle, int[] index, float value);
|
||||
|
||||
public static native void setLong(long tensorHandle, int[] index, long value);
|
||||
|
||||
public static native void setInt(long tensorHandle, int[] index, int value);
|
||||
|
||||
public static native long randDouble(int[] shape, int device);
|
||||
|
||||
public static native long randnDouble(int[] shape, int device);
|
||||
|
||||
public static native long randFloat(int[] shape, int device);
|
||||
|
||||
public static native long randnFloat(int[] shape, int device);
|
||||
|
||||
public static native long randintDouble(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randintFloat(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randintLong(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randintInt(long low, long high, int[] shape, int device);
|
||||
|
||||
public static native long randLike(long tensorHandle);
|
||||
|
||||
public static native void randLikeAssign(long tensorHandle);
|
||||
|
||||
public static native long randnLike(long tensorHandle);
|
||||
|
||||
public static native void randnLikeAssign(long tensorHandle);
|
||||
|
||||
public static native long randintLike(long tensorHandle, long low, long high);
|
||||
|
||||
public static native void randintLikeAssign(long tensorHandle, long low, long high);
|
||||
|
||||
public static native long fullDouble(double value, int[] shape, int device);
|
||||
|
||||
public static native long fullFloat(float value, int[] shape, int device);
|
||||
|
||||
public static native long fullLong(long value, int[] shape, int device);
|
||||
|
||||
public static native long fullInt(int value, int[] shape, int device);
|
||||
|
||||
public static native long timesDouble(double value, long other);
|
||||
|
||||
public static native long timesFloat(float value, long other);
|
||||
|
||||
public static native long timesLong(long value, long other);
|
||||
|
||||
public static native long timesInt(int value, long other);
|
||||
|
||||
public static native void timesDoubleAssign(double value, long other);
|
||||
|
||||
public static native void timesFloatAssign(float value, long other);
|
||||
|
||||
public static native void timesLongAssign(long value, long other);
|
||||
|
||||
public static native void timesIntAssign(int value, long other);
|
||||
|
||||
public static native long plusDouble(double value, long other);
|
||||
|
||||
public static native long plusFloat(float value, long other);
|
||||
|
||||
public static native long plusLong(long value, long other);
|
||||
|
||||
public static native long plusInt(int value, long other);
|
||||
|
||||
public static native void plusDoubleAssign(double value, long other);
|
||||
|
||||
public static native void plusFloatAssign(float value, long other);
|
||||
|
||||
public static native void plusLongAssign(long value, long other);
|
||||
|
||||
public static native void plusIntAssign(int value, long other);
|
||||
|
||||
public static native long timesTensor(long lhs, long rhs);
|
||||
|
||||
public static native void timesTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long divTensor(long lhs, long rhs);
|
||||
|
||||
public static native void divTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long plusTensor(long lhs, long rhs);
|
||||
|
||||
public static native void plusTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long minusTensor(long lhs, long rhs);
|
||||
|
||||
public static native void minusTensorAssign(long lhs, long rhs);
|
||||
|
||||
public static native long unaryMinus(long tensorHandle);
|
||||
|
||||
public static native long absTensor(long tensorHandle);
|
||||
|
||||
public static native void absTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long transposeTensor(long tensorHandle, int i, int j);
|
||||
|
||||
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
|
||||
|
||||
public static native long expTensor(long tensorHandle);
|
||||
|
||||
public static native void expTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long logTensor(long tensorHandle);
|
||||
|
||||
public static native void logTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long sumTensor(long tensorHandle);
|
||||
|
||||
public static native void sumTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long matmul(long lhs, long rhs);
|
||||
|
||||
public static native void matmulAssign(long lhs, long rhs);
|
||||
|
||||
public static native void matmulRightAssign(long lhs, long rhs);
|
||||
|
||||
public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2);
|
||||
|
||||
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
|
||||
|
||||
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors);
|
||||
|
||||
public static native boolean requiresGrad(long tensorHandle);
|
||||
|
||||
public static native void setRequiresGrad(long tensorHandle, boolean status);
|
||||
|
||||
public static native long detachFromGraph(long tensorHandle);
|
||||
|
||||
public static native long autogradTensor(long value, long variable, boolean retainGraph);
|
||||
|
||||
public static native long autohessTensor(long value, long variable);
|
||||
}
|
@ -0,0 +1,390 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.memory.DeferScope
|
||||
import space.kscience.kmath.memory.withDeferScope
|
||||
import space.kscience.kmath.tensors.*
|
||||
|
||||
public sealed class TorchTensorAlgebraJVM<
|
||||
T,
|
||||
PrimitiveArrayType,
|
||||
TorchTensorType : TorchTensorJVM<T>> constructor(
|
||||
internal val scope: DeferScope
|
||||
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
override fun getNumThreads(): Int {
|
||||
return JTorch.getNumThreads()
|
||||
}
|
||||
|
||||
override fun setNumThreads(numThreads: Int): Unit {
|
||||
JTorch.setNumThreads(numThreads)
|
||||
}
|
||||
|
||||
override fun cudaAvailable(): Boolean {
|
||||
return JTorch.cudaIsAvailable()
|
||||
}
|
||||
|
||||
override fun setSeed(seed: Int): Unit {
|
||||
JTorch.setSeed(seed)
|
||||
}
|
||||
|
||||
override var checks: Boolean = false
|
||||
|
||||
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
|
||||
|
||||
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||
wrap(JTorch.unaryMinus(this.tensorHandle))
|
||||
|
||||
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
return wrap(JTorch.matmul(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
JTorch.matmulAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
JTorch.matmulRightAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int
|
||||
): TorchTensorType =
|
||||
wrap(JTorch.diagEmbed(diagonalEntries.tensorHandle, offset, dim1, dim2))
|
||||
|
||||
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
return wrap(JTorch.transposeTensor(tensorHandle, i, j))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
JTorch.transposeTensorAssign(tensorHandle, i, j)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
|
||||
if (checks) checkView(this, shape)
|
||||
return wrap(JTorch.viewTensor(this.tensorHandle, shape))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.abs(): TorchTensorType = wrap(JTorch.absTensor(tensorHandle))
|
||||
override fun TorchTensorType.absAssign(): Unit = JTorch.absTensorAssign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.sum(): TorchTensorType = wrap(JTorch.sumTensor(tensorHandle))
|
||||
override fun TorchTensorType.sumAssign(): Unit = JTorch.sumTensorAssign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
|
||||
wrap(JTorch.randintLike(this.tensorHandle, low, high))
|
||||
|
||||
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
|
||||
JTorch.randintLikeAssign(this.tensorHandle, low, high)
|
||||
|
||||
override fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(JTorch.copyTensor(this.tensorHandle))
|
||||
|
||||
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||
wrap(JTorch.copyToDevice(this.tensorHandle, device.toInt()))
|
||||
|
||||
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
|
||||
JTorch.swapTensors(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorPartialDivisionAlgebraJVM<T, PrimitiveArrayType,
|
||||
TorchTensorType : TorchTensorOverFieldJVM<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebraJVM<T, PrimitiveArrayType, TorchTensorType>(scope),
|
||||
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
|
||||
override operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.divTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.divTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||
wrap(JTorch.randLike(this.tensorHandle))
|
||||
|
||||
override fun TorchTensorType.randUniformAssign(): Unit =
|
||||
JTorch.randLikeAssign(this.tensorHandle)
|
||||
|
||||
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||
wrap(JTorch.randnLike(this.tensorHandle))
|
||||
|
||||
override fun TorchTensorType.randNormalAssign(): Unit =
|
||||
JTorch.randnLikeAssign(this.tensorHandle)
|
||||
|
||||
override fun TorchTensorType.exp(): TorchTensorType = wrap(JTorch.expTensor(tensorHandle))
|
||||
override fun TorchTensorType.expAssign(): Unit = JTorch.expTensorAssign(tensorHandle)
|
||||
override fun TorchTensorType.log(): TorchTensorType = wrap(JTorch.logTensor(tensorHandle))
|
||||
override fun TorchTensorType.logAssign(): Unit = JTorch.logTensorAssign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||
val U = JTorch.emptyTensor()
|
||||
val V = JTorch.emptyTensor()
|
||||
val S = JTorch.emptyTensor()
|
||||
JTorch.svdTensor(this.tensorHandle, U, S, V)
|
||||
return Triple(wrap(U), wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair<TorchTensorType, TorchTensorType> {
|
||||
val V = JTorch.emptyTensor()
|
||||
val S = JTorch.emptyTensor()
|
||||
JTorch.symeigTensor(this.tensorHandle, S, V, eigenvectors)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(JTorch.autogradTensor(this.tensorHandle, variable.tensorHandle, retainGraph))
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(JTorch.autohessTensor(this.tensorHandle, variable.tensorHandle))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(JTorch.detachFromGraph(this.tensorHandle))
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
TorchTensorPartialDivisionAlgebraJVM<Double, DoubleArray, TorchTensorReal>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorReal =
|
||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||
this.elements().map { it.second }.toList().toDoubleArray()
|
||||
|
||||
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.fromBlobDouble(array, shape, device.toInt()))
|
||||
|
||||
override fun randNormal(shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.randnDouble(shape, device.toInt()))
|
||||
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.randDouble(shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.randintDouble(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.plusAssign(value: Double): Unit =
|
||||
JTorch.plusDoubleAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.minusAssign(value: Double): Unit =
|
||||
JTorch.plusDoubleAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(JTorch.timesDouble(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
|
||||
wrap(JTorch.timesDouble(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.timesAssign(value: Double): Unit =
|
||||
JTorch.timesDoubleAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.fullDouble(value, shape, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
TorchTensorPartialDivisionAlgebraJVM<Float, FloatArray, TorchTensorFloat>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorFloat =
|
||||
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||
this.elements().map { it.second }.toList().toFloatArray()
|
||||
|
||||
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.fromBlobFloat(array, shape, device.toInt()))
|
||||
|
||||
override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.randnFloat(shape, device.toInt()))
|
||||
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.randFloat(shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.randintFloat(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
|
||||
JTorch.plusFloatAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
|
||||
JTorch.plusFloatAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(JTorch.timesFloat(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||
wrap(JTorch.timesFloat(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
|
||||
JTorch.timesFloatAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.fullFloat(value, shape, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebraJVM<Long, LongArray, TorchTensorLong>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorLong =
|
||||
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||
this.elements().map { it.second }.toList().toLongArray()
|
||||
|
||||
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(JTorch.fromBlobLong(array, shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(JTorch.randintLong(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.plusAssign(value: Long): Unit =
|
||||
JTorch.plusLongAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.minusAssign(value: Long): Unit =
|
||||
JTorch.plusLongAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(JTorch.timesLong(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||
wrap(JTorch.timesLong(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.timesAssign(value: Long): Unit =
|
||||
JTorch.timesLongAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(JTorch.fullLong(value, shape, device.toInt()))
|
||||
}
|
||||
|
||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebraJVM<Int, IntArray, TorchTensorInt>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorInt =
|
||||
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||
this.elements().map { it.second }.toList().toIntArray()
|
||||
|
||||
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(JTorch.fromBlobInt(array, shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(JTorch.randintInt(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.plusAssign(value: Int): Unit =
|
||||
JTorch.plusIntAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.minusAssign(value: Int): Unit =
|
||||
JTorch.plusIntAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(JTorch.timesInt(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||
wrap(JTorch.timesInt(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.timesAssign(value: Int): Unit =
|
||||
JTorch.timesIntAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(JTorch.fullInt(value, shape, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorFloatAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorLongAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
@ -0,0 +1,94 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.memory.DeferScope
|
||||
|
||||
public sealed class TorchTensorJVM<T> constructor(
|
||||
scope: DeferScope,
|
||||
internal val tensorHandle: Long
|
||||
) : TorchTensor<T>, TorchTensorMemoryHolder(scope)
|
||||
{
|
||||
override fun close(): Unit = JTorch.disposeTensor(tensorHandle)
|
||||
|
||||
override val dimension: Int get() = JTorch.getDim(tensorHandle)
|
||||
override val shape: IntArray
|
||||
get() = (1..dimension).map { JTorch.getShapeAt(tensorHandle, it - 1) }.toIntArray()
|
||||
override val strides: IntArray
|
||||
get() = (1..dimension).map { JTorch.getStrideAt(tensorHandle, it - 1) }.toIntArray()
|
||||
override val size: Int get() = JTorch.getNumel(tensorHandle)
|
||||
override val device: Device get() = Device.fromInt(JTorch.getDevice(tensorHandle))
|
||||
|
||||
override fun toString(): String = JTorch.tensorToString(tensorHandle)
|
||||
|
||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToDouble(this.tensorHandle)
|
||||
)
|
||||
|
||||
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToFloat(this.tensorHandle)
|
||||
)
|
||||
|
||||
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToLong(this.tensorHandle)
|
||||
)
|
||||
|
||||
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToInt(this.tensorHandle)
|
||||
)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorOverFieldJVM<T> constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorJVM<T>(scope, tensorHandle), TorchTensorOverField<T> {
|
||||
override var requiresGrad: Boolean
|
||||
get() = JTorch.requiresGrad(tensorHandle)
|
||||
set(value) = JTorch.setRequiresGrad(tensorHandle, value)
|
||||
}
|
||||
|
||||
public class TorchTensorReal internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Double>(scope, tensorHandle) {
|
||||
override fun item(): Double = JTorch.getItemDouble(tensorHandle)
|
||||
override fun get(index: IntArray): Double = JTorch.getDouble(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Double) {
|
||||
JTorch.setDouble(tensorHandle, index, value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorFloat internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Float>(scope, tensorHandle) {
|
||||
override fun item(): Float = JTorch.getItemFloat(tensorHandle)
|
||||
override fun get(index: IntArray): Float = JTorch.getFloat(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Float) {
|
||||
JTorch.setFloat(tensorHandle, index, value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorLong internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Long>(scope, tensorHandle) {
|
||||
override fun item(): Long = JTorch.getItemLong(tensorHandle)
|
||||
override fun get(index: IntArray): Long = JTorch.getLong(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Long) {
|
||||
JTorch.setLong(tensorHandle, index, value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorInt internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Int>(scope, tensorHandle) {
|
||||
override fun item(): Int = JTorch.getItemInt(tensorHandle)
|
||||
override fun get(index: IntArray): Int = JTorch.getInt(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Int) {
|
||||
JTorch.setInt(tensorHandle, index, value)
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class BenchmarkMatMul {
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMulDouble() = TorchTensorRealAlgebra {
|
||||
benchmarkMatMul(20, 10, 100000, "Real")
|
||||
benchmarkMatMul(200, 10, 10000, "Real")
|
||||
benchmarkMatMul(2000, 3, 20, "Real")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMulFloat() = TorchTensorFloatAlgebra {
|
||||
benchmarkMatMul(20, 10, 100000, "Float")
|
||||
benchmarkMatMul(200, 10, 10000, "Float")
|
||||
benchmarkMatMul(2000, 3, 20, "Float")
|
||||
if (cudaAvailable()) {
|
||||
benchmarkMatMul(20, 10, 100000, "Float", Device.CUDA(0))
|
||||
benchmarkMatMul(200, 10, 10000, "Float", Device.CUDA(0))
|
||||
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class BenchmarkRandomGenerators {
|
||||
@Test
|
||||
fun benchmarkRand1() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand1()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand3() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand3()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand5() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand5()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand7() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand7()
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class TestAutograd {
|
||||
@Test
|
||||
fun testAutoGrad() = TorchTensorFloatAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingAutoGrad(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedAutoGrad(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
class TestTorchTensor {
|
||||
|
||||
@Test
|
||||
fun testCopying() = TorchTensorFloatAlgebra {
|
||||
withCuda { device ->
|
||||
testingCopying(device)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||
testingRequiresGrad()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTypeMoving() = TorchTensorFloatAlgebra {
|
||||
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).copyToInt()
|
||||
TorchTensorIntAlgebra {
|
||||
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
|
||||
tensorInt swap temporalTensor
|
||||
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
|
||||
}
|
||||
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testViewWithNoCopy() = TorchTensorIntAlgebra {
|
||||
withChecks {
|
||||
withCuda {
|
||||
device -> testingViewWithNoCopy(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class TestTorchTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testScalarProduct() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingScalarProduct(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMatrixMultiplication() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingMatrixMultiplication(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLinearStructure() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingLinearStructure(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTensorTransformations() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingTensorTransformations(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedSVD() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedSVD(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedSymEig() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedSymEig(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
class TestUtils {
|
||||
|
||||
@Test
|
||||
fun testSetNumThreads() {
|
||||
TorchTensorLongAlgebra {
|
||||
testingSetNumThreads()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSeedSetting() = TorchTensorFloatAlgebra {
|
||||
withCuda { device ->
|
||||
testingSetSeed(device)
|
||||
}
|
||||
}
|
||||
}
|
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
@ -0,0 +1,2 @@
|
||||
package=space.kscience.kmath.torch.ctorch
|
||||
headers=ctorch.h
|
@ -0,0 +1,444 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.memory.DeferScope
|
||||
import space.kscience.kmath.memory.withDeferScope
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import space.kscience.kmath.tensors.*
|
||||
import space.kscience.kmath.torch.ctorch.*
|
||||
|
||||
public sealed class TorchTensorAlgebraNative<
|
||||
T,
|
||||
TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType,
|
||||
TorchTensorType : TorchTensorNative<T>> constructor(
|
||||
internal val scope: DeferScope
|
||||
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
|
||||
override fun getNumThreads(): Int {
|
||||
return get_num_threads()
|
||||
}
|
||||
|
||||
override fun setNumThreads(numThreads: Int): Unit {
|
||||
set_num_threads(numThreads)
|
||||
}
|
||||
|
||||
override fun cudaAvailable(): Boolean {
|
||||
return cuda_is_available()
|
||||
}
|
||||
|
||||
override fun setSeed(seed: Int): Unit {
|
||||
set_seed(seed)
|
||||
}
|
||||
|
||||
override var checks: Boolean = false
|
||||
|
||||
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
|
||||
|
||||
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
||||
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
||||
|
||||
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||
wrap(unary_minus(this.tensorHandle)!!)
|
||||
|
||||
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
return wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
matmul_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int
|
||||
): TorchTensorType =
|
||||
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||
|
||||
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
return wrap(transpose_tensor(tensorHandle, i, j)!!)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
transpose_tensor_assign(tensorHandle, i, j)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
|
||||
if (checks) checkView(this, shape)
|
||||
return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.absAssign(): Unit = abs_tensor_assign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.sumAssign(): Unit = sum_tensor_assign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
|
||||
wrap(randint_like(this.tensorHandle, low, high)!!)
|
||||
|
||||
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
|
||||
randint_like_assign(this.tensorHandle, low, high)
|
||||
|
||||
override fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(copy_tensor(this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||
|
||||
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
|
||||
swap_tensors(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensorOverFieldNative<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebraNative<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
||||
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
|
||||
override operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||
wrap(rand_like(this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorType.randUniformAssign(): Unit =
|
||||
rand_like_assign(this.tensorHandle)
|
||||
|
||||
|
||||
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||
wrap(randn_like(this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorType.randNormalAssign(): Unit =
|
||||
randn_like_assign(this.tensorHandle)
|
||||
|
||||
|
||||
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.expAssign(): Unit = exp_tensor_assign(tensorHandle)
|
||||
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.logAssign(): Unit = log_tensor_assign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||
val U = empty_tensor()!!
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
svd_tensor(this.tensorHandle, U, S, V)
|
||||
return Triple(wrap(U), wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair<TorchTensorType, TorchTensorType> {
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
symeig_tensor(this.tensorHandle, S, V, eigenvectors)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(detach_from_graph(this.tensorHandle)!!)
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
TorchTensorPartialDivisionAlgebraNative<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||
this.elements().map { it.second }.toList().toDoubleArray()
|
||||
|
||||
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(from_blob_double(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
|
||||
wrap(from_blob_double(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
|
||||
require(this.device is Device.CPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_double(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randNormal(shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(randn_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(randint_double(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(plus_double(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
|
||||
wrap(plus_double(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.plusAssign(value: Double): Unit {
|
||||
plus_double_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(plus_double(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
|
||||
wrap(plus_double(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.minusAssign(value: Double): Unit {
|
||||
plus_double_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(times_double(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
|
||||
wrap(times_double(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorReal.timesAssign(value: Double): Unit {
|
||||
times_double_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
|
||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
||||
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||
this.elements().map { it.second }.toList().toFloatArray()
|
||||
|
||||
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(from_blob_float(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<FloatVar>, shape: IntArray): TorchTensorFloat =
|
||||
wrap(from_blob_float(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorFloat.getData(): CPointer<FloatVar> {
|
||||
require(this.device is Device.CPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_float(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(randn_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(randint_float(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(plus_float(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||
wrap(plus_float(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
|
||||
plus_float_assign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(plus_float(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||
wrap(plus_float(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
|
||||
plus_float_assign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(times_float(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||
wrap(times_float(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
|
||||
times_float_assign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebraNative<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
|
||||
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||
this.elements().map { it.second }.toList().toLongArray()
|
||||
|
||||
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(from_blob_long(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<LongVar>, shape: IntArray): TorchTensorLong =
|
||||
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorLong.getData(): CPointer<LongVar> {
|
||||
check(this.device is Device.CPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_long(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(plus_long(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||
wrap(plus_long(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.plusAssign(value: Long): Unit =
|
||||
plus_long_assign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(plus_long(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||
wrap(plus_long(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.minusAssign(value: Long): Unit =
|
||||
plus_long_assign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(times_long(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||
wrap(times_long(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.timesAssign(value: Long): Unit =
|
||||
times_long_assign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebraNative<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
|
||||
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||
this.elements().map { it.second }.toList().toIntArray()
|
||||
|
||||
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(from_blob_int(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<IntVar>, shape: IntArray): TorchTensorInt =
|
||||
wrap(from_blob_int(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||
|
||||
override fun TorchTensorInt.getData(): CPointer<IntVar> {
|
||||
require(this.device is Device.CPU) {
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_int(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(plus_int(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||
wrap(plus_int(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.plusAssign(value: Int): Unit =
|
||||
plus_int_assign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(plus_int(-this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||
wrap(plus_int(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.minusAssign(value: Int): Unit =
|
||||
plus_int_assign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(times_int(this, other.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||
wrap(times_int(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.timesAssign(value: Int): Unit =
|
||||
times_int_assign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
|
||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorFloatAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorLongAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
||||
|
@ -0,0 +1,104 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import space.kscience.kmath.memory.DeferScope
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import space.kscience.kmath.torch.ctorch.*
|
||||
|
||||
|
||||
public sealed class TorchTensorNative<T> constructor(
|
||||
scope: DeferScope,
|
||||
internal val tensorHandle: COpaquePointer
|
||||
) : TorchTensor<T>, TorchTensorMemoryHolder(scope) {
|
||||
|
||||
override fun close(): Unit = dispose_tensor(tensorHandle)
|
||||
|
||||
override val dimension: Int get() = get_dim(tensorHandle)
|
||||
override val shape: IntArray
|
||||
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
|
||||
override val strides: IntArray
|
||||
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
|
||||
override val size: Int get() = get_numel(tensorHandle)
|
||||
override val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
||||
|
||||
override fun toString(): String {
|
||||
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!!
|
||||
val stringRepresentation = nativeStringRepresentation.toKString()
|
||||
dispose_char(nativeStringRepresentation)
|
||||
return stringRepresentation
|
||||
}
|
||||
|
||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_float(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_long(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_int(this.tensorHandle)!!
|
||||
)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorOverFieldNative<T> constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorNative<T>(scope, tensorHandle), TorchTensorOverField<T> {
|
||||
override var requiresGrad: Boolean
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
}
|
||||
|
||||
|
||||
public class TorchTensorReal internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorOverFieldNative<Double>(scope, tensorHandle) {
|
||||
override fun item(): Double = get_item_double(tensorHandle)
|
||||
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Double) {
|
||||
set_double(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorFloat internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorOverFieldNative<Float>(scope, tensorHandle) {
|
||||
override fun item(): Float = get_item_float(tensorHandle)
|
||||
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Float) {
|
||||
set_float(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorLong internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorNative<Long>(scope, tensorHandle) {
|
||||
override fun item(): Long = get_item_long(tensorHandle)
|
||||
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Long) {
|
||||
set_long(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorInt internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorNative<Int>(scope, tensorHandle) {
|
||||
override fun item(): Int = get_item_int(tensorHandle)
|
||||
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Int) {
|
||||
set_int(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
internal class BenchmarkMatMul {
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMulDouble() = TorchTensorRealAlgebra {
|
||||
benchmarkMatMul(20, 10, 100000, "Real")
|
||||
benchmarkMatMul(200, 10, 10000, "Real")
|
||||
benchmarkMatMul(2000, 3, 20, "Real")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMulFloat() = TorchTensorFloatAlgebra {
|
||||
benchmarkMatMul(20, 10, 100000, "Float")
|
||||
benchmarkMatMul(200, 10, 10000, "Float")
|
||||
benchmarkMatMul(2000, 3, 20, "Float")
|
||||
if (cudaAvailable()) {
|
||||
benchmarkMatMul(20, 10, 100000, "Float", Device.CUDA(0))
|
||||
benchmarkMatMul(200, 10, 10000, "Float", Device.CUDA(0))
|
||||
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
internal class BenchmarkRandomGenerators {
|
||||
|
||||
@Test
|
||||
fun benchmarkRand1() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand1()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand3() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand3()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand5() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand5()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand7() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand7()
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
internal class TestAutograd {
|
||||
@Test
|
||||
fun testAutoGrad() = TorchTensorFloatAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingAutoGrad(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedAutoGrad(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
class TestTorchTensor {
|
||||
|
||||
@Test
|
||||
fun testCopying() = TorchTensorFloatAlgebra {
|
||||
withCuda {
|
||||
device -> testingCopying(device)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testNativeNoCopyDataTransferOnCPU() = memScoped {
|
||||
val data = allocArray<DoubleVar>(1)
|
||||
data[0] = 1.0
|
||||
TorchTensorRealAlgebra {
|
||||
val tensor = fromBlob(data, intArrayOf(1))
|
||||
assertEquals(tensor[intArrayOf(0)], 1.0)
|
||||
data[0] = 2.0
|
||||
assertEquals(tensor[intArrayOf(0)], 2.0)
|
||||
val tensorData = tensor.getData()
|
||||
tensorData[0] = 3.0
|
||||
assertEquals(tensor[intArrayOf(0)], 3.0)
|
||||
}
|
||||
assertEquals(data[0], 3.0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||
testingRequiresGrad()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTypeMoving() = TorchTensorFloatAlgebra {
|
||||
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).copyToInt()
|
||||
TorchTensorIntAlgebra {
|
||||
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
|
||||
tensorInt swap temporalTensor
|
||||
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
|
||||
}
|
||||
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testViewWithNoCopy() = TorchTensorIntAlgebra {
|
||||
withChecks {
|
||||
withCuda {
|
||||
device -> testingViewWithNoCopy(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
internal class TestTorchTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testScalarProduct() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingScalarProduct(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMatrixMultiplication() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingMatrixMultiplication(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLinearStructure() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingLinearStructure(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTensorTransformations() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingTensorTransformations(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedSVD() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedSVD(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedSymEig() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedSymEig(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
package space.kscience.kmath.torch
|
||||
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
internal class TestUtils {
|
||||
@Test
|
||||
fun testSetNumThreads() {
|
||||
TorchTensorLongAlgebra {
|
||||
testingSetNumThreads()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSeedSetting() = TorchTensorFloatAlgebra {
|
||||
withCuda {
|
||||
device -> testingSetSeed(device)
|
||||
}
|
||||
}
|
||||
}
|
0
kmath-torchscript/example/build.gradle.kts
Normal file
0
kmath-torchscript/example/build.gradle.kts
Normal file
12
kmath-torchscript/plugin/build.gradle.kts
Normal file
12
kmath-torchscript/plugin/build.gradle.kts
Normal file
@ -0,0 +1,12 @@
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.project")
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
gradlePluginPortal()
|
||||
mavenCentral()
|
||||
maven("https://maven.google.com")
|
||||
maven("https://plugins.gradle.org/m2/")
|
||||
}
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
plugins {
|
||||
kotlin("jvm")
|
||||
id("ru.mipt.npm.gradle.common")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compileOnly(kotlin("compiler-embeddable"))
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.torchscript.compiler
|
||||
|
||||
import org.jetbrains.kotlin.backend.common.extensions.IrGenerationExtension
|
||||
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
|
||||
import org.jetbrains.kotlin.com.intellij.mock.MockProject
|
||||
import org.jetbrains.kotlin.compiler.plugin.ComponentRegistrar
|
||||
import org.jetbrains.kotlin.config.CompilerConfiguration
|
||||
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
|
||||
|
||||
public class TorchscriptPluginComponentRegistrar : ComponentRegistrar {
|
||||
public override fun registerProjectComponents(project: MockProject, configuration: CompilerConfiguration): Unit =
|
||||
registerExtensions(project)
|
||||
|
||||
private companion object {
|
||||
private fun registerExtensions(project: MockProject) {
|
||||
IrGenerationExtension.registerExtension(project, TorchscriptExtension)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public object TorchscriptExtension : IrGenerationExtension {
|
||||
override fun generate(moduleFragment: IrModuleFragment, pluginContext: IrPluginContext) {
|
||||
TODO()
|
||||
}
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.torchscript.compiler.resolve
|
||||
|
||||
import org.jetbrains.kotlin.name.FqName
|
||||
|
||||
public object TorchscriptAnnotations {
|
||||
public val moduleAnnotationFqName: FqName = FqName("space.kscience.kmath.torchscript.Module")
|
||||
public val intrinsicAnnotationFqName: FqName = FqName("space.kscience.kmath.torchscript.Intrinsic")
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
#
|
||||
# Copyright 2018-2021 KMath contributors.
|
||||
# Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
#
|
||||
|
||||
space.kscience.kmath.torchscript.compiler.TorchscriptPluginComponentRegistrar
|
19
kmath-torchscript/plugin/gradle-plugin/build.gradle.kts
Normal file
19
kmath-torchscript/plugin/gradle-plugin/build.gradle.kts
Normal file
@ -0,0 +1,19 @@
|
||||
@file:Suppress("UNUSED_VARIABLE")
|
||||
|
||||
plugins {
|
||||
`kotlin-dsl`
|
||||
kotlin("kapt")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compileOnly("com.google.auto.service:auto-service-annotations:1.0")
|
||||
compileOnly(kotlin("gradle-plugin"))
|
||||
kapt("com.google.auto.service:auto-service:1.0")
|
||||
}
|
||||
|
||||
gradlePlugin {
|
||||
val kmathTorchScriptPlugin by plugins.registering {
|
||||
id = "space.kscience.kmath.torchscript.torchscript-compiler-plugin"
|
||||
implementationClass = "space.kscience.kmath.torchscript.gradle.TorchScriptKotlinGradleSubplugin"
|
||||
}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.torchscript.gradle
|
||||
|
||||
class TorchScriptKotlinGradleSubplugin {
|
||||
}
|
19
kmath-torchscript/plugin/settings.gradle.kts
Normal file
19
kmath-torchscript/plugin/settings.gradle.kts
Normal file
@ -0,0 +1,19 @@
|
||||
rootProject.name = "plugin"
|
||||
|
||||
pluginManagement {
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
mavenCentral()
|
||||
gradlePluginPortal()
|
||||
}
|
||||
|
||||
val toolsVersion = "0.10.0"
|
||||
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.project") version toolsVersion
|
||||
id("ru.mipt.npm.gradle.mpp") version toolsVersion
|
||||
id("ru.mipt.npm.gradle.jvm") version toolsVersion
|
||||
}
|
||||
}
|
||||
|
||||
include(":compiler-plugin", ":gradle-plugin")
|
5
kmath-torchscript/runtime/build.gradle.kts
Normal file
5
kmath-torchscript/runtime/build.gradle.kts
Normal file
@ -0,0 +1,5 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
id("ru.mipt.npm.gradle.common")
|
||||
}
|
||||
|
@ -0,0 +1,10 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.torchscript
|
||||
|
||||
public annotation class Module
|
||||
|
||||
public annotation class Intrinsic
|
@ -46,3 +46,11 @@ include(
|
||||
":examples",
|
||||
":benchmarks",
|
||||
)
|
||||
|
||||
if (System.getProperty("os.name") == "Linux") {
|
||||
include(":kmath-torch")
|
||||
include(":kmath-torchscript-library")
|
||||
include(":kmath-torchscript:example")
|
||||
include(":kmath-torchscript:runtime")
|
||||
includeBuild("kmath-torchscript/plugin")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user