Compare commits
41 Commits
dev
...
feature/to
Author | SHA1 | Date | |
---|---|---|---|
|
bd736c6716 | ||
|
86528fc943 | ||
|
3e3d433a64 | ||
|
0a4e7acb4c | ||
|
6594ffc965 | ||
|
9b3258b06b | ||
|
120189fd89 | ||
|
40895e5936 | ||
|
c141c04e99 | ||
|
c9dfb6a08c | ||
|
391eb28cad | ||
|
17e6ebbc14 | ||
|
d599d1132b | ||
|
274d1a3105 | ||
|
b30ca920e1 | ||
|
6eb718f64a | ||
|
889691a122 | ||
|
e5205d5afd | ||
|
ed4ac2623d | ||
|
97ef57697d | ||
|
4f4fcba559 | ||
|
ca2082405a | ||
|
0b784474b4 | ||
|
fbb414731b | ||
|
7d25aa2834 | ||
|
ef570254e6 | ||
|
80f28dbcd5 | ||
|
ca3cca65ef | ||
|
39f3a87bbd | ||
|
524b1d80d1 | ||
|
8967691b7d | ||
|
7105331149 | ||
|
7894799e8e | ||
|
0cb2c3f0da | ||
|
9b1a958491 | ||
|
cfe93886ac | ||
|
fb9d612081 | ||
|
d97f8857a0 | ||
|
0fc29b40c5 | ||
|
32e4b68061 | ||
|
a229aaa6a4 |
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
@ -12,6 +12,9 @@ jobs:
|
|||||||
uses: actions/setup-java@v1
|
uses: actions/setup-java@v1
|
||||||
with:
|
with:
|
||||||
java-version: 11
|
java-version: 11
|
||||||
|
- name: Install build-essential
|
||||||
|
run: |
|
||||||
|
sudo apt install -y build-essential
|
||||||
- name: Grant execute permission for gradlew
|
- name: Grant execute permission for gradlew
|
||||||
run: chmod +x gradlew
|
run: chmod +x gradlew
|
||||||
- name: Install Chrome
|
- name: Install Chrome
|
||||||
|
@ -254,7 +254,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [NDStructure] over [Buffer].
|
* Trait for [NDStructure] over [Buffer].
|
||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
* @param strides The strides to access elements of [Buffer] by linear indices.
|
* @param strides The strides to access elements of [Buffer] by linear indices.
|
||||||
@ -281,6 +281,7 @@ public open class NDBuffer<T>(
|
|||||||
it to this[it]
|
it to this[it]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
|
return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
|
||||||
}
|
}
|
||||||
@ -346,4 +347,4 @@ public inline fun <reified T : Any> NDStructure<T>.combine(
|
|||||||
): NDStructure<T> {
|
): NDStructure<T> {
|
||||||
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
|
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
|
||||||
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
||||||
}
|
}
|
@ -36,13 +36,11 @@ inline public fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): In
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class TensorStrides(override val shape: IntArray) : Strides {
|
||||||
public class TensorStrides(override val shape: IntArray): Strides
|
|
||||||
{
|
|
||||||
override val strides: IntArray
|
override val strides: IntArray
|
||||||
get() = stridesFromShape(shape)
|
get() = stridesFromShape(shape)
|
||||||
|
|
||||||
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray =
|
override fun index(offset: Int): IntArray =
|
||||||
indexFromOffset(offset, strides, shape.size)
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
package space.kscience.kmath.memory
|
||||||
|
|
||||||
|
public expect class DeferScope {
|
||||||
|
public inline fun defer(crossinline block: () -> Unit)
|
||||||
|
}
|
||||||
|
|
||||||
|
public expect inline fun <R> withDeferScope(block: DeferScope.() -> R): R
|
@ -0,0 +1,30 @@
|
|||||||
|
package space.kscience.kmath.memory
|
||||||
|
|
||||||
|
private typealias Deferred = () -> Unit
|
||||||
|
|
||||||
|
public actual class DeferScope {
|
||||||
|
@PublishedApi
|
||||||
|
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
internal fun executeAllDeferred() {
|
||||||
|
deferred.forEach(Deferred::invoke)
|
||||||
|
deferred.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||||
|
deferred += {
|
||||||
|
try {
|
||||||
|
block()
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||||
|
val ds = DeferScope()
|
||||||
|
val r = ds.block()
|
||||||
|
ds.executeAllDeferred()
|
||||||
|
return r
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
package space.kscience.kmath.memory
|
||||||
|
|
||||||
|
private typealias Deferred = () -> Unit
|
||||||
|
|
||||||
|
public actual class DeferScope {
|
||||||
|
@PublishedApi
|
||||||
|
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
internal fun executeAllDeferred() {
|
||||||
|
deferred.forEach(Deferred::invoke)
|
||||||
|
deferred.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||||
|
deferred += {
|
||||||
|
try {
|
||||||
|
block()
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||||
|
val ds = DeferScope()
|
||||||
|
val r = ds.block()
|
||||||
|
ds.executeAllDeferred()
|
||||||
|
return r
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
package space.kscience.kmath.memory
|
||||||
|
|
||||||
|
import kotlinx.cinterop.memScoped
|
||||||
|
|
||||||
|
public actual typealias DeferScope = kotlinx.cinterop.DeferScope
|
||||||
|
|
||||||
|
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R = memScoped(block)
|
66
kmath-torch/README.md
Normal file
66
kmath-torch/README.md
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# LibTorch extension (`kmath-torch`)
|
||||||
|
|
||||||
|
This is a `Kotlin/Native` & `JVM` module, with only `linuxX64` supported so far. The library wraps some of
|
||||||
|
the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
||||||
|
|
||||||
|
```
|
||||||
|
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||||||
|
```
|
||||||
|
|
||||||
|
This builds `ctorch` a C wrapper and `jtorch` a JNI wrapper for `LibTorch`, placed inside:
|
||||||
|
|
||||||
|
`~/.konan/third-party/kmath-torch-0.2.0/cpp-build`
|
||||||
|
|
||||||
|
You will have to link against it in your own project.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Tensors are implemented over the `MutableNDStructure`. They can only be created through provided factory methods
|
||||||
|
and require scoping within a `TensorAlgebra` instance:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
|
||||||
|
val realTensor: TorchTensorReal = copyFromArray(
|
||||||
|
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
||||||
|
shape = intArrayOf(2, 5)
|
||||||
|
)
|
||||||
|
println(realTensor)
|
||||||
|
|
||||||
|
val gpuRealTensor: TorchTensorReal = copyFromArray(
|
||||||
|
array = (1..8).map { it * 2.5 }.toList().toDoubleArray(),
|
||||||
|
shape = intArrayOf(2, 2, 2),
|
||||||
|
device = Device.CUDA(0)
|
||||||
|
)
|
||||||
|
println(gpuRealTensor)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
High performance automatic differentiation engine is available:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
val dim = 10
|
||||||
|
val device = Device.CPU //or Device.CUDA(0)
|
||||||
|
|
||||||
|
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||||
|
val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
|
||||||
|
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
|
||||||
|
// expression to differentiate w.r.t. x evaluated at x = tensorX
|
||||||
|
val expressionAtX = withGradAt(tensorX, { x ->
|
||||||
|
0.5 * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9
|
||||||
|
})
|
||||||
|
|
||||||
|
// value of the gradient at x = tensorX
|
||||||
|
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||||
|
// value of the hessian at x = tensorX
|
||||||
|
val hessianAtX = expressionAtX hess tensorX
|
||||||
|
}
|
||||||
|
```
|
||||||
|
Contributed by [Roland Grinis](https://github.com/rgrit91)
|
397
kmath-torch/api/kmath-torch.api
Normal file
397
kmath-torch/api/kmath-torch.api
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
public abstract class space/kscience/kmath/torch/Device {
|
||||||
|
public static final field Companion Lspace/kscience/kmath/torch/Device$Companion;
|
||||||
|
public final fun toInt ()I
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/Device$CPU : space/kscience/kmath/torch/Device {
|
||||||
|
public static final field INSTANCE Lspace/kscience/kmath/torch/Device$CPU;
|
||||||
|
public fun toString ()Ljava/lang/String;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/Device$CUDA : space/kscience/kmath/torch/Device {
|
||||||
|
public fun <init> (I)V
|
||||||
|
public final fun component1 ()I
|
||||||
|
public final fun copy (I)Lspace/kscience/kmath/torch/Device$CUDA;
|
||||||
|
public static synthetic fun copy$default (Lspace/kscience/kmath/torch/Device$CUDA;IILjava/lang/Object;)Lspace/kscience/kmath/torch/Device$CUDA;
|
||||||
|
public fun equals (Ljava/lang/Object;)Z
|
||||||
|
public final fun getIndex ()I
|
||||||
|
public fun hashCode ()I
|
||||||
|
public fun toString ()Ljava/lang/String;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/Device$Companion {
|
||||||
|
public final fun fromInt (I)Lspace/kscience/kmath/torch/Device;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract interface class space/kscience/kmath/torch/TorchTensor : space/kscience/kmath/tensors/TensorStructure {
|
||||||
|
public abstract fun elements ()Lkotlin/sequences/Sequence;
|
||||||
|
public abstract fun getDevice ()Lspace/kscience/kmath/torch/Device;
|
||||||
|
public abstract fun getSize ()I
|
||||||
|
public abstract fun getStrides ()[I
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensor$DefaultImpls {
|
||||||
|
public static fun elements (Lspace/kscience/kmath/torch/TorchTensor;)Lkotlin/sequences/Sequence;
|
||||||
|
public static fun getDimension (Lspace/kscience/kmath/torch/TorchTensor;)I
|
||||||
|
public static fun value (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract interface class space/kscience/kmath/torch/TorchTensorAlgebra : space/kscience/kmath/tensors/TensorAlgebra {
|
||||||
|
public abstract fun copy (Lspace/kscience/kmath/torch/TorchTensor;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public abstract fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public abstract fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||||
|
public abstract fun copyToDevice (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public abstract fun cudaAvailable ()Z
|
||||||
|
public abstract fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public abstract fun getChecks ()Z
|
||||||
|
public abstract fun getNumThreads ()I
|
||||||
|
public abstract fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public abstract fun randIntegral (Lspace/kscience/kmath/torch/TorchTensor;JJ)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public abstract fun randIntegralAssign (Lspace/kscience/kmath/torch/TorchTensor;JJ)V
|
||||||
|
public abstract fun setChecks (Z)V
|
||||||
|
public abstract fun setNumThreads (I)V
|
||||||
|
public abstract fun setSeed (I)V
|
||||||
|
public abstract fun swap (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorAlgebra$DefaultImpls {
|
||||||
|
public static synthetic fun copyFromArray$default (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public static synthetic fun randIntegral$default (Lspace/kscience/kmath/torch/TorchTensorAlgebra;JJ[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract class space/kscience/kmath/torch/TorchTensorAlgebraJVM : space/kscience/kmath/torch/TorchTensorAlgebra {
|
||||||
|
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||||
|
public synthetic fun abs (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun abs (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun absAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun absAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public synthetic fun copy (Lspace/kscience/kmath/torch/TorchTensor;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun copy (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun copyToDevice (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun copyToDevice (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public fun cudaAvailable ()Z
|
||||||
|
public synthetic fun diagonalEmbedding (Lspace/kscience/kmath/tensors/TensorStructure;III)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun diagonalEmbedding (Lspace/kscience/kmath/torch/TorchTensorJVM;III)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun dot (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun dot (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun dotAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun dotAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public synthetic fun dotRightAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun dotRightAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public fun getChecks ()Z
|
||||||
|
public fun getNumThreads ()I
|
||||||
|
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun minus (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun plus (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public synthetic fun randIntegral (Lspace/kscience/kmath/torch/TorchTensor;JJ)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun randIntegral (Lspace/kscience/kmath/torch/TorchTensorJVM;JJ)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun randIntegralAssign (Lspace/kscience/kmath/torch/TorchTensor;JJ)V
|
||||||
|
public fun randIntegralAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;JJ)V
|
||||||
|
public fun setChecks (Z)V
|
||||||
|
public fun setNumThreads (I)V
|
||||||
|
public fun setSeed (I)V
|
||||||
|
public synthetic fun sum (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun sum (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun sumAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun sumAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public synthetic fun swap (Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||||
|
public fun swap (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun times (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;Lspace/kscience/kmath/torch/TorchTensorJVM;)V
|
||||||
|
public synthetic fun transpose (Lspace/kscience/kmath/tensors/TensorStructure;II)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun transpose (Lspace/kscience/kmath/torch/TorchTensorJVM;II)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun transposeAssign (Lspace/kscience/kmath/tensors/TensorStructure;II)V
|
||||||
|
public fun transposeAssign (Lspace/kscience/kmath/torch/TorchTensorJVM;II)V
|
||||||
|
public synthetic fun unaryMinus (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun unaryMinus (Lspace/kscience/kmath/torch/TorchTensorJVM;)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
public synthetic fun view (Lspace/kscience/kmath/tensors/TensorStructure;[I)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun view (Lspace/kscience/kmath/torch/TorchTensorJVM;[I)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorAlgebraJVMKt {
|
||||||
|
public static final fun TorchTensorFloatAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||||
|
public static final fun TorchTensorIntAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||||
|
public static final fun TorchTensorLongAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||||
|
public static final fun TorchTensorRealAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorAlgebraKt {
|
||||||
|
public static final fun checkDeviceCompatible (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||||
|
public static final fun checkDotOperation (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||||
|
public static final fun checkLinearOperation (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lspace/kscience/kmath/torch/TorchTensor;Lspace/kscience/kmath/torch/TorchTensor;)V
|
||||||
|
public static final fun withChecks (Lspace/kscience/kmath/torch/TorchTensorAlgebra;Lkotlin/jvm/functions/Function1;)V
|
||||||
|
public static final fun withGradAt (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;Lspace/kscience/kmath/torch/TorchTensorOverField;Lkotlin/jvm/functions/Function2;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorFloat : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||||
|
public fun get ([I)Ljava/lang/Float;
|
||||||
|
public synthetic fun get ([I)Ljava/lang/Object;
|
||||||
|
public fun item ()Ljava/lang/Float;
|
||||||
|
public synthetic fun item ()Ljava/lang/Object;
|
||||||
|
public fun set ([IF)V
|
||||||
|
public synthetic fun set ([ILjava/lang/Object;)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorFloatAlgebra : space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebraJVM {
|
||||||
|
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||||
|
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun copyFromArray ([F[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||||
|
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorFloat;)[F
|
||||||
|
public fun full (F[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun minus (FLspace/kscience/kmath/torch/TorchTensorFloat;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun minus (Lspace/kscience/kmath/torch/TorchTensorFloat;F)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorFloat;F)V
|
||||||
|
public fun plus (FLspace/kscience/kmath/torch/TorchTensorFloat;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun plus (Lspace/kscience/kmath/torch/TorchTensorFloat;F)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorFloat;F)V
|
||||||
|
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun times (FLspace/kscience/kmath/torch/TorchTensorFloat;)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun times (Lspace/kscience/kmath/torch/TorchTensorFloat;F)Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorFloat;F)V
|
||||||
|
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorInt : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||||
|
public fun get ([I)Ljava/lang/Integer;
|
||||||
|
public synthetic fun get ([I)Ljava/lang/Object;
|
||||||
|
public fun item ()Ljava/lang/Integer;
|
||||||
|
public synthetic fun item ()Ljava/lang/Object;
|
||||||
|
public fun set ([II)V
|
||||||
|
public synthetic fun set ([ILjava/lang/Object;)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorIntAlgebra : space/kscience/kmath/torch/TorchTensorAlgebraJVM {
|
||||||
|
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||||
|
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun copyFromArray ([I[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||||
|
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorInt;)[I
|
||||||
|
public fun full (I[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun minus (ILspace/kscience/kmath/torch/TorchTensorInt;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun minus (Lspace/kscience/kmath/torch/TorchTensorInt;I)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorInt;I)V
|
||||||
|
public fun plus (ILspace/kscience/kmath/torch/TorchTensorInt;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun plus (Lspace/kscience/kmath/torch/TorchTensorInt;I)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorInt;I)V
|
||||||
|
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public fun times (ILspace/kscience/kmath/torch/TorchTensorInt;)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun times (Lspace/kscience/kmath/torch/TorchTensorInt;I)Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorInt;I)V
|
||||||
|
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract class space/kscience/kmath/torch/TorchTensorJVM : space/kscience/kmath/torch/TorchTensorMemoryHolder, space/kscience/kmath/torch/TorchTensor {
|
||||||
|
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;JLkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||||
|
protected fun close ()V
|
||||||
|
public final fun copyToDouble ()Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public final fun copyToFloat ()Lspace/kscience/kmath/torch/TorchTensorFloat;
|
||||||
|
public final fun copyToInt ()Lspace/kscience/kmath/torch/TorchTensorInt;
|
||||||
|
public final fun copyToLong ()Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public fun elements ()Lkotlin/sequences/Sequence;
|
||||||
|
public fun getDevice ()Lspace/kscience/kmath/torch/Device;
|
||||||
|
public fun getDimension ()I
|
||||||
|
public fun getShape ()[I
|
||||||
|
public fun getSize ()I
|
||||||
|
public fun getStrides ()[I
|
||||||
|
public fun toString ()Ljava/lang/String;
|
||||||
|
public fun value ()Ljava/lang/Object;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorLong : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||||
|
public fun get ([I)Ljava/lang/Long;
|
||||||
|
public synthetic fun get ([I)Ljava/lang/Object;
|
||||||
|
public fun item ()Ljava/lang/Long;
|
||||||
|
public synthetic fun item ()Ljava/lang/Object;
|
||||||
|
public fun set ([IJ)V
|
||||||
|
public synthetic fun set ([ILjava/lang/Object;)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorLongAlgebra : space/kscience/kmath/torch/TorchTensorAlgebraJVM {
|
||||||
|
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||||
|
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun copyFromArray ([J[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||||
|
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorLong;)[J
|
||||||
|
public fun full (J[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun minus (JLspace/kscience/kmath/torch/TorchTensorLong;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun minus (Lspace/kscience/kmath/torch/TorchTensorLong;J)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorLong;J)V
|
||||||
|
public fun plus (JLspace/kscience/kmath/torch/TorchTensorLong;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun plus (Lspace/kscience/kmath/torch/TorchTensorLong;J)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorLong;J)V
|
||||||
|
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public fun times (JLspace/kscience/kmath/torch/TorchTensorLong;)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun times (Lspace/kscience/kmath/torch/TorchTensorLong;J)Lspace/kscience/kmath/torch/TorchTensorLong;
|
||||||
|
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorLong;J)V
|
||||||
|
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract class space/kscience/kmath/torch/TorchTensorMemoryHolder {
|
||||||
|
protected abstract fun close ()V
|
||||||
|
public fun equals (Ljava/lang/Object;)Z
|
||||||
|
public final fun getScope ()Lspace/kscience/kmath/memory/DeferScope;
|
||||||
|
public fun hashCode ()I
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract interface class space/kscience/kmath/torch/TorchTensorOverField : space/kscience/kmath/torch/TorchTensor {
|
||||||
|
public abstract fun getRequiresGrad ()Z
|
||||||
|
public abstract fun setRequiresGrad (Z)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorOverField$DefaultImpls {
|
||||||
|
public static fun elements (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lkotlin/sequences/Sequence;
|
||||||
|
public static fun getDimension (Lspace/kscience/kmath/torch/TorchTensorOverField;)I
|
||||||
|
public static fun value (Lspace/kscience/kmath/torch/TorchTensorOverField;)Ljava/lang/Object;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract class space/kscience/kmath/torch/TorchTensorOverFieldJVM : space/kscience/kmath/torch/TorchTensorJVM, space/kscience/kmath/torch/TorchTensorOverField {
|
||||||
|
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;JLkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||||
|
public fun getRequiresGrad ()Z
|
||||||
|
public fun setRequiresGrad (Z)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract interface class space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra : space/kscience/kmath/tensors/TensorPartialDivisionAlgebra, space/kscience/kmath/torch/TorchTensorAlgebra {
|
||||||
|
public abstract fun detachFromGraph (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;Z)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun hess (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun randNormal (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun randNormalAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||||
|
public abstract fun randUniform (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public abstract fun randUniformAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra$DefaultImpls {
|
||||||
|
public static fun grad (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public static synthetic fun grad$default (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;ZILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public static synthetic fun randNormal$default (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public static synthetic fun randUniform$default (Lspace/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra;[ILspace/kscience/kmath/torch/Device;ILjava/lang/Object;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract class space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebraJVM : space/kscience/kmath/torch/TorchTensorAlgebraJVM, space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebra {
|
||||||
|
public synthetic fun <init> (Lspace/kscience/kmath/memory/DeferScope;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
|
||||||
|
public synthetic fun detachFromGraph (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun detachFromGraph (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun div (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun div (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun divAssign (Lspace/kscience/kmath/tensors/TensorStructure;Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun divAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||||
|
public synthetic fun exp (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun exp (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun expAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun expAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||||
|
public synthetic fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public synthetic fun grad (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;Z)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun grad (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public fun grad (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Z)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun hess (Lspace/kscience/kmath/torch/TorchTensorOverField;Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun hess (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun log (Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun log (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun logAssign (Lspace/kscience/kmath/tensors/TensorStructure;)V
|
||||||
|
public fun logAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||||
|
public synthetic fun randNormal (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun randNormal (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun randNormalAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||||
|
public fun randNormalAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||||
|
public synthetic fun randUniform (Lspace/kscience/kmath/torch/TorchTensorOverField;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun randUniform (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;
|
||||||
|
public synthetic fun randUniformAssign (Lspace/kscience/kmath/torch/TorchTensorOverField;)V
|
||||||
|
public fun randUniformAssign (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)V
|
||||||
|
public synthetic fun svd (Lspace/kscience/kmath/tensors/TensorStructure;)Lkotlin/Triple;
|
||||||
|
public fun svd (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;)Lkotlin/Triple;
|
||||||
|
public synthetic fun symEig (Lspace/kscience/kmath/tensors/TensorStructure;Z)Lkotlin/Pair;
|
||||||
|
public fun symEig (Lspace/kscience/kmath/torch/TorchTensorOverFieldJVM;Z)Lkotlin/Pair;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorReal : space/kscience/kmath/torch/TorchTensorOverFieldJVM {
|
||||||
|
public fun get ([I)Ljava/lang/Double;
|
||||||
|
public synthetic fun get ([I)Ljava/lang/Object;
|
||||||
|
public fun item ()Ljava/lang/Double;
|
||||||
|
public synthetic fun item ()Ljava/lang/Object;
|
||||||
|
public fun set ([ID)V
|
||||||
|
public synthetic fun set ([ILjava/lang/Object;)V
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/torch/TorchTensorRealAlgebra : space/kscience/kmath/torch/TorchTensorPartialDivisionAlgebraJVM {
|
||||||
|
public fun <init> (Lspace/kscience/kmath/memory/DeferScope;)V
|
||||||
|
public synthetic fun copyFromArray (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun copyFromArray ([D[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun copyToArray (Lspace/kscience/kmath/torch/TorchTensor;)Ljava/lang/Object;
|
||||||
|
public fun copyToArray (Lspace/kscience/kmath/torch/TorchTensorReal;)[D
|
||||||
|
public fun full (D[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun full (Ljava/lang/Object;[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun minus (DLspace/kscience/kmath/torch/TorchTensorReal;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun minus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun minus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun minus (Lspace/kscience/kmath/torch/TorchTensorReal;D)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun minusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun minusAssign (Lspace/kscience/kmath/torch/TorchTensorReal;D)V
|
||||||
|
public fun plus (DLspace/kscience/kmath/torch/TorchTensorReal;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun plus (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun plus (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun plus (Lspace/kscience/kmath/torch/TorchTensorReal;D)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun plusAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun plusAssign (Lspace/kscience/kmath/torch/TorchTensorReal;D)V
|
||||||
|
public synthetic fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensor;
|
||||||
|
public fun randIntegral (JJ[ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun randNormal ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorOverField;
|
||||||
|
public fun randUniform ([ILspace/kscience/kmath/torch/Device;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public fun times (DLspace/kscience/kmath/torch/TorchTensorReal;)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun times (Ljava/lang/Object;Lspace/kscience/kmath/tensors/TensorStructure;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public synthetic fun times (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)Lspace/kscience/kmath/tensors/TensorStructure;
|
||||||
|
public fun times (Lspace/kscience/kmath/torch/TorchTensorReal;D)Lspace/kscience/kmath/torch/TorchTensorReal;
|
||||||
|
public synthetic fun timesAssign (Lspace/kscience/kmath/tensors/TensorStructure;Ljava/lang/Object;)V
|
||||||
|
public fun timesAssign (Lspace/kscience/kmath/torch/TorchTensorReal;D)V
|
||||||
|
public synthetic fun wrap$kmath_torch (J)Lspace/kscience/kmath/torch/TorchTensorJVM;
|
||||||
|
}
|
||||||
|
|
194
kmath-torch/build.gradle.kts
Normal file
194
kmath-torch/build.gradle.kts
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
import de.undercouch.gradle.tasks.download.Download
|
||||||
|
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
|
||||||
|
import org.gradle.api.JavaVersion.VERSION_11
|
||||||
|
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.gradle.mpp")
|
||||||
|
id("de.undercouch.download")
|
||||||
|
}
|
||||||
|
|
||||||
|
java {
|
||||||
|
sourceCompatibility = VERSION_11
|
||||||
|
targetCompatibility = VERSION_11
|
||||||
|
}
|
||||||
|
|
||||||
|
val home = System.getProperty("user.home")
|
||||||
|
val javaHome = System.getProperty("java.home")
|
||||||
|
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
||||||
|
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
||||||
|
val cppSources = projectDir.resolve("src/cppMain")
|
||||||
|
|
||||||
|
val cudaHome: String? = System.getenv("CUDA_HOME")
|
||||||
|
val cudaDefault = file("/usr/local/cuda").exists()
|
||||||
|
val cudaFound = cudaHome?.isNotEmpty() ?: false or cudaDefault
|
||||||
|
|
||||||
|
val cmakeArchive = "cmake-3.19.2-Linux-x86_64"
|
||||||
|
val torchArchive = "libtorch"
|
||||||
|
|
||||||
|
val cmakeCmd = "$thirdPartyDir/$cmakeArchive/bin/cmake"
|
||||||
|
val ninjaCmd = "$thirdPartyDir/ninja"
|
||||||
|
|
||||||
|
val downloadCMake by tasks.registering(Download::class) {
|
||||||
|
val tarFile = "$cmakeArchive.tar.gz"
|
||||||
|
src("https://github.com/Kitware/CMake/releases/download/v3.19.2/$tarFile")
|
||||||
|
dest(File(thirdPartyDir, tarFile))
|
||||||
|
overwrite(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
val downloadNinja by tasks.registering(Download::class) {
|
||||||
|
src("https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip")
|
||||||
|
dest(File(thirdPartyDir, "ninja-linux.zip"))
|
||||||
|
overwrite(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
val downloadTorch by tasks.registering(Download::class) {
|
||||||
|
val abiMeta = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2B"
|
||||||
|
val cudaUrl = "https://download.pytorch.org/libtorch/cu110/${abiMeta}cu110.zip"
|
||||||
|
val cpuUrl = "https://download.pytorch.org/libtorch/cpu/${abiMeta}cpu.zip"
|
||||||
|
val url = if (cudaFound) cudaUrl else cpuUrl
|
||||||
|
src(url)
|
||||||
|
dest(File(thirdPartyDir, "$torchArchive.zip"))
|
||||||
|
overwrite(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
val extractCMake by tasks.registering(Copy::class) {
|
||||||
|
dependsOn(downloadCMake)
|
||||||
|
from(tarTree(resources.gzip(downloadCMake.get().dest)))
|
||||||
|
into(thirdPartyDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
val extractTorch by tasks.registering(Copy::class) {
|
||||||
|
dependsOn(downloadTorch)
|
||||||
|
from(zipTree(downloadTorch.get().dest))
|
||||||
|
into(thirdPartyDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
val extractNinja by tasks.registering(Copy::class) {
|
||||||
|
dependsOn(downloadNinja)
|
||||||
|
from(zipTree(downloadNinja.get().dest))
|
||||||
|
into(thirdPartyDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
val configureCpp by tasks.registering {
|
||||||
|
dependsOn(extractCMake)
|
||||||
|
dependsOn(extractNinja)
|
||||||
|
dependsOn(extractTorch)
|
||||||
|
onlyIf { !file(cppBuildDir).exists() }
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(thirdPartyDir)
|
||||||
|
commandLine("mkdir", "-p", cppBuildDir)
|
||||||
|
}
|
||||||
|
exec {
|
||||||
|
workingDir(cppBuildDir)
|
||||||
|
commandLine(
|
||||||
|
cmakeCmd,
|
||||||
|
cppSources,
|
||||||
|
"-GNinja",
|
||||||
|
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
|
||||||
|
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
|
||||||
|
"-DJAVA_HOME=$javaHome",
|
||||||
|
"-DCMAKE_BUILD_TYPE=Release"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val cleanCppBuild by tasks.registering {
|
||||||
|
onlyIf { file(cppBuildDir).exists() }
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(thirdPartyDir)
|
||||||
|
commandLine("rm", "-rf", cppBuildDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val buildCpp by tasks.registering {
|
||||||
|
dependsOn(configureCpp)
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(cppBuildDir)
|
||||||
|
commandLine(cmakeCmd, "--build", ".", "--config", "Release")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val generateJNIHeader by tasks.registering {
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(projectDir.resolve("src/jvmMain/java/space/kscience/kmath/torch"))
|
||||||
|
commandLine("$javaHome/bin/javac", "-h", cppSources.resolve("include") , "JTorch.java")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kotlin {
|
||||||
|
explicitApiWarning()
|
||||||
|
|
||||||
|
jvm {
|
||||||
|
withJava()
|
||||||
|
}
|
||||||
|
|
||||||
|
val nativeTarget = linuxX64("native")
|
||||||
|
nativeTarget.apply {
|
||||||
|
binaries {
|
||||||
|
all {
|
||||||
|
linkerOpts(
|
||||||
|
"-L$cppBuildDir",
|
||||||
|
"-Wl,-rpath=$cppBuildDir",
|
||||||
|
"-lctorch"
|
||||||
|
)
|
||||||
|
optimized = true
|
||||||
|
debuggable = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val main by nativeTarget.compilations.getting {
|
||||||
|
cinterops {
|
||||||
|
val libctorch by creating {
|
||||||
|
includeDirs(cppSources.resolve("include"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val test by nativeTarget.compilations.getting
|
||||||
|
|
||||||
|
sourceSets {
|
||||||
|
val commonMain by getting {
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val nativeMain by getting {
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val jvmMain by getting {
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val native: KotlinNativeTarget by kotlin.targets
|
||||||
|
tasks[native.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
||||||
|
.dependsOn(buildCpp)
|
||||||
|
|
||||||
|
tasks["jvmProcessResources"].dependsOn(buildCpp)
|
||||||
|
|
||||||
|
tasks {
|
||||||
|
withType<Test>{
|
||||||
|
systemProperty("java.library.path", cppBuildDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No JS implementation
|
||||||
|
project.gradle.startParameter.excludedTaskNames.add("jsTest")
|
||||||
|
project.gradle.startParameter.excludedTaskNames.add("jsBrowserTest")
|
@ -0,0 +1,24 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
|
public sealed class Device {
|
||||||
|
public object CPU: Device() {
|
||||||
|
override fun toString(): String {
|
||||||
|
return "CPU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
public data class CUDA(val index: Int): Device()
|
||||||
|
public fun toInt(): Int {
|
||||||
|
when(this) {
|
||||||
|
is CPU -> return 0
|
||||||
|
is CUDA -> return this.index + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
public companion object {
|
||||||
|
public fun fromInt(deviceInt: Int): Device {
|
||||||
|
return if (deviceInt == 0) CPU else CUDA(
|
||||||
|
deviceInt - 1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.*
|
||||||
|
|
||||||
|
public interface TorchTensor<T> : TensorStructure<T> {
|
||||||
|
|
||||||
|
public val strides: IntArray
|
||||||
|
public val size: Int
|
||||||
|
public val device: Device
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
|
if (dimension == 0) {
|
||||||
|
return emptySequence()
|
||||||
|
}
|
||||||
|
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
|
||||||
|
return indices.map { it to get(it) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface TorchTensorOverField<T>: TorchTensor<T>
|
||||||
|
{
|
||||||
|
public var requiresGrad: Boolean
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,105 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.*
|
||||||
|
|
||||||
|
public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>> :
|
||||||
|
TensorAlgebra<T, TorchTensorType> {
|
||||||
|
|
||||||
|
public fun getNumThreads(): Int
|
||||||
|
public fun setNumThreads(numThreads: Int): Unit
|
||||||
|
public fun cudaAvailable(): Boolean
|
||||||
|
public fun setSeed(seed: Int): Unit
|
||||||
|
|
||||||
|
public var checks: Boolean
|
||||||
|
|
||||||
|
public fun copyFromArray(
|
||||||
|
array: PrimitiveArrayType,
|
||||||
|
shape: IntArray,
|
||||||
|
device: Device = Device.CPU
|
||||||
|
): TorchTensorType
|
||||||
|
|
||||||
|
public fun TorchTensorType.copyToArray(): PrimitiveArrayType
|
||||||
|
|
||||||
|
public fun full(value: T, shape: IntArray, device: Device): TorchTensorType
|
||||||
|
|
||||||
|
public fun randIntegral(
|
||||||
|
low: Long, high: Long, shape: IntArray,
|
||||||
|
device: Device = Device.CPU
|
||||||
|
): TorchTensorType
|
||||||
|
public fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType
|
||||||
|
public fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit
|
||||||
|
|
||||||
|
public fun TorchTensorType.copy(): TorchTensorType
|
||||||
|
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType
|
||||||
|
public infix fun TorchTensorType.swap(other: TorchTensorType)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>> :
|
||||||
|
TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>, TensorPartialDivisionAlgebra<T, TorchTensorType> {
|
||||||
|
|
||||||
|
public fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||||
|
public fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||||
|
public fun TorchTensorType.randUniform(): TorchTensorType
|
||||||
|
public fun TorchTensorType.randUniformAssign(): Unit
|
||||||
|
public fun TorchTensorType.randNormal(): TorchTensorType
|
||||||
|
public fun TorchTensorType.randNormalAssign(): Unit
|
||||||
|
|
||||||
|
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType
|
||||||
|
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||||
|
this.grad(variable, false)
|
||||||
|
|
||||||
|
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType
|
||||||
|
public fun TorchTensorType.detachFromGraph(): TorchTensorType
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.withChecks(block: TorchTensorAlgebraType.() -> Unit): Unit {
|
||||||
|
val state = this.checks
|
||||||
|
this.checks = true
|
||||||
|
this.block()
|
||||||
|
this.checks = state
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkDeviceCompatible(
|
||||||
|
a: TorchTensorType, b: TorchTensorType
|
||||||
|
): Unit =
|
||||||
|
check(a.device == b.device) {
|
||||||
|
"Tensors must be on the same device"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkLinearOperation(
|
||||||
|
a: TorchTensorType,
|
||||||
|
b: TorchTensorType
|
||||||
|
) {
|
||||||
|
if (a.isNotValue() and b.isNotValue()) {
|
||||||
|
checkDeviceCompatible(a, b)
|
||||||
|
checkShapeCompatible(a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
||||||
|
checkDeviceCompatible(a, b)
|
||||||
|
checkDot(a,b)
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||||
|
TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorDivisionAlgebraType.withGradAt(
|
||||||
|
tensor: TorchTensorType,
|
||||||
|
block: TorchTensorDivisionAlgebraType.(TorchTensorType) -> TorchTensorType
|
||||||
|
): TorchTensorType {
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
return this.block(tensor)
|
||||||
|
}
|
@ -0,0 +1,15 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.memory.DeferScope
|
||||||
|
|
||||||
|
public abstract class TorchTensorMemoryHolder internal constructor(
|
||||||
|
public val scope: DeferScope
|
||||||
|
) {
|
||||||
|
init {
|
||||||
|
scope.defer(::close)
|
||||||
|
}
|
||||||
|
protected abstract fun close(): Unit
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean = false
|
||||||
|
override fun hashCode(): Int = 0
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.time.measureTime
|
||||||
|
|
||||||
|
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkMatMul(
|
||||||
|
scale: Int,
|
||||||
|
numWarmUp: Int,
|
||||||
|
numIter: Int,
|
||||||
|
fieldName: String,
|
||||||
|
device: Device = Device.CPU
|
||||||
|
): Unit {
|
||||||
|
println("Benchmarking $scale x $scale $fieldName matrices on $device: ")
|
||||||
|
setSeed(SEED)
|
||||||
|
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
|
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
|
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||||
|
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||||
|
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,132 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.time.measureTime
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkRand(
|
||||||
|
samples: Int,
|
||||||
|
numWarmUp: Int,
|
||||||
|
numIter: Int,
|
||||||
|
device: Device,
|
||||||
|
distName: String,
|
||||||
|
initBock: TorchTensorAlgebraType.(IntArray, Device) -> TorchTensorType,
|
||||||
|
runBlock: TorchTensorAlgebraType.(TorchTensorType) -> Unit
|
||||||
|
): Unit{
|
||||||
|
println("Benchmarking generation of $samples $distName samples on $device: ")
|
||||||
|
setSeed(SEED)
|
||||||
|
val shape = intArrayOf(samples)
|
||||||
|
val tensor = this.initBock(shape,device)
|
||||||
|
repeat(numWarmUp) { this.runBlock(tensor) }
|
||||||
|
val measuredTime = measureTime { repeat(numIter) { this.runBlock(tensor) } }
|
||||||
|
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkRandNormal(
|
||||||
|
samples: Int,
|
||||||
|
numWarmUp: Int,
|
||||||
|
numIter: Int,
|
||||||
|
device: Device = Device.CPU): Unit{
|
||||||
|
benchmarkRand(
|
||||||
|
samples,
|
||||||
|
numWarmUp,
|
||||||
|
numIter,
|
||||||
|
device,
|
||||||
|
"Normal",
|
||||||
|
{sh, dc -> randNormal(shape = sh, device = dc)},
|
||||||
|
{ten -> ten.randNormalAssign() }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkRandUniform(
|
||||||
|
samples: Int,
|
||||||
|
numWarmUp: Int,
|
||||||
|
numIter: Int,
|
||||||
|
device: Device = Device.CPU): Unit{
|
||||||
|
benchmarkRand(
|
||||||
|
samples,
|
||||||
|
numWarmUp,
|
||||||
|
numIter,
|
||||||
|
device,
|
||||||
|
"Uniform",
|
||||||
|
{sh, dc -> randUniform(shape = sh, device = dc)},
|
||||||
|
{ten -> ten.randUniformAssign() }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkRandIntegral(
|
||||||
|
samples: Int,
|
||||||
|
numWarmUp: Int,
|
||||||
|
numIter: Int,
|
||||||
|
device: Device = Device.CPU): Unit{
|
||||||
|
benchmarkRand(
|
||||||
|
samples,
|
||||||
|
numWarmUp,
|
||||||
|
numIter,
|
||||||
|
device,
|
||||||
|
"integer [0,100]",
|
||||||
|
{sh, dc -> randIntegral(0, 100, shape = sh, device = dc)},
|
||||||
|
{ten -> ten.randIntegralAssign(0, 100) }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkingRand1(): Unit {
|
||||||
|
benchmarkRandNormal(10, 10, 100000)
|
||||||
|
benchmarkRandUniform(10, 10, 100000)
|
||||||
|
benchmarkRandIntegral(10, 10, 100000)
|
||||||
|
if(cudaAvailable()) {
|
||||||
|
benchmarkRandNormal(10, 10, 100000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandUniform(10, 10, 100000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandIntegral(10, 10, 100000, device = Device.CUDA(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkingRand3(): Unit {
|
||||||
|
benchmarkRandNormal(1000, 10, 10000)
|
||||||
|
benchmarkRandUniform(1000, 10, 10000)
|
||||||
|
benchmarkRandIntegral(1000, 10, 10000)
|
||||||
|
if(cudaAvailable()) {
|
||||||
|
benchmarkRandNormal(1000, 10, 100000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandUniform(1000, 10, 100000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandIntegral(1000, 10, 100000, device = Device.CUDA(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkingRand5(): Unit {
|
||||||
|
benchmarkRandNormal(100000, 5, 100)
|
||||||
|
benchmarkRandUniform(100000, 5, 100)
|
||||||
|
benchmarkRandIntegral(100000, 5, 100)
|
||||||
|
if(cudaAvailable()){
|
||||||
|
benchmarkRandNormal(100000, 10, 100000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandUniform(100000, 10, 100000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandIntegral(100000, 10, 100000, device = Device.CUDA(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.benchmarkingRand7(): Unit {
|
||||||
|
benchmarkRandNormal(10000000, 3, 20)
|
||||||
|
benchmarkRandUniform(10000000, 3, 20)
|
||||||
|
benchmarkRandIntegral(10000000, 3, 20)
|
||||||
|
if(cudaAvailable()){
|
||||||
|
benchmarkRandNormal(10000000, 10, 10000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandUniform(10000000, 10, 10000, device = Device.CUDA(0))
|
||||||
|
benchmarkRandIntegral(10000000, 10, 10000, device = Device.CUDA(0))
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingAutoGrad(device: Device = Device.CPU): Unit {
|
||||||
|
setSeed(SEED)
|
||||||
|
val dim = 3
|
||||||
|
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||||
|
val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
|
||||||
|
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
|
||||||
|
val expressionAtX = withGradAt(tensorX, { x ->
|
||||||
|
0.5f * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9f
|
||||||
|
})
|
||||||
|
|
||||||
|
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||||
|
val hessianAtX = expressionAtX hess tensorX
|
||||||
|
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
||||||
|
|
||||||
|
val error = (gradientAtX - expectedGradientAtX).abs().sum().value() +
|
||||||
|
(hessianAtX - tensorSigma).abs().sum().value()
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingBatchedAutoGrad(device: Device = Device.CPU): Unit {
|
||||||
|
setSeed(SEED)
|
||||||
|
val batch = intArrayOf(2)
|
||||||
|
val dim = 2
|
||||||
|
val tensorX = randNormal(shape = batch + intArrayOf(1, dim), device = device)
|
||||||
|
val randFeatures = randNormal(shape = batch + intArrayOf(dim, dim), device = device)
|
||||||
|
val tensorSigma = randFeatures + randFeatures.transpose(-2, -1)
|
||||||
|
val tensorMu = randNormal(shape = batch + intArrayOf(1, dim), device = device)
|
||||||
|
|
||||||
|
val expressionAtX = withGradAt(tensorX, { x ->
|
||||||
|
val xt = x.transpose(-1, -2)
|
||||||
|
0.5f * (x dot (tensorSigma dot xt)) + (tensorMu dot xt) + 58.2f
|
||||||
|
})
|
||||||
|
expressionAtX.sumAssign()
|
||||||
|
|
||||||
|
val gradientAtX = expressionAtX grad tensorX
|
||||||
|
val expectedGradientAtX = (tensorX dot tensorSigma) + tensorMu
|
||||||
|
|
||||||
|
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,53 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingCopying(device: Device = Device.CPU): Unit {
|
||||||
|
val array = (1..24).map { 10f * it * it }.toFloatArray()
|
||||||
|
val shape = intArrayOf(2, 3, 4)
|
||||||
|
val tensor = copyFromArray(array, shape = shape, device = device)
|
||||||
|
val copyOfTensor = tensor.copy()
|
||||||
|
tensor[intArrayOf(1, 2, 3)] = 0.1f
|
||||||
|
assertTrue(copyOfTensor.copyToArray() contentEquals array)
|
||||||
|
assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)])
|
||||||
|
if(device != Device.CPU){
|
||||||
|
val normalCpu = randNormal(intArrayOf(2, 3))
|
||||||
|
val normalGpu = normalCpu.copyToDevice(device)
|
||||||
|
assertTrue(normalCpu.copyToArray() contentEquals normalGpu.copyToArray())
|
||||||
|
|
||||||
|
val uniformGpu = randUniform(intArrayOf(3,2),device)
|
||||||
|
val uniformCpu = uniformGpu.copyToDevice(Device.CPU)
|
||||||
|
assertTrue(uniformGpu.copyToArray() contentEquals uniformCpu.copyToArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingRequiresGrad(): Unit {
|
||||||
|
val tensor = randNormal(intArrayOf(3))
|
||||||
|
assertTrue(!tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
assertTrue(tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = false
|
||||||
|
assertTrue(!tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
val detachedTensor = tensor.detachFromGraph()
|
||||||
|
assertTrue(!detachedTensor.requiresGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensor<Int>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<Int, IntArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingViewWithNoCopy(device: Device = Device.CPU) {
|
||||||
|
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
|
||||||
|
val viewTensor = tensor.view(intArrayOf(2, 3))
|
||||||
|
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
|
||||||
|
viewTensor[intArrayOf(0, 0)] = 10
|
||||||
|
assertEquals(tensor[intArrayOf(0)], 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,133 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.linear.RealMatrixContext
|
||||||
|
import space.kscience.kmath.linear.Matrix
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import kotlin.math.*
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingScalarProduct(device: Device = Device.CPU): Unit {
|
||||||
|
val lhs = randUniform(shape = intArrayOf(3), device = device)
|
||||||
|
val rhs = randUniform(shape = intArrayOf(3), device = device)
|
||||||
|
val product = lhs dot rhs
|
||||||
|
var expected = 0.0
|
||||||
|
lhs.elements().forEach {
|
||||||
|
expected += it.second * rhs[it.first]
|
||||||
|
}
|
||||||
|
assertTrue(abs(expected - product.value()) < TOLERANCE)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingMatrixMultiplication(device: Device = Device.CPU): Unit {
|
||||||
|
setSeed(SEED)
|
||||||
|
|
||||||
|
val lhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||||
|
val rhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||||
|
val product = lhsTensor dot rhsTensor
|
||||||
|
|
||||||
|
val expected: Matrix<Double> = RealMatrixContext {
|
||||||
|
val lhs = produce(3, 3) { i, j -> lhsTensor[intArrayOf(i, j)] }
|
||||||
|
val rhs = produce(3, 3) { i, j -> rhsTensor[intArrayOf(i, j)] }
|
||||||
|
lhs dot rhs
|
||||||
|
}
|
||||||
|
|
||||||
|
val lhsTensorCopy = lhsTensor.copy()
|
||||||
|
val rhsTensorCopy = rhsTensor.copy()
|
||||||
|
|
||||||
|
lhsTensorCopy dotAssign rhsTensor
|
||||||
|
lhsTensor dotRightAssign rhsTensorCopy
|
||||||
|
|
||||||
|
var error = 0.0
|
||||||
|
product.elements().forEach {
|
||||||
|
error += abs(expected[it.first] - it.second) +
|
||||||
|
abs(expected[it.first] - lhsTensorCopy[it.first]) +
|
||||||
|
abs(expected[it.first] - rhsTensorCopy[it.first])
|
||||||
|
}
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingLinearStructure(device: Device = Device.CPU): Unit {
|
||||||
|
|
||||||
|
val shape = intArrayOf(3)
|
||||||
|
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||||
|
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||||
|
val tensorC = full(value = 789.3, shape = shape, device = device)
|
||||||
|
val tensorD = full(value = -72.9, shape = shape, device = device)
|
||||||
|
val tensorE = full(value = 553.1, shape = shape, device = device)
|
||||||
|
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
|
||||||
|
val expected = copyFromArray(
|
||||||
|
array = (1..3).map {
|
||||||
|
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
|
||||||
|
}
|
||||||
|
.toDoubleArray(),
|
||||||
|
shape = shape,
|
||||||
|
device = device
|
||||||
|
)
|
||||||
|
|
||||||
|
val assignResult = full(value = 0.0, shape = shape, device = device)
|
||||||
|
tensorA *= 15.8
|
||||||
|
tensorB *= 1.5
|
||||||
|
tensorB *= -tensorD
|
||||||
|
tensorC *= 0.02
|
||||||
|
tensorC /= tensorE
|
||||||
|
assignResult += tensorA
|
||||||
|
assignResult -= tensorB
|
||||||
|
assignResult += tensorC
|
||||||
|
assignResult += -39.4
|
||||||
|
|
||||||
|
val error = (expected - result).abs().sum().value() +
|
||||||
|
(expected - assignResult).abs().sum().value()
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingTensorTransformations(device: Device = Device.CPU): Unit {
|
||||||
|
setSeed(SEED)
|
||||||
|
val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||||
|
val result = tensor.exp().log()
|
||||||
|
val assignResult = tensor.copy()
|
||||||
|
assignResult.transposeAssign(0, 1)
|
||||||
|
assignResult.expAssign()
|
||||||
|
assignResult.logAssign()
|
||||||
|
assignResult.transposeAssign(0, 1)
|
||||||
|
val error = tensor - result
|
||||||
|
error.absAssign()
|
||||||
|
error.sumAssign()
|
||||||
|
error += (tensor - assignResult).abs().sum()
|
||||||
|
assertTrue(error.value() < TOLERANCE)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingBatchedSVD(device: Device = Device.CPU): Unit {
|
||||||
|
val tensor = randNormal(shape = intArrayOf(7, 5, 3), device = device)
|
||||||
|
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||||
|
val error = tensor - (tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)))
|
||||||
|
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Double>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Double, DoubleArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingBatchedSymEig(device: Device = Device.CPU): Unit {
|
||||||
|
val tensor = randNormal(shape = intArrayOf(5, 5), device = device)
|
||||||
|
val tensorSigma = tensor + tensor.transpose(-2, -1)
|
||||||
|
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||||
|
val error = tensorSigma - (tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)))
|
||||||
|
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
|||||||
|
@file:Suppress("NOTHING_TO_INLINE")
|
||||||
|
|
||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal val SEED = 987654
|
||||||
|
internal val TOLERANCE = 1e-6
|
||||||
|
|
||||||
|
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.withCuda(block: TorchTensorAlgebraType.(Device) -> Unit): Unit {
|
||||||
|
this.block(Device.CPU)
|
||||||
|
if (cudaAvailable()) this.block(Device.CUDA(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingSetNumThreads(): Unit {
|
||||||
|
val numThreads = 2
|
||||||
|
setNumThreads(numThreads)
|
||||||
|
assertEquals(numThreads, getNumThreads())
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorPartialDivisionAlgebra<Float, FloatArray, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.testingSetSeed(device: Device = Device.CPU): Unit {
|
||||||
|
setSeed(SEED)
|
||||||
|
val integral = randIntegral(0, 100, IntArray(0), device = device).value()
|
||||||
|
val normal = randNormal(IntArray(0), device = device).value()
|
||||||
|
val uniform = randUniform(IntArray(0), device = device).value()
|
||||||
|
setSeed(SEED)
|
||||||
|
val nextIntegral = randIntegral(0, 100, IntArray(0), device = device).value()
|
||||||
|
val nextNormal = randNormal(IntArray(0), device = device).value()
|
||||||
|
val nextUniform = randUniform(IntArray(0), device = device).value()
|
||||||
|
assertEquals(normal, nextNormal)
|
||||||
|
assertEquals(uniform, nextUniform)
|
||||||
|
assertEquals(integral, nextIntegral)
|
||||||
|
}
|
34
kmath-torch/src/cppMain/CMakeLists.txt
Normal file
34
kmath-torch/src/cppMain/CMakeLists.txt
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.12)
|
||||||
|
|
||||||
|
project(CTorch LANGUAGES C CXX)
|
||||||
|
|
||||||
|
# Require C++17
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
|
# Build configuration
|
||||||
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
|
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||||
|
endif()
|
||||||
|
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||||
|
|
||||||
|
find_package(Torch REQUIRED)
|
||||||
|
find_package(JNI REQUIRED)
|
||||||
|
|
||||||
|
add_library(ctorch SHARED src/ctorch.cc)
|
||||||
|
target_include_directories(ctorch PRIVATE include)
|
||||||
|
target_link_libraries(ctorch PRIVATE torch)
|
||||||
|
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||||
|
|
||||||
|
add_library(jtorch SHARED src/jtorch.cc)
|
||||||
|
target_include_directories(jtorch PRIVATE include ${JNI_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(jtorch PRIVATE torch)
|
||||||
|
target_compile_options(jtorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||||
|
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
|
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
|
||||||
|
install(TARGETS ctorch
|
||||||
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||||
|
|
||||||
|
install(TARGETS jtorch LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
160
kmath-torch/src/cppMain/include/ctorch.h
Normal file
160
kmath-torch/src/cppMain/include/ctorch.h
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
#ifndef CTORCH
|
||||||
|
#define CTORCH
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef void *TorchTensorHandle;
|
||||||
|
|
||||||
|
int get_num_threads();
|
||||||
|
|
||||||
|
void set_num_threads(int num_threads);
|
||||||
|
|
||||||
|
bool cuda_is_available();
|
||||||
|
|
||||||
|
void set_seed(int seed);
|
||||||
|
|
||||||
|
TorchTensorHandle empty_tensor();
|
||||||
|
|
||||||
|
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
|
||||||
|
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
|
||||||
|
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
|
||||||
|
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
|
||||||
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
||||||
|
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle);
|
||||||
|
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle);
|
||||||
|
TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim);
|
||||||
|
|
||||||
|
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
||||||
|
void dispose_char(char *ptr);
|
||||||
|
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
int get_dim(TorchTensorHandle tensor_handle);
|
||||||
|
int get_numel(TorchTensorHandle tensor_handle);
|
||||||
|
int get_shape_at(TorchTensorHandle tensor_handle, int d);
|
||||||
|
int get_stride_at(TorchTensorHandle tensor_handle, int d);
|
||||||
|
int get_device(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||||
|
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||||
|
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||||
|
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
double get_item_double(TorchTensorHandle tensor_handle);
|
||||||
|
float get_item_float(TorchTensorHandle tensor_handle);
|
||||||
|
long get_item_long(TorchTensorHandle tensor_handle);
|
||||||
|
int get_item_int(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
double get_double(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
float get_float(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
long get_long(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
int get_int(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
void set_double(TorchTensorHandle tensor_handle, int *index, double value);
|
||||||
|
void set_float(TorchTensorHandle tensor_handle, int *index, float value);
|
||||||
|
void set_long(TorchTensorHandle tensor_handle, int *index, long value);
|
||||||
|
void set_int(TorchTensorHandle tensor_handle, int *index, int value);
|
||||||
|
|
||||||
|
TorchTensorHandle rand_double(int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle randn_double(int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
|
||||||
|
|
||||||
|
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device);
|
||||||
|
|
||||||
|
TorchTensorHandle rand_like(TorchTensorHandle tensor_handle);
|
||||||
|
void rand_like_assign(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle randn_like(TorchTensorHandle tensor_handle);
|
||||||
|
void randn_like_assign(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high);
|
||||||
|
void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high);
|
||||||
|
|
||||||
|
|
||||||
|
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
|
||||||
|
|
||||||
|
TorchTensorHandle times_double(double value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle times_float(float value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle times_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
void times_double_assign(double value, TorchTensorHandle other);
|
||||||
|
void times_float_assign(float value, TorchTensorHandle other);
|
||||||
|
void times_long_assign(long value, TorchTensorHandle other);
|
||||||
|
void times_int_assign(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
void plus_double_assign(double value, TorchTensorHandle other);
|
||||||
|
void plus_float_assign(float value, TorchTensorHandle other);
|
||||||
|
void plus_long_assign(long value, TorchTensorHandle other);
|
||||||
|
void plus_int_assign(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
void abs_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
||||||
|
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j);
|
||||||
|
|
||||||
|
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
void exp_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
void log_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
void sum_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
|
||||||
|
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
|
||||||
|
|
||||||
|
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle U_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle);
|
||||||
|
|
||||||
|
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle,
|
||||||
|
bool eigenvectors);
|
||||||
|
|
||||||
|
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||||
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||||
|
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph);
|
||||||
|
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||||
|
TorchTensorHandle autohess_tensor_given_grad(TorchTensorHandle value, TorchTensorHandle variable, TorchTensorHandle gradient);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif //CTORCH
|
@ -0,0 +1,813 @@
|
|||||||
|
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||||
|
#include <jni.h>
|
||||||
|
/* Header for class space_kscience_kmath_torch_JTorch */
|
||||||
|
|
||||||
|
#ifndef _Included_space_kscience_kmath_torch_JTorch
|
||||||
|
#define _Included_space_kscience_kmath_torch_JTorch
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getNumThreads
|
||||||
|
* Signature: ()I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumThreads
|
||||||
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: setNumThreads
|
||||||
|
* Signature: (I)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setNumThreads
|
||||||
|
(JNIEnv *, jclass, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: cudaIsAvailable
|
||||||
|
* Signature: ()Z
|
||||||
|
*/
|
||||||
|
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_cudaIsAvailable
|
||||||
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: setSeed
|
||||||
|
* Signature: (I)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setSeed
|
||||||
|
(JNIEnv *, jclass, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: emptyTensor
|
||||||
|
* Signature: ()J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_emptyTensor
|
||||||
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fromBlobDouble
|
||||||
|
* Signature: ([D[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobDouble
|
||||||
|
(JNIEnv *, jclass, jdoubleArray, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fromBlobFloat
|
||||||
|
* Signature: ([F[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobFloat
|
||||||
|
(JNIEnv *, jclass, jfloatArray, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fromBlobLong
|
||||||
|
* Signature: ([J[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobLong
|
||||||
|
(JNIEnv *, jclass, jlongArray, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fromBlobInt
|
||||||
|
* Signature: ([I[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobInt
|
||||||
|
(JNIEnv *, jclass, jintArray, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: copyTensor
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: copyToDevice
|
||||||
|
* Signature: (JI)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDevice
|
||||||
|
(JNIEnv *, jclass, jlong, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: copyToDouble
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDouble
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: copyToFloat
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToFloat
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: copyToLong
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToLong
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: copyToInt
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToInt
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: swapTensors
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_swapTensors
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: viewTensor
|
||||||
|
* Signature: (J[I)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_viewTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: tensorToString
|
||||||
|
* Signature: (J)Ljava/lang/String;
|
||||||
|
*/
|
||||||
|
JNIEXPORT jstring JNICALL Java_space_kscience_kmath_torch_JTorch_tensorToString
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: disposeTensor
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_disposeTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getDim
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDim
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getNumel
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumel
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getShapeAt
|
||||||
|
* Signature: (JI)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getShapeAt
|
||||||
|
(JNIEnv *, jclass, jlong, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getStrideAt
|
||||||
|
* Signature: (JI)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getStrideAt
|
||||||
|
(JNIEnv *, jclass, jlong, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getDevice
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDevice
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getItemDouble
|
||||||
|
* Signature: (J)D
|
||||||
|
*/
|
||||||
|
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getItemDouble
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getItemFloat
|
||||||
|
* Signature: (J)F
|
||||||
|
*/
|
||||||
|
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getItemFloat
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getItemLong
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getItemLong
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getItemInt
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getItemInt
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getDouble
|
||||||
|
* Signature: (J[I)D
|
||||||
|
*/
|
||||||
|
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getDouble
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getFloat
|
||||||
|
* Signature: (J[I)F
|
||||||
|
*/
|
||||||
|
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getFloat
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getLong
|
||||||
|
* Signature: (J[I)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getLong
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: getInt
|
||||||
|
* Signature: (J[I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getInt
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: setDouble
|
||||||
|
* Signature: (J[ID)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setDouble
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray, jdouble);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: setFloat
|
||||||
|
* Signature: (J[IF)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setFloat
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray, jfloat);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: setLong
|
||||||
|
* Signature: (J[IJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setLong
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: setInt
|
||||||
|
* Signature: (J[II)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setInt
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randDouble
|
||||||
|
* Signature: ([II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randDouble
|
||||||
|
(JNIEnv *, jclass, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randnDouble
|
||||||
|
* Signature: ([II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnDouble
|
||||||
|
(JNIEnv *, jclass, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randFloat
|
||||||
|
* Signature: ([II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randFloat
|
||||||
|
(JNIEnv *, jclass, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randnFloat
|
||||||
|
* Signature: ([II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnFloat
|
||||||
|
(JNIEnv *, jclass, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randintDouble
|
||||||
|
* Signature: (JJ[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintDouble
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randintFloat
|
||||||
|
* Signature: (JJ[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintFloat
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randintLong
|
||||||
|
* Signature: (JJ[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLong
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randintInt
|
||||||
|
* Signature: (JJ[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintInt
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randLike
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randLike
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randLikeAssign
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randLikeAssign
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randnLike
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnLike
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randnLikeAssign
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randnLikeAssign
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randintLike
|
||||||
|
* Signature: (JJJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLike
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: randintLikeAssign
|
||||||
|
* Signature: (JJJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randintLikeAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fullDouble
|
||||||
|
* Signature: (D[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullDouble
|
||||||
|
(JNIEnv *, jclass, jdouble, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fullFloat
|
||||||
|
* Signature: (F[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullFloat
|
||||||
|
(JNIEnv *, jclass, jfloat, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fullLong
|
||||||
|
* Signature: (J[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullLong
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: fullInt
|
||||||
|
* Signature: (I[II)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullInt
|
||||||
|
(JNIEnv *, jclass, jint, jintArray, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesDouble
|
||||||
|
* Signature: (DJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesDouble
|
||||||
|
(JNIEnv *, jclass, jdouble, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesFloat
|
||||||
|
* Signature: (FJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloat
|
||||||
|
(JNIEnv *, jclass, jfloat, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesLong
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesLong
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesInt
|
||||||
|
* Signature: (IJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesInt
|
||||||
|
(JNIEnv *, jclass, jint, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesDoubleAssign
|
||||||
|
* Signature: (DJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesDoubleAssign
|
||||||
|
(JNIEnv *, jclass, jdouble, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesFloatAssign
|
||||||
|
* Signature: (FJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloatAssign
|
||||||
|
(JNIEnv *, jclass, jfloat, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesLongAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesLongAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesIntAssign
|
||||||
|
* Signature: (IJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesIntAssign
|
||||||
|
(JNIEnv *, jclass, jint, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusDouble
|
||||||
|
* Signature: (DJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusDouble
|
||||||
|
(JNIEnv *, jclass, jdouble, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusFloat
|
||||||
|
* Signature: (FJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloat
|
||||||
|
(JNIEnv *, jclass, jfloat, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusLong
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusLong
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusInt
|
||||||
|
* Signature: (IJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusInt
|
||||||
|
(JNIEnv *, jclass, jint, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusDoubleAssign
|
||||||
|
* Signature: (DJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusDoubleAssign
|
||||||
|
(JNIEnv *, jclass, jdouble, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusFloatAssign
|
||||||
|
* Signature: (FJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloatAssign
|
||||||
|
(JNIEnv *, jclass, jfloat, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusLongAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusLongAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusIntAssign
|
||||||
|
* Signature: (IJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusIntAssign
|
||||||
|
(JNIEnv *, jclass, jint, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesTensor
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: timesTensorAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: divTensor
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_divTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: divTensorAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_divTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusTensor
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: plusTensorAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: minusTensor
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: minusTensorAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: unaryMinus
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_unaryMinus
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: absTensor
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_absTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: absTensorAssign
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_absTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: transposeTensor
|
||||||
|
* Signature: (JII)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: transposeTensorAssign
|
||||||
|
* Signature: (JII)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: expTensor
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_expTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: expTensorAssign
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_expTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: logTensor
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_logTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: logTensorAssign
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_logTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: sumTensor
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: sumTensorAssign
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensorAssign
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: matmul
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_matmul
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: matmulAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: matmulRightAssign
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulRightAssign
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: diagEmbed
|
||||||
|
* Signature: (JIII)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_diagEmbed
|
||||||
|
(JNIEnv *, jclass, jlong, jint, jint, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: svdTensor
|
||||||
|
* Signature: (JJJJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_svdTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: symeigTensor
|
||||||
|
* Signature: (JJJZ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_symeigTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jlong, jboolean);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: requiresGrad
|
||||||
|
* Signature: (J)Z
|
||||||
|
*/
|
||||||
|
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_requiresGrad
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: setRequiresGrad
|
||||||
|
* Signature: (JZ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setRequiresGrad
|
||||||
|
(JNIEnv *, jclass, jlong, jboolean);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: detachFromGraph
|
||||||
|
* Signature: (J)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_detachFromGraph
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: autogradTensor
|
||||||
|
* Signature: (JJZ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_autogradTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong, jboolean);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_torch_JTorch
|
||||||
|
* Method: autohessTensor
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_autohessTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
151
kmath-torch/src/cppMain/include/utils.hh
Normal file
151
kmath-torch/src/cppMain/include/utils.hh
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
#include <torch/torch.h>
|
||||||
|
|
||||||
|
namespace ctorch
|
||||||
|
{
|
||||||
|
|
||||||
|
using TorchTensorHandle = void *;
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline c10::ScalarType dtype()
|
||||||
|
{
|
||||||
|
return torch::kFloat64;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline c10::ScalarType dtype<float>()
|
||||||
|
{
|
||||||
|
return torch::kFloat32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline c10::ScalarType dtype<long>()
|
||||||
|
{
|
||||||
|
return torch::kInt64;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline c10::ScalarType dtype<int>()
|
||||||
|
{
|
||||||
|
return torch::kInt32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Handle>
|
||||||
|
inline torch::Tensor &cast(const Handle &tensor_handle)
|
||||||
|
{
|
||||||
|
return *static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Handle>
|
||||||
|
inline void dispose_tensor(const Handle &tensor_handle)
|
||||||
|
{
|
||||||
|
delete static_cast<torch::Tensor *>((TorchTensorHandle)tensor_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string tensor_to_string(const torch::Tensor &tensor)
|
||||||
|
{
|
||||||
|
std::stringstream bufrep;
|
||||||
|
bufrep << tensor;
|
||||||
|
return bufrep.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline char *tensor_to_char(const torch::Tensor &tensor)
|
||||||
|
{
|
||||||
|
auto rep = tensor_to_string(tensor);
|
||||||
|
char *crep = (char *)malloc(rep.length() + 1);
|
||||||
|
std::strcpy(crep, rep.c_str());
|
||||||
|
return crep;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int device_to_int(const torch::Tensor &tensor)
|
||||||
|
{
|
||||||
|
return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline torch::Device int_to_device(int device_int)
|
||||||
|
{
|
||||||
|
return (device_int == 0) ? torch::kCPU : torch::Device(torch::kCUDA, device_int - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<int64_t> to_vec_int(int *arr, int arr_size)
|
||||||
|
{
|
||||||
|
auto vec = std::vector<int64_t>(arr_size);
|
||||||
|
vec.assign(arr, arr + arr_size);
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<at::indexing::TensorIndex> to_index(int *arr, int arr_size)
|
||||||
|
{
|
||||||
|
std::vector<at::indexing::TensorIndex> index;
|
||||||
|
for (int i = 0; i < arr_size; i++)
|
||||||
|
{
|
||||||
|
index.emplace_back(arr[i]);
|
||||||
|
}
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor from_blob(Dtype *data, const std::vector<int64_t> &shape, torch::Device device, bool copy)
|
||||||
|
{
|
||||||
|
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename NumType>
|
||||||
|
inline NumType get(const torch::Tensor &tensor, int *index)
|
||||||
|
{
|
||||||
|
return tensor.index(to_index(index, tensor.dim())).item<NumType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename NumType>
|
||||||
|
inline void set(const torch::Tensor &tensor, int *index, NumType value)
|
||||||
|
{
|
||||||
|
tensor.index(to_index(index, tensor.dim())) = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor randn(const std::vector<int64_t> &shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::randn(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor rand(const std::vector<int64_t> &shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor randint(long low, long high, const std::vector<int64_t> &shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor full(Dtype value, const std::vector<int64_t> &shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline torch::Tensor hessian(const torch::Tensor &value, const torch::Tensor &variable)
|
||||||
|
{
|
||||||
|
auto nelem = variable.numel();
|
||||||
|
auto hess = value.new_zeros({nelem, nelem});
|
||||||
|
auto grad = torch::autograd::grad({value}, {variable}, {}, torch::nullopt, true)[0].view(nelem);
|
||||||
|
int i = 0;
|
||||||
|
for (int j = 0; j < nelem; j++)
|
||||||
|
{
|
||||||
|
auto row = grad[j].requires_grad()
|
||||||
|
? torch::autograd::grad({grad[i]}, {variable}, {}, true, true, true)[0].view(nelem).slice(0, j, nelem)
|
||||||
|
: grad[j].new_zeros(nelem - j);
|
||||||
|
hess[i].slice(0, i, nelem).add_(row.type_as(hess));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
auto ndim = variable.dim();
|
||||||
|
auto sizes = variable.sizes().data();
|
||||||
|
auto shape = std::vector<int64_t>(ndim);
|
||||||
|
shape.assign(sizes, sizes + ndim);
|
||||||
|
shape.reserve(2 * ndim);
|
||||||
|
std::copy_n(shape.begin(), ndim, std::back_inserter(shape));
|
||||||
|
return (hess + torch::triu(hess, 1).t()).view(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ctorch
|
465
kmath-torch/src/cppMain/src/ctorch.cc
Normal file
465
kmath-torch/src/cppMain/src/ctorch.cc
Normal file
@ -0,0 +1,465 @@
|
|||||||
|
#include <torch/torch.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "ctorch.h"
|
||||||
|
#include "utils.hh"
|
||||||
|
|
||||||
|
int get_num_threads()
|
||||||
|
{
|
||||||
|
return torch::get_num_threads();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_num_threads(int num_threads)
|
||||||
|
{
|
||||||
|
torch::set_num_threads(num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool cuda_is_available()
|
||||||
|
{
|
||||||
|
return torch::cuda::is_available();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_seed(int seed)
|
||||||
|
{
|
||||||
|
torch::manual_seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle empty_tensor()
|
||||||
|
{
|
||||||
|
return new torch::Tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
|
}
|
||||||
|
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
|
}
|
||||||
|
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
|
}
|
||||||
|
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).clone());
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
|
||||||
|
}
|
||||||
|
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
|
||||||
|
{
|
||||||
|
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||||
|
}
|
||||||
|
TorchTensorHandle view_tensor(TorchTensorHandle tensor_handle, int *shape, int dim)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).view(ctorch::to_vec_int(shape, dim)));
|
||||||
|
}
|
||||||
|
|
||||||
|
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::tensor_to_char(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
void dispose_char(char *ptr)
|
||||||
|
{
|
||||||
|
free(ptr);
|
||||||
|
}
|
||||||
|
void dispose_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::dispose_tensor(tensor_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_dim(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).dim();
|
||||||
|
}
|
||||||
|
int get_numel(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).numel();
|
||||||
|
}
|
||||||
|
int get_shape_at(TorchTensorHandle tensor_handle, int d)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).size(d);
|
||||||
|
}
|
||||||
|
int get_stride_at(TorchTensorHandle tensor_handle, int d)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).stride(d);
|
||||||
|
}
|
||||||
|
int get_device(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||||
|
}
|
||||||
|
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||||
|
}
|
||||||
|
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||||
|
}
|
||||||
|
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
double get_item_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<double>();
|
||||||
|
}
|
||||||
|
float get_item_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<float>();
|
||||||
|
}
|
||||||
|
long get_item_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<long>();
|
||||||
|
}
|
||||||
|
int get_item_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
double get_double(TorchTensorHandle tensor_handle, int *index)
|
||||||
|
{
|
||||||
|
return ctorch::get<double>(ctorch::cast(tensor_handle), index);
|
||||||
|
}
|
||||||
|
float get_float(TorchTensorHandle tensor_handle, int *index)
|
||||||
|
{
|
||||||
|
return ctorch::get<float>(ctorch::cast(tensor_handle), index);
|
||||||
|
}
|
||||||
|
long get_long(TorchTensorHandle tensor_handle, int *index)
|
||||||
|
{
|
||||||
|
return ctorch::get<long>(ctorch::cast(tensor_handle), index);
|
||||||
|
}
|
||||||
|
int get_int(TorchTensorHandle tensor_handle, int *index)
|
||||||
|
{
|
||||||
|
return ctorch::get<int>(ctorch::cast(tensor_handle), index);
|
||||||
|
}
|
||||||
|
void set_double(TorchTensorHandle tensor_handle, int *index, double value)
|
||||||
|
{
|
||||||
|
ctorch::set<double>(ctorch::cast(tensor_handle), index, value);
|
||||||
|
}
|
||||||
|
void set_float(TorchTensorHandle tensor_handle, int *index, float value)
|
||||||
|
{
|
||||||
|
ctorch::set<float>(ctorch::cast(tensor_handle), index, value);
|
||||||
|
}
|
||||||
|
void set_long(TorchTensorHandle tensor_handle, int *index, long value)
|
||||||
|
{
|
||||||
|
ctorch::set<long>(ctorch::cast(tensor_handle), index, value);
|
||||||
|
}
|
||||||
|
void set_int(TorchTensorHandle tensor_handle, int *index, int value)
|
||||||
|
{
|
||||||
|
ctorch::set<int>(ctorch::cast(tensor_handle), index, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle rand_double(int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::rand<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randn_double(int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randn<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle rand_float(int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randn_float(int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randn<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randint<double>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randint<float>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randint_int(long low, long high, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle rand_like(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::rand_like(ctorch::cast(tensor_handle)));
|
||||||
|
}
|
||||||
|
void rand_like_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = torch::rand_like(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randn_like(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::randn_like(ctorch::cast(tensor_handle)));
|
||||||
|
}
|
||||||
|
void randn_like_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randint_like(TorchTensorHandle tensor_handle, long low, long high)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
|
||||||
|
}
|
||||||
|
void randint_like_assign(TorchTensorHandle tensor_handle, long low, long high)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<float>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<long>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
TorchTensorHandle times_float(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
TorchTensorHandle times_long(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
TorchTensorHandle times_int(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
void times_double_assign(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
void times_float_assign(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
void times_long_assign(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
void times_int_assign(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
void plus_double_assign(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
void plus_float_assign(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
void plus_long_assign(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
void plus_int_assign(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) *= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(-ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||||
|
}
|
||||||
|
void abs_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
||||||
|
}
|
||||||
|
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).exp());
|
||||||
|
}
|
||||||
|
void exp_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).log());
|
||||||
|
}
|
||||||
|
void log_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||||
|
}
|
||||||
|
void sum_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||||
|
}
|
||||||
|
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
|
||||||
|
}
|
||||||
|
|
||||||
|
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle U_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle)
|
||||||
|
{
|
||||||
|
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
|
||||||
|
ctorch::cast(U_handle) = U;
|
||||||
|
ctorch::cast(S_handle) = S;
|
||||||
|
ctorch::cast(V_handle) = V;
|
||||||
|
}
|
||||||
|
|
||||||
|
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle,
|
||||||
|
bool eigenvectors)
|
||||||
|
{
|
||||||
|
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||||
|
ctorch::cast(S_handle) = S;
|
||||||
|
ctorch::cast(V_handle) = V;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).requires_grad();
|
||||||
|
}
|
||||||
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||||
|
}
|
||||||
|
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||||
|
}
|
||||||
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable, bool retain_graph)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
|
||||||
|
}
|
||||||
|
TorchTensorHandle autohess_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
|
||||||
|
}
|
572
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
572
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
@ -0,0 +1,572 @@
|
|||||||
|
#include <torch/torch.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "space_kscience_kmath_torch_JTorch.h"
|
||||||
|
#include "utils.hh"
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumThreads(JNIEnv *, jclass)
|
||||||
|
{
|
||||||
|
return torch::get_num_threads();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *, jclass, jint num_threads)
|
||||||
|
{
|
||||||
|
torch::set_num_threads(num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_cudaIsAvailable(JNIEnv *, jclass)
|
||||||
|
{
|
||||||
|
return torch::cuda::is_available();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setSeed(JNIEnv *, jclass, jint seed)
|
||||||
|
{
|
||||||
|
torch::manual_seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_emptyTensor(JNIEnv *, jclass)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobDouble(JNIEnv *env, jclass, jdoubleArray data, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::from_blob<double>(
|
||||||
|
env->GetDoubleArrayElements(data, 0),
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device), true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobFloat(JNIEnv *env, jclass, jfloatArray data, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::from_blob<float>(
|
||||||
|
env->GetFloatArrayElements(data, 0),
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device), true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobLong(JNIEnv *env, jclass, jlongArray data, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::from_blob<long>(
|
||||||
|
env->GetLongArrayElements(data, 0),
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device), true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fromBlobInt(JNIEnv *env, jclass, jintArray data, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::from_blob<int>(
|
||||||
|
env->GetIntArrayElements(data, 0),
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device), true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDevice(JNIEnv *, jclass, jlong tensor_handle, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToDouble(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToFloat(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToLong(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_copyToInt(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_swapTensors(JNIEnv *, jclass, jlong lhs_handle, jlong rhs_handle)
|
||||||
|
{
|
||||||
|
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_viewTensor(JNIEnv *env, jclass, jlong tensor_handle, jintArray shape)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::cast(tensor_handle).view(ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape))));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jstring JNICALL Java_space_kscience_kmath_torch_JTorch_tensorToString(JNIEnv *env, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return env->NewStringUTF(ctorch::tensor_to_string(ctorch::cast(tensor_handle)).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::dispose_tensor(tensor_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDim(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).dim();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getNumel(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).numel();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getShapeAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).size(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getStrideAt(JNIEnv *, jclass, jlong tensor_handle, jint d)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).stride(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getDevice(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getItemDouble(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<double>();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getItemFloat(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getItemLong(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<long>();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getItemInt(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_torch_JTorch_getDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||||
|
{
|
||||||
|
return ctorch::get<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_torch_JTorch_getFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||||
|
{
|
||||||
|
return ctorch::get<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_getLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||||
|
{
|
||||||
|
return ctorch::get<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_torch_JTorch_getInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index)
|
||||||
|
{
|
||||||
|
return ctorch::get<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setDouble(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jdouble value)
|
||||||
|
{
|
||||||
|
ctorch::set<double>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setFloat(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jfloat value)
|
||||||
|
{
|
||||||
|
ctorch::set<float>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setLong(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jlong value)
|
||||||
|
{
|
||||||
|
ctorch::set<long>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setInt(JNIEnv *env, jclass, jlong tensor_handle, jintArray index, jint value)
|
||||||
|
{
|
||||||
|
ctorch::set<int>(ctorch::cast(tensor_handle), env->GetIntArrayElements(index, 0), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randDouble(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::rand<double>(
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnDouble(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::randn<double>(
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randFloat(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::rand<float>(
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnFloat(JNIEnv *env, jclass, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::randn<float>(
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintDouble(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::randint<double>(low, high,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintFloat(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::randint<float>(low, high,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLong(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::randint<long>(low, high,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintInt(JNIEnv *env, jclass, jlong low, jlong high, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::randint<int>(low, high,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randLike(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(torch::rand_like(ctorch::cast(tensor_handle)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = torch::rand_like(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randnLike(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(torch::randn_like(ctorch::cast(tensor_handle)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randnLikeAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = torch::randn_like(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_randintLike(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(torch::randint_like(ctorch::cast(tensor_handle), low, high));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_randintLikeAssign(JNIEnv *, jclass, jlong tensor_handle, jlong low, jlong high)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = torch::randint_like(ctorch::cast(tensor_handle), low, high);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullDouble(JNIEnv *env, jclass, jdouble value, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::full<double>(
|
||||||
|
value,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullFloat(JNIEnv *env, jclass, jfloat value, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::full<float>(
|
||||||
|
value,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullLong(JNIEnv *env, jclass, jlong value, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::full<long>(
|
||||||
|
value,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_fullInt(JNIEnv *env, jclass, jint value, jintArray shape, jint device)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(
|
||||||
|
ctorch::full<int>(
|
||||||
|
value,
|
||||||
|
ctorch::to_vec_int(env->GetIntArrayElements(shape, 0), env->GetArrayLength(shape)),
|
||||||
|
ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesDouble(JNIEnv *, jclass, jdouble value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloat(JNIEnv *, jclass, jfloat value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesLong(JNIEnv *, jclass, jlong value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesInt(JNIEnv *, jclass, jint value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesLongAssign(JNIEnv *, jclass, jlong value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesIntAssign(JNIEnv *, jclass, jint value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusDouble(JNIEnv *, jclass, jdouble value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloat(JNIEnv *, jclass, jfloat value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusLong(JNIEnv *, jclass, jlong value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusInt(JNIEnv *, jclass, jint value, jlong other)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(value + ctorch::cast(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusDoubleAssign(JNIEnv *, jclass, jdouble value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusFloatAssign(JNIEnv *, jclass, jfloat value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusLongAssign(JNIEnv *, jclass, jlong value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusIntAssign(JNIEnv *, jclass, jint value, jlong other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_timesTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) *= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_divTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_divTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_plusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensor(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_minusTensorAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_unaryMinus(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(-ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_absTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_absTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensor(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_transposeTensorAssign(JNIEnv *, jclass, jlong tensor_handle, jint i, jint j)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_expTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).exp());
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_expTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_logTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).log());
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_logTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_sumTensorAssign(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_matmul(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_matmulRightAssign(JNIEnv *, jclass, jlong lhs, jlong rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL
|
||||||
|
Java_space_kscience_kmath_torch_JTorch_diagEmbed(JNIEnv *, jclass, jlong diags_handle, jint offset, jint dim1, jint dim2)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_space_kscience_kmath_torch_JTorch_svdTensor(JNIEnv *, jclass, jlong tensor_handle, jlong U_handle, jlong S_handle, jlong V_handle)
|
||||||
|
{
|
||||||
|
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
|
||||||
|
ctorch::cast(U_handle) = U;
|
||||||
|
ctorch::cast(S_handle) = S;
|
||||||
|
ctorch::cast(V_handle) = V;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_space_kscience_kmath_torch_JTorch_symeigTensor(JNIEnv *, jclass, jlong tensor_handle, jlong S_handle, jlong V_handle, jboolean eigenvectors)
|
||||||
|
{
|
||||||
|
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||||
|
ctorch::cast(S_handle) = S;
|
||||||
|
ctorch::cast(V_handle) = V;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_torch_JTorch_requiresGrad(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).requires_grad();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_torch_JTorch_setRequiresGrad(JNIEnv *, jclass, jlong tensor_handle, jboolean status)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_detachFromGraph(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL
|
||||||
|
Java_space_kscience_kmath_torch_JTorch_autogradTensor(JNIEnv *, jclass, jlong value, jlong variable, jboolean retain_graph)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)}, {}, retain_graph)[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_torch_JTorch_autohessTensor(JNIEnv *, jclass, jlong value, jlong variable)
|
||||||
|
{
|
||||||
|
return (long)new torch::Tensor(ctorch::hessian(ctorch::cast(value), ctorch::cast(variable)));
|
||||||
|
}
|
@ -0,0 +1,208 @@
|
|||||||
|
package space.kscience.kmath.torch;
|
||||||
|
|
||||||
|
class JTorch {
|
||||||
|
|
||||||
|
static {
|
||||||
|
System.loadLibrary("jtorch");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static native int getNumThreads();
|
||||||
|
|
||||||
|
public static native void setNumThreads(int numThreads);
|
||||||
|
|
||||||
|
public static native boolean cudaIsAvailable();
|
||||||
|
|
||||||
|
public static native void setSeed(int seed);
|
||||||
|
|
||||||
|
public static native long emptyTensor();
|
||||||
|
|
||||||
|
public static native long fromBlobDouble(double[] data, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long fromBlobFloat(float[] data, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long fromBlobLong(long[] data, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long fromBlobInt(int[] data, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long copyTensor(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long copyToDevice(long tensorHandle, int device);
|
||||||
|
|
||||||
|
public static native long copyToDouble(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long copyToFloat(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long copyToLong(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long copyToInt(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||||
|
|
||||||
|
public static native long viewTensor(long tensorHandle, int[] shape);
|
||||||
|
|
||||||
|
public static native String tensorToString(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void disposeTensor(long tensorHandle);
|
||||||
|
|
||||||
|
public static native int getDim(long tensorHandle);
|
||||||
|
|
||||||
|
public static native int getNumel(long tensorHandle);
|
||||||
|
|
||||||
|
public static native int getShapeAt(long tensorHandle, int d);
|
||||||
|
|
||||||
|
public static native int getStrideAt(long tensorHandle, int d);
|
||||||
|
|
||||||
|
public static native int getDevice(long tensorHandle);
|
||||||
|
|
||||||
|
public static native double getItemDouble(long tensorHandle);
|
||||||
|
|
||||||
|
public static native float getItemFloat(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long getItemLong(long tensorHandle);
|
||||||
|
|
||||||
|
public static native int getItemInt(long tensorHandle);
|
||||||
|
|
||||||
|
public static native double getDouble(long tensorHandle, int[] index);
|
||||||
|
|
||||||
|
public static native float getFloat(long tensorHandle, int[] index);
|
||||||
|
|
||||||
|
public static native long getLong(long tensorHandle, int[] index);
|
||||||
|
|
||||||
|
public static native int getInt(long tensorHandle, int[] index);
|
||||||
|
|
||||||
|
public static native void setDouble(long tensorHandle, int[] index, double value);
|
||||||
|
|
||||||
|
public static native void setFloat(long tensorHandle, int[] index, float value);
|
||||||
|
|
||||||
|
public static native void setLong(long tensorHandle, int[] index, long value);
|
||||||
|
|
||||||
|
public static native void setInt(long tensorHandle, int[] index, int value);
|
||||||
|
|
||||||
|
public static native long randDouble(int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randnDouble(int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randFloat(int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randnFloat(int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randintDouble(long low, long high, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randintFloat(long low, long high, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randintLong(long low, long high, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randintInt(long low, long high, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long randLike(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void randLikeAssign(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long randnLike(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void randnLikeAssign(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long randintLike(long tensorHandle, long low, long high);
|
||||||
|
|
||||||
|
public static native void randintLikeAssign(long tensorHandle, long low, long high);
|
||||||
|
|
||||||
|
public static native long fullDouble(double value, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long fullFloat(float value, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long fullLong(long value, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long fullInt(int value, int[] shape, int device);
|
||||||
|
|
||||||
|
public static native long timesDouble(double value, long other);
|
||||||
|
|
||||||
|
public static native long timesFloat(float value, long other);
|
||||||
|
|
||||||
|
public static native long timesLong(long value, long other);
|
||||||
|
|
||||||
|
public static native long timesInt(int value, long other);
|
||||||
|
|
||||||
|
public static native void timesDoubleAssign(double value, long other);
|
||||||
|
|
||||||
|
public static native void timesFloatAssign(float value, long other);
|
||||||
|
|
||||||
|
public static native void timesLongAssign(long value, long other);
|
||||||
|
|
||||||
|
public static native void timesIntAssign(int value, long other);
|
||||||
|
|
||||||
|
public static native long plusDouble(double value, long other);
|
||||||
|
|
||||||
|
public static native long plusFloat(float value, long other);
|
||||||
|
|
||||||
|
public static native long plusLong(long value, long other);
|
||||||
|
|
||||||
|
public static native long plusInt(int value, long other);
|
||||||
|
|
||||||
|
public static native void plusDoubleAssign(double value, long other);
|
||||||
|
|
||||||
|
public static native void plusFloatAssign(float value, long other);
|
||||||
|
|
||||||
|
public static native void plusLongAssign(long value, long other);
|
||||||
|
|
||||||
|
public static native void plusIntAssign(int value, long other);
|
||||||
|
|
||||||
|
public static native long timesTensor(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native void timesTensorAssign(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native long divTensor(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native void divTensorAssign(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native long plusTensor(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native void plusTensorAssign(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native long minusTensor(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native void minusTensorAssign(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native long unaryMinus(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long absTensor(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void absTensorAssign(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long transposeTensor(long tensorHandle, int i, int j);
|
||||||
|
|
||||||
|
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
|
||||||
|
|
||||||
|
public static native long expTensor(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void expTensorAssign(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long logTensor(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void logTensorAssign(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long sumTensor(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void sumTensorAssign(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long matmul(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native void matmulAssign(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native void matmulRightAssign(long lhs, long rhs);
|
||||||
|
|
||||||
|
public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2);
|
||||||
|
|
||||||
|
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
|
||||||
|
|
||||||
|
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors);
|
||||||
|
|
||||||
|
public static native boolean requiresGrad(long tensorHandle);
|
||||||
|
|
||||||
|
public static native void setRequiresGrad(long tensorHandle, boolean status);
|
||||||
|
|
||||||
|
public static native long detachFromGraph(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long autogradTensor(long value, long variable, boolean retainGraph);
|
||||||
|
|
||||||
|
public static native long autohessTensor(long value, long variable);
|
||||||
|
}
|
@ -0,0 +1,390 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.memory.DeferScope
|
||||||
|
import space.kscience.kmath.memory.withDeferScope
|
||||||
|
import space.kscience.kmath.tensors.*
|
||||||
|
|
||||||
|
public sealed class TorchTensorAlgebraJVM<
|
||||||
|
T,
|
||||||
|
PrimitiveArrayType,
|
||||||
|
TorchTensorType : TorchTensorJVM<T>> constructor(
|
||||||
|
internal val scope: DeferScope
|
||||||
|
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||||
|
override fun getNumThreads(): Int {
|
||||||
|
return JTorch.getNumThreads()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setNumThreads(numThreads: Int): Unit {
|
||||||
|
JTorch.setNumThreads(numThreads)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cudaAvailable(): Boolean {
|
||||||
|
return JTorch.cudaIsAvailable()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setSeed(seed: Int): Unit {
|
||||||
|
JTorch.setSeed(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
override var checks: Boolean = false
|
||||||
|
|
||||||
|
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
|
wrap(JTorch.unaryMinus(this.tensorHandle))
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkDotOperation(this, other)
|
||||||
|
return wrap(JTorch.matmul(this.tensorHandle, other.tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkDotOperation(this, other)
|
||||||
|
JTorch.matmulAssign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkDotOperation(this, other)
|
||||||
|
JTorch.matmulRightAssign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun diagonalEmbedding(
|
||||||
|
diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int
|
||||||
|
): TorchTensorType =
|
||||||
|
wrap(JTorch.diagEmbed(diagonalEntries.tensorHandle, offset, dim1, dim2))
|
||||||
|
|
||||||
|
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
|
||||||
|
if (checks) checkTranspose(this.dimension, i, j)
|
||||||
|
return wrap(JTorch.transposeTensor(tensorHandle, i, j))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||||
|
if (checks) checkTranspose(this.dimension, i, j)
|
||||||
|
JTorch.transposeTensorAssign(tensorHandle, i, j)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
|
||||||
|
if (checks) checkView(this, shape)
|
||||||
|
return wrap(JTorch.viewTensor(this.tensorHandle, shape))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.abs(): TorchTensorType = wrap(JTorch.absTensor(tensorHandle))
|
||||||
|
override fun TorchTensorType.absAssign(): Unit = JTorch.absTensorAssign(tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.sum(): TorchTensorType = wrap(JTorch.sumTensor(tensorHandle))
|
||||||
|
override fun TorchTensorType.sumAssign(): Unit = JTorch.sumTensorAssign(tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
|
||||||
|
wrap(JTorch.randintLike(this.tensorHandle, low, high))
|
||||||
|
|
||||||
|
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
|
||||||
|
JTorch.randintLikeAssign(this.tensorHandle, low, high)
|
||||||
|
|
||||||
|
override fun TorchTensorType.copy(): TorchTensorType =
|
||||||
|
wrap(JTorch.copyTensor(this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||||
|
wrap(JTorch.copyToDevice(this.tensorHandle, device.toInt()))
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
|
||||||
|
JTorch.swapTensors(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public sealed class TorchTensorPartialDivisionAlgebraJVM<T, PrimitiveArrayType,
|
||||||
|
TorchTensorType : TorchTensorOverFieldJVM<T>>(scope: DeferScope) :
|
||||||
|
TorchTensorAlgebraJVM<T, PrimitiveArrayType, TorchTensorType>(scope),
|
||||||
|
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(JTorch.divTensor(this.tensorHandle, other.tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
JTorch.divTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||||
|
wrap(JTorch.randLike(this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorType.randUniformAssign(): Unit =
|
||||||
|
JTorch.randLikeAssign(this.tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||||
|
wrap(JTorch.randnLike(this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorType.randNormalAssign(): Unit =
|
||||||
|
JTorch.randnLikeAssign(this.tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.exp(): TorchTensorType = wrap(JTorch.expTensor(tensorHandle))
|
||||||
|
override fun TorchTensorType.expAssign(): Unit = JTorch.expTensorAssign(tensorHandle)
|
||||||
|
override fun TorchTensorType.log(): TorchTensorType = wrap(JTorch.logTensor(tensorHandle))
|
||||||
|
override fun TorchTensorType.logAssign(): Unit = JTorch.logTensorAssign(tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||||
|
val U = JTorch.emptyTensor()
|
||||||
|
val V = JTorch.emptyTensor()
|
||||||
|
val S = JTorch.emptyTensor()
|
||||||
|
JTorch.svdTensor(this.tensorHandle, U, S, V)
|
||||||
|
return Triple(wrap(U), wrap(S), wrap(V))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair<TorchTensorType, TorchTensorType> {
|
||||||
|
val V = JTorch.emptyTensor()
|
||||||
|
val S = JTorch.emptyTensor()
|
||||||
|
JTorch.symeigTensor(this.tensorHandle, S, V, eigenvectors)
|
||||||
|
return Pair(wrap(S), wrap(V))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
|
||||||
|
if (checks) this.checkIsValue()
|
||||||
|
return wrap(JTorch.autogradTensor(this.tensorHandle, variable.tensorHandle, retainGraph))
|
||||||
|
}
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) this.checkIsValue()
|
||||||
|
return wrap(JTorch.autohessTensor(this.tensorHandle, variable.tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||||
|
wrap(JTorch.detachFromGraph(this.tensorHandle))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorPartialDivisionAlgebraJVM<Double, DoubleArray, TorchTensorReal>(scope) {
|
||||||
|
override fun wrap(tensorHandle: Long): TorchTensorReal =
|
||||||
|
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||||
|
this.elements().map { it.second }.toList().toDoubleArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(JTorch.fromBlobDouble(array, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randNormal(shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(JTorch.randnDouble(shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(JTorch.randDouble(shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(JTorch.randintDouble(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
|
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||||
|
wrap(JTorch.plusDouble(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
|
||||||
|
wrap(JTorch.plusDouble(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorReal.plusAssign(value: Double): Unit =
|
||||||
|
JTorch.plusDoubleAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
|
||||||
|
wrap(JTorch.plusDouble(-this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
|
||||||
|
wrap(JTorch.plusDouble(-value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorReal.minusAssign(value: Double): Unit =
|
||||||
|
JTorch.plusDoubleAssign(-value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
|
||||||
|
wrap(JTorch.timesDouble(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
|
||||||
|
wrap(JTorch.timesDouble(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorReal.timesAssign(value: Double): Unit =
|
||||||
|
JTorch.timesDoubleAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(JTorch.fullDouble(value, shape, device.toInt()))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorPartialDivisionAlgebraJVM<Float, FloatArray, TorchTensorFloat>(scope) {
|
||||||
|
override fun wrap(tensorHandle: Long): TorchTensorFloat =
|
||||||
|
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||||
|
this.elements().map { it.second }.toList().toFloatArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(JTorch.fromBlobFloat(array, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(JTorch.randnFloat(shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(JTorch.randFloat(shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(JTorch.randintFloat(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
|
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(JTorch.plusFloat(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||||
|
wrap(JTorch.plusFloat(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
|
||||||
|
JTorch.plusFloatAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(JTorch.plusFloat(-this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||||
|
wrap(JTorch.plusFloat(-value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
|
||||||
|
JTorch.plusFloatAssign(-value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(JTorch.timesFloat(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||||
|
wrap(JTorch.timesFloat(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
|
||||||
|
JTorch.timesFloatAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(JTorch.fullFloat(value, shape, device.toInt()))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorAlgebraJVM<Long, LongArray, TorchTensorLong>(scope) {
|
||||||
|
override fun wrap(tensorHandle: Long): TorchTensorLong =
|
||||||
|
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||||
|
this.elements().map { it.second }.toList().toLongArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
|
wrap(JTorch.fromBlobLong(array, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
|
wrap(JTorch.randintLong(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
|
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(JTorch.plusLong(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||||
|
wrap(JTorch.plusLong(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorLong.plusAssign(value: Long): Unit =
|
||||||
|
JTorch.plusLongAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(JTorch.plusLong(-this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||||
|
wrap(JTorch.plusLong(-value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorLong.minusAssign(value: Long): Unit =
|
||||||
|
JTorch.plusLongAssign(-value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(JTorch.timesLong(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||||
|
wrap(JTorch.timesLong(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorLong.timesAssign(value: Long): Unit =
|
||||||
|
JTorch.timesLongAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
|
wrap(JTorch.fullLong(value, shape, device.toInt()))
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorAlgebraJVM<Int, IntArray, TorchTensorInt>(scope) {
|
||||||
|
override fun wrap(tensorHandle: Long): TorchTensorInt =
|
||||||
|
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||||
|
this.elements().map { it.second }.toList().toIntArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
|
wrap(JTorch.fromBlobInt(array, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
|
wrap(JTorch.randintInt(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
|
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(JTorch.plusInt(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||||
|
wrap(JTorch.plusInt(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorInt.plusAssign(value: Int): Unit =
|
||||||
|
JTorch.plusIntAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(JTorch.plusInt(-this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||||
|
wrap(JTorch.plusInt(-value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorInt.minusAssign(value: Int): Unit =
|
||||||
|
JTorch.plusIntAssign(-value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(JTorch.timesInt(this, other.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||||
|
wrap(JTorch.timesInt(value, this.tensorHandle))
|
||||||
|
|
||||||
|
override fun TorchTensorInt.timesAssign(value: Int): Unit =
|
||||||
|
JTorch.timesIntAssign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
|
wrap(JTorch.fullInt(value, shape, device.toInt()))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorFloatAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorLongAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
@ -0,0 +1,94 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.memory.DeferScope
|
||||||
|
|
||||||
|
public sealed class TorchTensorJVM<T> constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
internal val tensorHandle: Long
|
||||||
|
) : TorchTensor<T>, TorchTensorMemoryHolder(scope)
|
||||||
|
{
|
||||||
|
override fun close(): Unit = JTorch.disposeTensor(tensorHandle)
|
||||||
|
|
||||||
|
override val dimension: Int get() = JTorch.getDim(tensorHandle)
|
||||||
|
override val shape: IntArray
|
||||||
|
get() = (1..dimension).map { JTorch.getShapeAt(tensorHandle, it - 1) }.toIntArray()
|
||||||
|
override val strides: IntArray
|
||||||
|
get() = (1..dimension).map { JTorch.getStrideAt(tensorHandle, it - 1) }.toIntArray()
|
||||||
|
override val size: Int get() = JTorch.getNumel(tensorHandle)
|
||||||
|
override val device: Device get() = Device.fromInt(JTorch.getDevice(tensorHandle))
|
||||||
|
|
||||||
|
override fun toString(): String = JTorch.tensorToString(tensorHandle)
|
||||||
|
|
||||||
|
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = JTorch.copyToDouble(this.tensorHandle)
|
||||||
|
)
|
||||||
|
|
||||||
|
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = JTorch.copyToFloat(this.tensorHandle)
|
||||||
|
)
|
||||||
|
|
||||||
|
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = JTorch.copyToLong(this.tensorHandle)
|
||||||
|
)
|
||||||
|
|
||||||
|
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = JTorch.copyToInt(this.tensorHandle)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
public sealed class TorchTensorOverFieldJVM<T> constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: Long
|
||||||
|
) : TorchTensorJVM<T>(scope, tensorHandle), TorchTensorOverField<T> {
|
||||||
|
override var requiresGrad: Boolean
|
||||||
|
get() = JTorch.requiresGrad(tensorHandle)
|
||||||
|
set(value) = JTorch.setRequiresGrad(tensorHandle, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorReal internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: Long
|
||||||
|
) : TorchTensorOverFieldJVM<Double>(scope, tensorHandle) {
|
||||||
|
override fun item(): Double = JTorch.getItemDouble(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Double = JTorch.getDouble(tensorHandle, index)
|
||||||
|
override fun set(index: IntArray, value: Double) {
|
||||||
|
JTorch.setDouble(tensorHandle, index, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorFloat internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: Long
|
||||||
|
) : TorchTensorOverFieldJVM<Float>(scope, tensorHandle) {
|
||||||
|
override fun item(): Float = JTorch.getItemFloat(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Float = JTorch.getFloat(tensorHandle, index)
|
||||||
|
override fun set(index: IntArray, value: Float) {
|
||||||
|
JTorch.setFloat(tensorHandle, index, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorLong internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: Long
|
||||||
|
) : TorchTensorOverFieldJVM<Long>(scope, tensorHandle) {
|
||||||
|
override fun item(): Long = JTorch.getItemLong(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Long = JTorch.getLong(tensorHandle, index)
|
||||||
|
override fun set(index: IntArray, value: Long) {
|
||||||
|
JTorch.setLong(tensorHandle, index, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorInt internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: Long
|
||||||
|
) : TorchTensorOverFieldJVM<Int>(scope, tensorHandle) {
|
||||||
|
override fun item(): Int = JTorch.getItemInt(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Int = JTorch.getInt(tensorHandle, index)
|
||||||
|
override fun set(index: IntArray, value: Int) {
|
||||||
|
JTorch.setInt(tensorHandle, index, value)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkMatMul {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMulDouble() = TorchTensorRealAlgebra {
|
||||||
|
benchmarkMatMul(20, 10, 100000, "Real")
|
||||||
|
benchmarkMatMul(200, 10, 10000, "Real")
|
||||||
|
benchmarkMatMul(2000, 3, 20, "Real")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMulFloat() = TorchTensorFloatAlgebra {
|
||||||
|
benchmarkMatMul(20, 10, 100000, "Float")
|
||||||
|
benchmarkMatMul(200, 10, 10000, "Float")
|
||||||
|
benchmarkMatMul(2000, 3, 20, "Float")
|
||||||
|
if (cudaAvailable()) {
|
||||||
|
benchmarkMatMul(20, 10, 100000, "Float", Device.CUDA(0))
|
||||||
|
benchmarkMatMul(200, 10, 10000, "Float", Device.CUDA(0))
|
||||||
|
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkRandomGenerators {
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand1() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand1()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand3() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand3()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand5() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand5()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand7() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand7()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutograd {
|
||||||
|
@Test
|
||||||
|
fun testAutoGrad() = TorchTensorFloatAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingAutoGrad(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingBatchedAutoGrad(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,39 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchTensor {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCopying() = TorchTensorFloatAlgebra {
|
||||||
|
withCuda { device ->
|
||||||
|
testingCopying(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||||
|
testingRequiresGrad()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTypeMoving() = TorchTensorFloatAlgebra {
|
||||||
|
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).copyToInt()
|
||||||
|
TorchTensorIntAlgebra {
|
||||||
|
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
|
||||||
|
tensorInt swap temporalTensor
|
||||||
|
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
|
||||||
|
}
|
||||||
|
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testViewWithNoCopy() = TorchTensorIntAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda {
|
||||||
|
device -> testingViewWithNoCopy(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,63 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchTensorAlgebra {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testScalarProduct() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingScalarProduct(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMatrixMultiplication() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingMatrixMultiplication(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLinearStructure() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingLinearStructure(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTensorTransformations() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingTensorTransformations(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSVD() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingBatchedSVD(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSymEig() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingBatchedSymEig(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,21 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSetNumThreads() {
|
||||||
|
TorchTensorLongAlgebra {
|
||||||
|
testingSetNumThreads()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSeedSetting() = TorchTensorFloatAlgebra {
|
||||||
|
withCuda { device ->
|
||||||
|
testingSetSeed(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
package=space.kscience.kmath.torch.ctorch
|
||||||
|
headers=ctorch.h
|
@ -0,0 +1,444 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.memory.DeferScope
|
||||||
|
import space.kscience.kmath.memory.withDeferScope
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import space.kscience.kmath.tensors.*
|
||||||
|
import space.kscience.kmath.torch.ctorch.*
|
||||||
|
|
||||||
|
public sealed class TorchTensorAlgebraNative<
|
||||||
|
T,
|
||||||
|
TVar : CPrimitiveVar,
|
||||||
|
PrimitiveArrayType,
|
||||||
|
TorchTensorType : TorchTensorNative<T>> constructor(
|
||||||
|
internal val scope: DeferScope
|
||||||
|
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||||
|
|
||||||
|
override fun getNumThreads(): Int {
|
||||||
|
return get_num_threads()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setNumThreads(numThreads: Int): Unit {
|
||||||
|
set_num_threads(numThreads)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cudaAvailable(): Boolean {
|
||||||
|
return cuda_is_available()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setSeed(seed: Int): Unit {
|
||||||
|
set_seed(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
override var checks: Boolean = false
|
||||||
|
|
||||||
|
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
|
||||||
|
|
||||||
|
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
||||||
|
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
|
wrap(unary_minus(this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkDotOperation(this, other)
|
||||||
|
return wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkDotOperation(this, other)
|
||||||
|
matmul_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkDotOperation(this, other)
|
||||||
|
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun diagonalEmbedding(
|
||||||
|
diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int
|
||||||
|
): TorchTensorType =
|
||||||
|
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
|
||||||
|
if (checks) checkTranspose(this.dimension, i, j)
|
||||||
|
return wrap(transpose_tensor(tensorHandle, i, j)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||||
|
if (checks) checkTranspose(this.dimension, i, j)
|
||||||
|
transpose_tensor_assign(tensorHandle, i, j)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
|
||||||
|
if (checks) checkView(this, shape)
|
||||||
|
return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
|
||||||
|
override fun TorchTensorType.absAssign(): Unit = abs_tensor_assign(tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
|
||||||
|
override fun TorchTensorType.sumAssign(): Unit = sum_tensor_assign(tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
|
||||||
|
wrap(randint_like(this.tensorHandle, low, high)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
|
||||||
|
randint_like_assign(this.tensorHandle, low, high)
|
||||||
|
|
||||||
|
override fun TorchTensorType.copy(): TorchTensorType =
|
||||||
|
wrap(copy_tensor(this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||||
|
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
|
||||||
|
swap_tensors(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitiveVar,
|
||||||
|
PrimitiveArrayType, TorchTensorType : TorchTensorOverFieldNative<T>>(scope: DeferScope) :
|
||||||
|
TorchTensorAlgebraNative<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
||||||
|
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||||
|
wrap(rand_like(this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorType.randUniformAssign(): Unit =
|
||||||
|
rand_like_assign(this.tensorHandle)
|
||||||
|
|
||||||
|
|
||||||
|
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||||
|
wrap(randn_like(this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorType.randNormalAssign(): Unit =
|
||||||
|
randn_like_assign(this.tensorHandle)
|
||||||
|
|
||||||
|
|
||||||
|
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||||
|
override fun TorchTensorType.expAssign(): Unit = exp_tensor_assign(tensorHandle)
|
||||||
|
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
||||||
|
override fun TorchTensorType.logAssign(): Unit = log_tensor_assign(tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||||
|
val U = empty_tensor()!!
|
||||||
|
val V = empty_tensor()!!
|
||||||
|
val S = empty_tensor()!!
|
||||||
|
svd_tensor(this.tensorHandle, U, S, V)
|
||||||
|
return Triple(wrap(U), wrap(S), wrap(V))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair<TorchTensorType, TorchTensorType> {
|
||||||
|
val V = empty_tensor()!!
|
||||||
|
val S = empty_tensor()!!
|
||||||
|
symeig_tensor(this.tensorHandle, S, V, eigenvectors)
|
||||||
|
return Pair(wrap(S), wrap(V))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
|
||||||
|
if (checks) this.checkIsValue()
|
||||||
|
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) this.checkIsValue()
|
||||||
|
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||||
|
wrap(detach_from_graph(this.tensorHandle)!!)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorPartialDivisionAlgebraNative<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
|
||||||
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
||||||
|
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||||
|
this.elements().map { it.second }.toList().toDoubleArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(from_blob_double(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
|
||||||
|
wrap(from_blob_double(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
|
||||||
|
require(this.device is Device.CPU) {
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_double(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun randNormal(shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(randn_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(randint_double(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||||
|
wrap(plus_double(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
|
||||||
|
wrap(plus_double(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.plusAssign(value: Double): Unit {
|
||||||
|
plus_double_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
|
||||||
|
wrap(plus_double(-this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
|
||||||
|
wrap(plus_double(-value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.minusAssign(value: Double): Unit {
|
||||||
|
plus_double_assign(-value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
|
||||||
|
wrap(times_double(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
|
||||||
|
wrap(times_double(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorReal.timesAssign(value: Double): Unit {
|
||||||
|
times_double_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
|
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||||
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
||||||
|
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||||
|
this.elements().map { it.second }.toList().toFloatArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(from_blob_float(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<FloatVar>, shape: IntArray): TorchTensorFloat =
|
||||||
|
wrap(from_blob_float(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.getData(): CPointer<FloatVar> {
|
||||||
|
require(this.device is Device.CPU) {
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_float(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(randn_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(randint_float(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(plus_float(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||||
|
wrap(plus_float(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
|
||||||
|
plus_float_assign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(plus_float(-this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||||
|
wrap(plus_float(-value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
|
||||||
|
plus_float_assign(-value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(times_float(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||||
|
wrap(times_float(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
|
||||||
|
times_float_assign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
|
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorAlgebraNative<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
||||||
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
|
||||||
|
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||||
|
this.elements().map { it.second }.toList().toLongArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
|
wrap(from_blob_long(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<LongVar>, shape: IntArray): TorchTensorLong =
|
||||||
|
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.getData(): CPointer<LongVar> {
|
||||||
|
check(this.device is Device.CPU) {
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_long(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
|
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(plus_long(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||||
|
wrap(plus_long(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.plusAssign(value: Long): Unit =
|
||||||
|
plus_long_assign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(plus_long(-this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||||
|
wrap(plus_long(-value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.minusAssign(value: Long): Unit =
|
||||||
|
plus_long_assign(-value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(times_long(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||||
|
wrap(times_long(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.timesAssign(value: Long): Unit =
|
||||||
|
times_long_assign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
|
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorAlgebraNative<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
||||||
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
|
||||||
|
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||||
|
this.elements().map { it.second }.toList().toIntArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
|
wrap(from_blob_int(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<IntVar>, shape: IntArray): TorchTensorInt =
|
||||||
|
wrap(from_blob_int(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.getData(): CPointer<IntVar> {
|
||||||
|
require(this.device is Device.CPU) {
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_int(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
|
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(plus_int(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||||
|
wrap(plus_int(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.plusAssign(value: Int): Unit =
|
||||||
|
plus_int_assign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(plus_int(-this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||||
|
wrap(plus_int(-value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.minusAssign(value: Int): Unit =
|
||||||
|
plus_int_assign(-value, this.tensorHandle)
|
||||||
|
|
||||||
|
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(times_int(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||||
|
wrap(times_int(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.timesAssign(value: Int): Unit =
|
||||||
|
times_int_assign(value, this.tensorHandle)
|
||||||
|
|
||||||
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
|
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorFloatAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorLongAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||||
|
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
||||||
|
|
@ -0,0 +1,104 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import space.kscience.kmath.memory.DeferScope
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import space.kscience.kmath.torch.ctorch.*
|
||||||
|
|
||||||
|
|
||||||
|
public sealed class TorchTensorNative<T> constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
internal val tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensor<T>, TorchTensorMemoryHolder(scope) {
|
||||||
|
|
||||||
|
override fun close(): Unit = dispose_tensor(tensorHandle)
|
||||||
|
|
||||||
|
override val dimension: Int get() = get_dim(tensorHandle)
|
||||||
|
override val shape: IntArray
|
||||||
|
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
|
||||||
|
override val strides: IntArray
|
||||||
|
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
|
||||||
|
override val size: Int get() = get_numel(tensorHandle)
|
||||||
|
override val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!!
|
||||||
|
val stringRepresentation = nativeStringRepresentation.toKString()
|
||||||
|
dispose_char(nativeStringRepresentation)
|
||||||
|
return stringRepresentation
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_float(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_long(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_int(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
public sealed class TorchTensorOverFieldNative<T> constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensorNative<T>(scope, tensorHandle), TorchTensorOverField<T> {
|
||||||
|
override var requiresGrad: Boolean
|
||||||
|
get() = requires_grad(tensorHandle)
|
||||||
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class TorchTensorReal internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensorOverFieldNative<Double>(scope, tensorHandle) {
|
||||||
|
override fun item(): Double = get_item_double(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
||||||
|
override fun set(index: IntArray, value: Double) {
|
||||||
|
set_double(tensorHandle, index.toCValues(), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorFloat internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensorOverFieldNative<Float>(scope, tensorHandle) {
|
||||||
|
override fun item(): Float = get_item_float(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||||
|
override fun set(index: IntArray, value: Float) {
|
||||||
|
set_float(tensorHandle, index.toCValues(), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorLong internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensorNative<Long>(scope, tensorHandle) {
|
||||||
|
override fun item(): Long = get_item_long(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
|
||||||
|
override fun set(index: IntArray, value: Long) {
|
||||||
|
set_long(tensorHandle, index.toCValues(), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorInt internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensorNative<Int>(scope, tensorHandle) {
|
||||||
|
override fun item(): Int = get_item_int(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
|
||||||
|
override fun set(index: IntArray, value: Int) {
|
||||||
|
set_int(tensorHandle, index.toCValues(), value)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
internal class BenchmarkMatMul {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMulDouble() = TorchTensorRealAlgebra {
|
||||||
|
benchmarkMatMul(20, 10, 100000, "Real")
|
||||||
|
benchmarkMatMul(200, 10, 10000, "Real")
|
||||||
|
benchmarkMatMul(2000, 3, 20, "Real")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMulFloat() = TorchTensorFloatAlgebra {
|
||||||
|
benchmarkMatMul(20, 10, 100000, "Float")
|
||||||
|
benchmarkMatMul(200, 10, 10000, "Float")
|
||||||
|
benchmarkMatMul(2000, 3, 20, "Float")
|
||||||
|
if (cudaAvailable()) {
|
||||||
|
benchmarkMatMul(20, 10, 100000, "Float", Device.CUDA(0))
|
||||||
|
benchmarkMatMul(200, 10, 10000, "Float", Device.CUDA(0))
|
||||||
|
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
internal class BenchmarkRandomGenerators {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand1() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand1()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand3() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand3()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand5() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand5()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkRand7() = TorchTensorFloatAlgebra{
|
||||||
|
benchmarkingRand7()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
internal class TestAutograd {
|
||||||
|
@Test
|
||||||
|
fun testAutoGrad() = TorchTensorFloatAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingAutoGrad(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingBatchedAutoGrad(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchTensor {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCopying() = TorchTensorFloatAlgebra {
|
||||||
|
withCuda {
|
||||||
|
device -> testingCopying(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testNativeNoCopyDataTransferOnCPU() = memScoped {
|
||||||
|
val data = allocArray<DoubleVar>(1)
|
||||||
|
data[0] = 1.0
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
val tensor = fromBlob(data, intArrayOf(1))
|
||||||
|
assertEquals(tensor[intArrayOf(0)], 1.0)
|
||||||
|
data[0] = 2.0
|
||||||
|
assertEquals(tensor[intArrayOf(0)], 2.0)
|
||||||
|
val tensorData = tensor.getData()
|
||||||
|
tensorData[0] = 3.0
|
||||||
|
assertEquals(tensor[intArrayOf(0)], 3.0)
|
||||||
|
}
|
||||||
|
assertEquals(data[0], 3.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||||
|
testingRequiresGrad()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTypeMoving() = TorchTensorFloatAlgebra {
|
||||||
|
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).copyToInt()
|
||||||
|
TorchTensorIntAlgebra {
|
||||||
|
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
|
||||||
|
tensorInt swap temporalTensor
|
||||||
|
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
|
||||||
|
}
|
||||||
|
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testViewWithNoCopy() = TorchTensorIntAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda {
|
||||||
|
device -> testingViewWithNoCopy(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,62 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
|
||||||
|
internal class TestTorchTensorAlgebra {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testScalarProduct() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingScalarProduct(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMatrixMultiplication() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingMatrixMultiplication(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLinearStructure() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingLinearStructure(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTensorTransformations() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingTensorTransformations(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSVD() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingBatchedSVD(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSymEig() = TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
|
withCuda { device ->
|
||||||
|
testingBatchedSymEig(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
internal class TestUtils {
|
||||||
|
@Test
|
||||||
|
fun testSetNumThreads() {
|
||||||
|
TorchTensorLongAlgebra {
|
||||||
|
testingSetNumThreads()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSeedSetting() = TorchTensorFloatAlgebra {
|
||||||
|
withCuda {
|
||||||
|
device -> testingSetSeed(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -42,3 +42,7 @@ include(
|
|||||||
":kmath-kotlingrad",
|
":kmath-kotlingrad",
|
||||||
":examples"
|
":examples"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(System.getProperty("os.name") == "Linux"){
|
||||||
|
include(":kmath-torch")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user