Corrected readme file
This commit is contained in:
parent
ed4ac2623d
commit
e5205d5afd
@ -1,18 +1,22 @@
|
|||||||
# LibTorch extension (`kmath-torch`)
|
# LibTorch extension (`kmath-torch`)
|
||||||
|
|
||||||
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of
|
||||||
|
the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
||||||
|
|
||||||
```
|
```
|
||||||
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||||||
```
|
```
|
||||||
|
|
||||||
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
||||||
|
|
||||||
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
||||||
|
|
||||||
You will have to link against it in your own project. Here is an example of build script for a standalone application:
|
You will have to link against it in your own project. Here is an example of build script for a standalone application:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
//build.gradle.kts
|
//build.gradle.kts
|
||||||
plugins {
|
plugins {
|
||||||
@ -59,6 +63,7 @@ kotlin {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
//settings.gradle.kts
|
//settings.gradle.kts
|
||||||
pluginManagement {
|
pluginManagement {
|
||||||
@ -76,13 +81,15 @@ pluginManagement {
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods
|
||||||
|
and require scoping:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
|
|
||||||
val realTensor: TorchTensorReal = copyFromArray(
|
val realTensor: TorchTensorReal = copyFromArray(
|
||||||
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
||||||
shape = intArrayOf(2,5)
|
shape = intArrayOf(2, 5)
|
||||||
)
|
)
|
||||||
println(realTensor)
|
println(realTensor)
|
||||||
|
|
||||||
@ -94,23 +101,27 @@ TorchTensorRealAlgebra {
|
|||||||
println(gpuRealTensor)
|
println(gpuRealTensor)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Enjoy a high performance automatic differentiation engine:
|
Enjoy a high performance automatic differentiation engine:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
val dim = 10
|
val dim = 10
|
||||||
val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0)
|
val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0)
|
||||||
val x = randNormal(shape = intArrayOf(dim), device = device)
|
val x = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
// x is the variable
|
|
||||||
x.requiresGrad = true
|
val X = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||||
|
val Q = X + X.transpose(0, 1)
|
||||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
|
||||||
val Q = X + X.transpose(0,1)
|
|
||||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
|
||||||
// expression to differentiate w.r.t. x
|
// expression to differentiate w.r.t. x
|
||||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
|
val f = x.withGrad {
|
||||||
|
0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
|
||||||
|
}
|
||||||
// value of the gradient at x
|
// value of the gradient at x
|
||||||
val gradf = f grad x
|
val gradf = f grad x
|
||||||
|
// value of the hessian at x
|
||||||
|
val hessf = f hess x
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user