96 lines
2.5 KiB
Markdown
96 lines
2.5 KiB
Markdown
|
# LibTorch extension (`kmath-torch`)
|
||
|
|
||
|
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||
|
|
||
|
## Installation
|
||
|
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
||
|
```
|
||
|
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||
|
```
|
||
|
|
||
|
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
||
|
|
||
|
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
||
|
|
||
|
You will have to link against it in your own project. Here is an example of build script for a standalone application:
|
||
|
```kotlin
|
||
|
//build.gradle.kts
|
||
|
plugins {
|
||
|
id("ru.mipt.npm.mpp")
|
||
|
}
|
||
|
|
||
|
repositories {
|
||
|
jcenter()
|
||
|
mavenLocal()
|
||
|
}
|
||
|
|
||
|
val home = System.getProperty("user.home")
|
||
|
val kver = "0.2.0-dev-4"
|
||
|
val cppBuildDir = "$home/.konan/third-party/kmath-torch-$kver/cpp-build"
|
||
|
|
||
|
kotlin {
|
||
|
explicitApiWarning()
|
||
|
|
||
|
val nativeTarget = linuxX64("your.app")
|
||
|
nativeTarget.apply {
|
||
|
binaries {
|
||
|
executable {
|
||
|
entryPoint = "your.app.main"
|
||
|
}
|
||
|
all {
|
||
|
linkerOpts(
|
||
|
"-L$cppBuildDir",
|
||
|
"-Wl,-rpath=$cppBuildDir",
|
||
|
"-lctorch"
|
||
|
)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
val main by nativeTarget.compilations.getting
|
||
|
|
||
|
sourceSets {
|
||
|
val nativeMain by creating {
|
||
|
dependencies {
|
||
|
implementation("kscience.kmath:kmath-torch:$kver")
|
||
|
}
|
||
|
}
|
||
|
main.defaultSourceSet.dependsOn(nativeMain)
|
||
|
}
|
||
|
}
|
||
|
```
|
||
|
```kotlin
|
||
|
//settings.gradle.kts
|
||
|
pluginManagement {
|
||
|
repositories {
|
||
|
gradlePluginPortal()
|
||
|
jcenter()
|
||
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||
|
}
|
||
|
plugins {
|
||
|
id("ru.mipt.npm.mpp") version "0.7.1"
|
||
|
kotlin("jvm") version "1.4.21"
|
||
|
}
|
||
|
}
|
||
|
```
|
||
|
|
||
|
## Usage
|
||
|
|
||
|
Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
||
|
```kotlin
|
||
|
memScoped {
|
||
|
val intTensor: TorchTensorInt = TorchTensor.copyFromIntArray(
|
||
|
scope = this,
|
||
|
array = intArrayOf(7,8,9,2,6,5),
|
||
|
shape = intArrayOf(3,2))
|
||
|
println(intTensor)
|
||
|
|
||
|
val floatTensor: TorchTensorFloat = TorchTensor.copyFromFloatArray(
|
||
|
scope = this,
|
||
|
array = floatArrayOf(7f,8.9f,2.6f,5.6f),
|
||
|
shape = intArrayOf(4))
|
||
|
println(intTensor)
|
||
|
}
|
||
|
```
|
||
|
|