2020-12-30 01:42:33 +03:00
|
|
|
# LibTorch extension (`kmath-torch`)
|
|
|
|
|
2021-01-16 23:29:47 +03:00
|
|
|
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of
|
|
|
|
the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
2020-12-30 01:42:33 +03:00
|
|
|
|
|
|
|
## Installation
|
2021-01-16 23:29:47 +03:00
|
|
|
|
2020-12-30 01:42:33 +03:00
|
|
|
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
2021-01-16 23:29:47 +03:00
|
|
|
|
2020-12-30 01:42:33 +03:00
|
|
|
```
|
|
|
|
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
|
|
|
```
|
|
|
|
|
|
|
|
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
|
|
|
|
2021-01-16 23:29:47 +03:00
|
|
|
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
2020-12-30 01:42:33 +03:00
|
|
|
|
|
|
|
You will have to link against it in your own project. Here is an example of build script for a standalone application:
|
2021-01-16 23:29:47 +03:00
|
|
|
|
2020-12-30 01:42:33 +03:00
|
|
|
```kotlin
|
|
|
|
//build.gradle.kts
|
|
|
|
plugins {
|
|
|
|
id("ru.mipt.npm.mpp")
|
|
|
|
}
|
|
|
|
|
|
|
|
repositories {
|
|
|
|
jcenter()
|
|
|
|
mavenLocal()
|
|
|
|
}
|
|
|
|
|
|
|
|
val home = System.getProperty("user.home")
|
|
|
|
val kver = "0.2.0-dev-4"
|
|
|
|
val cppBuildDir = "$home/.konan/third-party/kmath-torch-$kver/cpp-build"
|
|
|
|
|
|
|
|
kotlin {
|
|
|
|
explicitApiWarning()
|
|
|
|
|
|
|
|
val nativeTarget = linuxX64("your.app")
|
|
|
|
nativeTarget.apply {
|
|
|
|
binaries {
|
|
|
|
executable {
|
|
|
|
entryPoint = "your.app.main"
|
|
|
|
}
|
|
|
|
all {
|
|
|
|
linkerOpts(
|
|
|
|
"-L$cppBuildDir",
|
|
|
|
"-Wl,-rpath=$cppBuildDir",
|
|
|
|
"-lctorch"
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
val main by nativeTarget.compilations.getting
|
|
|
|
|
|
|
|
sourceSets {
|
|
|
|
val nativeMain by creating {
|
|
|
|
dependencies {
|
|
|
|
implementation("kscience.kmath:kmath-torch:$kver")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
main.defaultSourceSet.dependsOn(nativeMain)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
```
|
2021-01-16 23:29:47 +03:00
|
|
|
|
2020-12-30 01:42:33 +03:00
|
|
|
```kotlin
|
|
|
|
//settings.gradle.kts
|
|
|
|
pluginManagement {
|
|
|
|
repositories {
|
|
|
|
gradlePluginPortal()
|
|
|
|
jcenter()
|
|
|
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
|
|
|
}
|
|
|
|
plugins {
|
|
|
|
id("ru.mipt.npm.mpp") version "0.7.1"
|
|
|
|
kotlin("jvm") version "1.4.21"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
```
|
|
|
|
|
|
|
|
## Usage
|
|
|
|
|
2021-01-16 23:29:47 +03:00
|
|
|
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods
|
|
|
|
and require scoping:
|
|
|
|
|
2020-12-30 01:42:33 +03:00
|
|
|
```kotlin
|
2021-01-09 20:13:38 +03:00
|
|
|
TorchTensorRealAlgebra {
|
2021-01-06 16:20:48 +03:00
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
val realTensor: TorchTensorReal = copyFromArray(
|
|
|
|
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
2021-01-16 23:29:47 +03:00
|
|
|
shape = intArrayOf(2, 5)
|
2021-01-06 16:20:48 +03:00
|
|
|
)
|
2021-01-09 20:13:38 +03:00
|
|
|
println(realTensor)
|
2021-01-06 16:20:48 +03:00
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
val gpuRealTensor: TorchTensorReal = copyFromArray(
|
|
|
|
array = (1..8).map { it * 2.5 }.toList().toDoubleArray(),
|
2021-01-06 16:20:48 +03:00
|
|
|
shape = intArrayOf(2, 2, 2),
|
2021-01-09 20:13:38 +03:00
|
|
|
device = TorchDevice.TorchCUDA(0)
|
2021-01-06 16:20:48 +03:00
|
|
|
)
|
2021-01-09 20:13:38 +03:00
|
|
|
println(gpuRealTensor)
|
2020-12-30 01:42:33 +03:00
|
|
|
}
|
|
|
|
```
|
2021-01-16 23:29:47 +03:00
|
|
|
|
2021-01-10 19:24:57 +03:00
|
|
|
Enjoy a high performance automatic differentiation engine:
|
2021-01-16 23:29:47 +03:00
|
|
|
|
2021-01-10 19:24:57 +03:00
|
|
|
```kotlin
|
|
|
|
TorchTensorRealAlgebra {
|
|
|
|
val dim = 10
|
|
|
|
val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0)
|
|
|
|
val x = randNormal(shape = intArrayOf(dim), device = device)
|
2021-01-16 23:29:47 +03:00
|
|
|
|
|
|
|
val X = randNormal(shape = intArrayOf(dim, dim), device = device)
|
|
|
|
val Q = X + X.transpose(0, 1)
|
2021-01-10 19:24:57 +03:00
|
|
|
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
2021-01-16 23:29:47 +03:00
|
|
|
|
2021-01-10 19:24:57 +03:00
|
|
|
// expression to differentiate w.r.t. x
|
2021-01-16 23:29:47 +03:00
|
|
|
val f = x.withGrad {
|
|
|
|
0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
|
|
|
|
}
|
2021-01-10 19:24:57 +03:00
|
|
|
// value of the gradient at x
|
|
|
|
val gradf = f grad x
|
2021-01-16 23:29:47 +03:00
|
|
|
// value of the hessian at x
|
|
|
|
val hessf = f hess x
|
2021-01-10 19:24:57 +03:00
|
|
|
}
|
|
|
|
```
|
2020-12-30 01:42:33 +03:00
|
|
|
|