forked from kscience/kmath
Corrected readme file
This commit is contained in:
parent
ed4ac2623d
commit
e5205d5afd
@ -1,9 +1,12 @@
|
||||
# LibTorch extension (`kmath-torch`)
|
||||
|
||||
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of
|
||||
the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||
|
||||
## Installation
|
||||
|
||||
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
||||
|
||||
```
|
||||
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||||
```
|
||||
@ -13,6 +16,7 @@ This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
||||
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
||||
|
||||
You will have to link against it in your own project. Here is an example of build script for a standalone application:
|
||||
|
||||
```kotlin
|
||||
//build.gradle.kts
|
||||
plugins {
|
||||
@ -59,6 +63,7 @@ kotlin {
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```kotlin
|
||||
//settings.gradle.kts
|
||||
pluginManagement {
|
||||
@ -76,13 +81,15 @@ pluginManagement {
|
||||
|
||||
## Usage
|
||||
|
||||
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
||||
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods
|
||||
and require scoping:
|
||||
|
||||
```kotlin
|
||||
TorchTensorRealAlgebra {
|
||||
|
||||
val realTensor: TorchTensorReal = copyFromArray(
|
||||
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
||||
shape = intArrayOf(2,5)
|
||||
shape = intArrayOf(2, 5)
|
||||
)
|
||||
println(realTensor)
|
||||
|
||||
@ -94,23 +101,27 @@ TorchTensorRealAlgebra {
|
||||
println(gpuRealTensor)
|
||||
}
|
||||
```
|
||||
|
||||
Enjoy a high performance automatic differentiation engine:
|
||||
|
||||
```kotlin
|
||||
TorchTensorRealAlgebra {
|
||||
val dim = 10
|
||||
val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0)
|
||||
val x = randNormal(shape = intArrayOf(dim), device = device)
|
||||
// x is the variable
|
||||
x.requiresGrad = true
|
||||
|
||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||
val Q = X + X.transpose(0,1)
|
||||
val X = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||
val Q = X + X.transpose(0, 1)
|
||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
|
||||
// expression to differentiate w.r.t. x
|
||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
|
||||
val f = x.withGrad {
|
||||
0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
|
||||
}
|
||||
// value of the gradient at x
|
||||
val gradf = f grad x
|
||||
// value of the hessian at x
|
||||
val hessf = f hess x
|
||||
}
|
||||
```
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user