kmath/kmath-torch/README.md
2021-03-01 17:04:13 +00:00

2.2 KiB

LibTorch extension (kmath-torch)

This is a Kotlin/Native & JVM module, with only linuxX64 supported so far. The library wraps some of the PyTorch C++ API, focusing on integrating Aten & Autograd with KMath.

Installation

To install the library, you have to build & publish locally kmath-core, kmath-memory with kmath-torch:

./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal

This builds ctorch a C wrapper and jtorch a JNI wrapper for LibTorch, placed inside:

~/.konan/third-party/kmath-torch-0.2.0/cpp-build

You will have to link against it in your own project.

Usage

Tensors are implemented over the MutableNDStructure. They can only be created through provided factory methods and require scoping within a TensorAlgebra instance:

TorchTensorRealAlgebra {

    val realTensor: TorchTensorReal = copyFromArray(
        array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
        shape = intArrayOf(2, 5)
    )
    println(realTensor)

    val gpuRealTensor: TorchTensorReal = copyFromArray(
        array = (1..8).map { it * 2.5 }.toList().toDoubleArray(),
        shape = intArrayOf(2, 2, 2),
        device = Device.CUDA(0)
    )
    println(gpuRealTensor)
}

High performance automatic differentiation engine is available:

TorchTensorRealAlgebra {
    val dim = 10
    val device = Device.CPU //or Device.CUDA(0)
    
    val tensorX = randNormal(shape = intArrayOf(dim), device = device)
    val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
    val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
    val tensorMu = randNormal(shape = intArrayOf(dim), device = device)

    // expression to differentiate w.r.t. x evaluated at x = tensorX
    val expressionAtX = withGradAt(tensorX, { x ->
        0.5 * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9
    })

    // value of the gradient at x = tensorX
    val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
    // value of the hessian at x = tensorX
    val hessianAtX = expressionAtX hess tensorX
}

Contributed by Roland Grinis