forked from kscience/kmath
2.5 KiB
2.5 KiB
LibTorch extension (kmath-torch
)
This is a Kotlin/Native
module, with only linuxX64
supported so far. This library wraps some of the PyTorch C++ API, focusing on integrating Aten
& Autograd
with KMath
.
Installation
To install the library, you have to build & publish locally kmath-core
, kmath-memory
with kmath-torch
:
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
This builds ctorch
, a C wrapper for LibTorch
placed inside:
~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build
You will have to link against it in your own project. Here is an example of build script for a standalone application:
//build.gradle.kts
plugins {
id("ru.mipt.npm.mpp")
}
repositories {
jcenter()
mavenLocal()
}
val home = System.getProperty("user.home")
val kver = "0.2.0-dev-4"
val cppBuildDir = "$home/.konan/third-party/kmath-torch-$kver/cpp-build"
kotlin {
explicitApiWarning()
val nativeTarget = linuxX64("your.app")
nativeTarget.apply {
binaries {
executable {
entryPoint = "your.app.main"
}
all {
linkerOpts(
"-L$cppBuildDir",
"-Wl,-rpath=$cppBuildDir",
"-lctorch"
)
}
}
}
val main by nativeTarget.compilations.getting
sourceSets {
val nativeMain by creating {
dependencies {
implementation("kscience.kmath:kmath-torch:$kver")
}
}
main.defaultSourceSet.dependsOn(nativeMain)
}
}
//settings.gradle.kts
pluginManagement {
repositories {
gradlePluginPortal()
jcenter()
maven("https://dl.bintray.com/mipt-npm/dev")
}
plugins {
id("ru.mipt.npm.mpp") version "0.7.1"
kotlin("jvm") version "1.4.21"
}
}
Usage
Tensors implement the buffer protocol over MutableNDStructure
. They can only be instantiated through provided factory methods and require scoping:
memScoped {
val intTensor: TorchTensorInt = TorchTensor.copyFromIntArray(
scope = this,
array = intArrayOf(7,8,9,2,6,5),
shape = intArrayOf(3,2))
println(intTensor)
val floatTensor: TorchTensorFloat = TorchTensor.copyFromFloatArray(
scope = this,
array = floatArrayOf(7f,8.9f,2.6f,5.6f),
shape = intArrayOf(4))
println(intTensor)
}