Buffer protocol for torch tensors
This commit is contained in:
parent
b72b6fb34f
commit
a229aaa6a4
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
|||||||
build/
|
build/
|
||||||
out/
|
out/
|
||||||
.idea/
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|
||||||
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
||||||
!gradle-wrapper.jar
|
!gradle-wrapper.jar
|
||||||
|
@ -148,23 +148,28 @@ public interface Strides {
|
|||||||
/**
|
/**
|
||||||
* Array strides
|
* Array strides
|
||||||
*/
|
*/
|
||||||
public val strides: List<Int>
|
public val strides: IntArray
|
||||||
|
|
||||||
/**
|
|
||||||
* Get linear index from multidimensional index
|
|
||||||
*/
|
|
||||||
public fun offset(index: IntArray): Int
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get multidimensional from linear
|
|
||||||
*/
|
|
||||||
public fun index(offset: Int): IntArray
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||||
*/
|
*/
|
||||||
public val linearSize: Int
|
public val linearSize: Int
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get linear index from multidimensional index
|
||||||
|
*/
|
||||||
|
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||||
|
if (value < 0 || value >= this.shape[i])
|
||||||
|
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||||
|
|
||||||
|
value * strides[i]
|
||||||
|
}.sum()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get multidimensional from linear
|
||||||
|
*/
|
||||||
|
public fun index(offset: Int): IntArray
|
||||||
|
|
||||||
// TODO introduce a fast way to calculate index of the next element?
|
// TODO introduce a fast way to calculate index of the next element?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -183,7 +188,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
/**
|
/**
|
||||||
* Strides for memory access
|
* Strides for memory access
|
||||||
*/
|
*/
|
||||||
override val strides: List<Int> by lazy {
|
override val strides: IntArray by lazy {
|
||||||
sequence {
|
sequence {
|
||||||
var current = 1
|
var current = 1
|
||||||
yield(1)
|
yield(1)
|
||||||
@ -192,16 +197,9 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
current *= it
|
current *= it
|
||||||
yield(current)
|
yield(current)
|
||||||
}
|
}
|
||||||
}.toList()
|
}.toList().toIntArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
|
||||||
if (value < 0 || value >= this.shape[i])
|
|
||||||
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
|
||||||
|
|
||||||
value * strides[i]
|
|
||||||
}.sum()
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray {
|
override fun index(offset: Int): IntArray {
|
||||||
val res = IntArray(shape.size)
|
val res = IntArray(shape.size)
|
||||||
var current = offset
|
var current = offset
|
||||||
@ -238,20 +236,22 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [NDStructure] over [Buffer].
|
* Trait for [NDStructure] over [Buffer].
|
||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items
|
||||||
|
* @param BufferImpl implementation of [Buffer].
|
||||||
*/
|
*/
|
||||||
public abstract class NDBuffer<T> : NDStructure<T> {
|
public abstract class NDBufferTrait<T, out BufferImpl : Buffer<T>, out StridesImpl: Strides> :
|
||||||
|
NDStructure<T> {
|
||||||
/**
|
/**
|
||||||
* The underlying buffer.
|
* The underlying buffer.
|
||||||
*/
|
*/
|
||||||
public abstract val buffer: Buffer<T>
|
public abstract val buffer: BufferImpl
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The strides to access elements of [Buffer] by linear indices.
|
* The strides to access elements of [Buffer] by linear indices.
|
||||||
*/
|
*/
|
||||||
public abstract val strides: Strides
|
public abstract val strides: StridesImpl
|
||||||
|
|
||||||
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||||
|
|
||||||
@ -259,8 +259,8 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
|
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
public fun checkStridesBufferCompatibility(): Unit = require(strides.linearSize == buffer.size) {
|
||||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
"Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
override fun hashCode(): Int {
|
||||||
@ -269,6 +269,10 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
override fun toString(): String {
|
override fun toString(): String {
|
||||||
val bufferRepr: String = when (shape.size) {
|
val bufferRepr: String = when (shape.size) {
|
||||||
1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
|
1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
|
||||||
@ -282,10 +286,36 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
}
|
}
|
||||||
return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
|
return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Trait for [MutableNDStructure] over [MutableBuffer].
|
||||||
|
*
|
||||||
|
* @param T the type of items
|
||||||
|
* @param MutableBufferImpl implementation of [MutableBuffer].
|
||||||
|
*/
|
||||||
|
public abstract class MutableNDBufferTrait<T, out MutableBufferImpl : MutableBuffer<T>, out StridesImpl: Strides> :
|
||||||
|
NDBufferTrait<T, MutableBufferImpl, StridesImpl>(), MutableNDStructure<T> {
|
||||||
|
override fun hashCode(): Int = 0
|
||||||
|
override fun equals(other: Any?): Boolean = false
|
||||||
|
override operator fun set(index: IntArray, value: T): Unit =
|
||||||
|
buffer.set(strides.offset(index), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default representation of [NDStructure] over [Buffer].
|
||||||
|
*
|
||||||
|
* @param T the type of items.
|
||||||
|
*/
|
||||||
|
public abstract class NDBuffer<T> : NDBufferTrait<T, Buffer<T>, Strides>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default representation of [MutableNDStructure] over [MutableBuffer].
|
||||||
|
*
|
||||||
|
* @param T the type of items.
|
||||||
|
*/
|
||||||
|
public abstract class MutableNDBuffer<T> : MutableNDBufferTrait<T, MutableBuffer<T>, Strides>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Boxing generic [NDStructure]
|
* Boxing generic [NDStructure]
|
||||||
*/
|
*/
|
||||||
@ -294,9 +324,7 @@ public class BufferNDStructure<T>(
|
|||||||
override val buffer: Buffer<T>,
|
override val buffer: Buffer<T>,
|
||||||
) : NDBuffer<T>() {
|
) : NDBuffer<T>() {
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
checkStridesBufferCompatibility()
|
||||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,20 +344,15 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mutable ND buffer based on linear [MutableBuffer].
|
* Boxing generic [MutableNDStructure].
|
||||||
*/
|
*/
|
||||||
public class MutableBufferNDStructure<T>(
|
public class MutableBufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: MutableBuffer<T>,
|
override val buffer: MutableBuffer<T>,
|
||||||
) : NDBuffer<T>(), MutableNDStructure<T> {
|
) : MutableNDBuffer<T>() {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
require(strides.linearSize == buffer.size) {
|
checkStridesBufferCompatibility()
|
||||||
"Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <reified T : Any> NDStructure<T>.combine(
|
public inline fun <reified T : Any> NDStructure<T>.combine(
|
||||||
|
95
kmath-torch/README.md
Normal file
95
kmath-torch/README.md
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# LibTorch extension (`kmath-torch`)
|
||||||
|
|
||||||
|
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
||||||
|
```
|
||||||
|
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||||||
|
```
|
||||||
|
|
||||||
|
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
||||||
|
|
||||||
|
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
||||||
|
|
||||||
|
You will have to link against it in your own project. Here is an example of build script for a standalone application:
|
||||||
|
```kotlin
|
||||||
|
//build.gradle.kts
|
||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.mpp")
|
||||||
|
}
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
jcenter()
|
||||||
|
mavenLocal()
|
||||||
|
}
|
||||||
|
|
||||||
|
val home = System.getProperty("user.home")
|
||||||
|
val kver = "0.2.0-dev-4"
|
||||||
|
val cppBuildDir = "$home/.konan/third-party/kmath-torch-$kver/cpp-build"
|
||||||
|
|
||||||
|
kotlin {
|
||||||
|
explicitApiWarning()
|
||||||
|
|
||||||
|
val nativeTarget = linuxX64("your.app")
|
||||||
|
nativeTarget.apply {
|
||||||
|
binaries {
|
||||||
|
executable {
|
||||||
|
entryPoint = "your.app.main"
|
||||||
|
}
|
||||||
|
all {
|
||||||
|
linkerOpts(
|
||||||
|
"-L$cppBuildDir",
|
||||||
|
"-Wl,-rpath=$cppBuildDir",
|
||||||
|
"-lctorch"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val main by nativeTarget.compilations.getting
|
||||||
|
|
||||||
|
sourceSets {
|
||||||
|
val nativeMain by creating {
|
||||||
|
dependencies {
|
||||||
|
implementation("kscience.kmath:kmath-torch:$kver")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
main.defaultSourceSet.dependsOn(nativeMain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
```kotlin
|
||||||
|
//settings.gradle.kts
|
||||||
|
pluginManagement {
|
||||||
|
repositories {
|
||||||
|
gradlePluginPortal()
|
||||||
|
jcenter()
|
||||||
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
}
|
||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.mpp") version "0.7.1"
|
||||||
|
kotlin("jvm") version "1.4.21"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
||||||
|
```kotlin
|
||||||
|
memScoped {
|
||||||
|
val intTensor: TorchTensorInt = TorchTensor.copyFromIntArray(
|
||||||
|
scope = this,
|
||||||
|
array = intArrayOf(7,8,9,2,6,5),
|
||||||
|
shape = intArrayOf(3,2))
|
||||||
|
println(intTensor)
|
||||||
|
|
||||||
|
val floatTensor: TorchTensorFloat = TorchTensor.copyFromFloatArray(
|
||||||
|
scope = this,
|
||||||
|
array = floatArrayOf(7f,8.9f,2.6f,5.6f),
|
||||||
|
shape = intArrayOf(4))
|
||||||
|
println(intTensor)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
146
kmath-torch/build.gradle.kts
Normal file
146
kmath-torch/build.gradle.kts
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import de.undercouch.gradle.tasks.download.Download
|
||||||
|
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
|
||||||
|
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.mpp")
|
||||||
|
id("de.undercouch.download")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val home = System.getProperty("user.home")
|
||||||
|
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
||||||
|
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
||||||
|
|
||||||
|
val cmakeArchive = "cmake-3.19.2-Linux-x86_64"
|
||||||
|
val torchArchive = "libtorch"
|
||||||
|
|
||||||
|
val cmakeCmd = "$thirdPartyDir/$cmakeArchive/bin/cmake"
|
||||||
|
val ninjaCmd = "$thirdPartyDir/ninja"
|
||||||
|
|
||||||
|
val downloadCMake by tasks.registering(Download::class) {
|
||||||
|
val tarFile = "$cmakeArchive.tar.gz"
|
||||||
|
src("https://github.com/Kitware/CMake/releases/download/v3.19.2/$tarFile")
|
||||||
|
dest(File(thirdPartyDir, tarFile))
|
||||||
|
overwrite(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
val downloadNinja by tasks.registering(Download::class) {
|
||||||
|
src("https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip")
|
||||||
|
dest(File(thirdPartyDir, "ninja-linux.zip"))
|
||||||
|
overwrite(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
val downloadTorch by tasks.registering(Download::class) {
|
||||||
|
val zipFile = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2Bcu110.zip"
|
||||||
|
src("https://download.pytorch.org/libtorch/cu110/$zipFile")
|
||||||
|
dest(File(thirdPartyDir, "$torchArchive.zip"))
|
||||||
|
overwrite(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
val extractCMake by tasks.registering(Copy::class) {
|
||||||
|
dependsOn(downloadCMake)
|
||||||
|
from(tarTree(resources.gzip(downloadCMake.get().dest)))
|
||||||
|
into(thirdPartyDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
val extractTorch by tasks.registering(Copy::class) {
|
||||||
|
dependsOn(downloadTorch)
|
||||||
|
from(zipTree(downloadTorch.get().dest))
|
||||||
|
into(thirdPartyDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
val extractNinja by tasks.registering(Copy::class) {
|
||||||
|
dependsOn(downloadNinja)
|
||||||
|
from(zipTree(downloadNinja.get().dest))
|
||||||
|
into(thirdPartyDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
val configureCpp by tasks.registering {
|
||||||
|
dependsOn(extractCMake)
|
||||||
|
dependsOn(extractNinja)
|
||||||
|
dependsOn(extractTorch)
|
||||||
|
onlyIf { !file(cppBuildDir).exists() }
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(thirdPartyDir)
|
||||||
|
commandLine("mkdir", "-p", cppBuildDir)
|
||||||
|
}
|
||||||
|
exec {
|
||||||
|
workingDir(cppBuildDir)
|
||||||
|
commandLine(
|
||||||
|
cmakeCmd,
|
||||||
|
projectDir.resolve("ctorch"),
|
||||||
|
"-GNinja",
|
||||||
|
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
|
||||||
|
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
|
||||||
|
"-DCMAKE_BUILD_TYPE=Release"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val cleanCppBuild by tasks.registering {
|
||||||
|
onlyIf { file(cppBuildDir).exists() }
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(thirdPartyDir)
|
||||||
|
commandLine("rm", "-rf", cppBuildDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val buildCpp by tasks.registering {
|
||||||
|
dependsOn(configureCpp)
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(cppBuildDir)
|
||||||
|
commandLine(cmakeCmd, "--build", ".", "--config", "Release")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kotlin {
|
||||||
|
explicitApiWarning()
|
||||||
|
|
||||||
|
val nativeTarget = linuxX64("torch")
|
||||||
|
nativeTarget.apply {
|
||||||
|
binaries {
|
||||||
|
all {
|
||||||
|
linkerOpts(
|
||||||
|
"-L$cppBuildDir",
|
||||||
|
"-Wl,-rpath=$cppBuildDir",
|
||||||
|
"-lctorch"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val main by nativeTarget.compilations.getting {
|
||||||
|
cinterops {
|
||||||
|
val libctorch by creating {
|
||||||
|
includeDirs(projectDir.resolve("ctorch/include"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val test by nativeTarget.compilations.getting
|
||||||
|
|
||||||
|
sourceSets {
|
||||||
|
val nativeMain by creating {
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val nativeTest by creating {
|
||||||
|
dependsOn(nativeMain)
|
||||||
|
}
|
||||||
|
|
||||||
|
main.defaultSourceSet.dependsOn(nativeMain)
|
||||||
|
test.defaultSourceSet.dependsOn(nativeTest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val torch: KotlinNativeTarget by kotlin.targets
|
||||||
|
tasks[torch.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
||||||
|
.dependsOn(buildCpp)
|
25
kmath-torch/ctorch/CMakeLists.txt
Normal file
25
kmath-torch/ctorch/CMakeLists.txt
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.12)
|
||||||
|
|
||||||
|
project(CTorch LANGUAGES C CXX)
|
||||||
|
|
||||||
|
# Require C++17
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
|
# Build configuration
|
||||||
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
|
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||||
|
endif()
|
||||||
|
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||||
|
|
||||||
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
|
add_library(ctorch SHARED src/ctorch.cc)
|
||||||
|
target_include_directories(ctorch PRIVATE include)
|
||||||
|
target_link_libraries(ctorch PRIVATE torch)
|
||||||
|
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||||
|
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
|
||||||
|
install(TARGETS ctorch
|
||||||
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
48
kmath-torch/ctorch/include/ctorch.h
Normal file
48
kmath-torch/ctorch/include/ctorch.h
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#ifndef CTORCH
|
||||||
|
#define CTORCH
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef void *TorchTensorHandle;
|
||||||
|
|
||||||
|
int get_num_threads();
|
||||||
|
|
||||||
|
void set_num_threads(int num_threads);
|
||||||
|
|
||||||
|
bool cuda_is_available();
|
||||||
|
|
||||||
|
void set_seed(int seed);
|
||||||
|
|
||||||
|
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim);
|
||||||
|
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim);
|
||||||
|
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim);
|
||||||
|
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim);
|
||||||
|
|
||||||
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||||
|
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||||
|
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||||
|
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
int get_numel(TorchTensorHandle tensor_handle);
|
||||||
|
int get_dim(TorchTensorHandle tensor_handle);
|
||||||
|
int *get_shape(TorchTensorHandle tensor_handle);
|
||||||
|
int *get_strides(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
void dispose_int_array(int *ptr);
|
||||||
|
void dispose_char(char *ptr);
|
||||||
|
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif //CTORCH
|
56
kmath-torch/ctorch/include/utils.hh
Normal file
56
kmath-torch/ctorch/include/utils.hh
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#include <torch/torch.h>
|
||||||
|
|
||||||
|
#include "ctorch.h"
|
||||||
|
|
||||||
|
namespace ctorch
|
||||||
|
{
|
||||||
|
template <typename Dtype>
|
||||||
|
inline c10::ScalarType dtype()
|
||||||
|
{
|
||||||
|
return torch::kFloat64;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline c10::ScalarType dtype<float>()
|
||||||
|
{
|
||||||
|
return torch::kFloat32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline c10::ScalarType dtype<long>()
|
||||||
|
{
|
||||||
|
return torch::kInt64;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline c10::ScalarType dtype<int>()
|
||||||
|
{
|
||||||
|
return torch::kInt32;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline torch::Tensor &cast(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return *static_cast<torch::Tensor *>(tensor_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor copy_from_blob(Dtype *data, int *shape, int dim)
|
||||||
|
{
|
||||||
|
auto shape_vec = std::vector<int64_t>(dim);
|
||||||
|
shape_vec.assign(shape, shape + dim);
|
||||||
|
return torch::from_blob(data, shape_vec, dtype<Dtype>()).clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntArray>
|
||||||
|
inline int *to_dynamic_ints(IntArray arr)
|
||||||
|
{
|
||||||
|
size_t n = arr.size();
|
||||||
|
int *res = (int *)malloc(sizeof(int) * n);
|
||||||
|
for (size_t i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
res[i] = arr[i];
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ctorch
|
110
kmath-torch/ctorch/src/ctorch.cc
Normal file
110
kmath-torch/ctorch/src/ctorch.cc
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
#include <torch/torch.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "ctorch.h"
|
||||||
|
#include "utils.hh"
|
||||||
|
|
||||||
|
int get_num_threads()
|
||||||
|
{
|
||||||
|
return torch::get_num_threads();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_num_threads(int num_threads)
|
||||||
|
{
|
||||||
|
torch::set_num_threads(num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool cuda_is_available()
|
||||||
|
{
|
||||||
|
return torch::cuda::is_available();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_seed(int seed)
|
||||||
|
{
|
||||||
|
torch::manual_seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::copy_from_blob<double>(data, shape, dim));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::copy_from_blob<float>(data, shape, dim));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::copy_from_blob<long>(data, shape, dim));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::copy_from_blob<int>(data, shape, dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||||
|
}
|
||||||
|
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||||
|
}
|
||||||
|
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||||
|
}
|
||||||
|
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_numel(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).numel();
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_dim(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).dim();
|
||||||
|
}
|
||||||
|
|
||||||
|
int *get_shape(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::to_dynamic_ints(ctorch::cast(tensor_handle).sizes());
|
||||||
|
}
|
||||||
|
|
||||||
|
int *get_strides(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::to_dynamic_ints(ctorch::cast(tensor_handle).strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
std::stringstream bufrep;
|
||||||
|
bufrep << ctorch::cast(tensor_handle);
|
||||||
|
auto rep = bufrep.str();
|
||||||
|
char *crep = (char *)malloc(rep.length() + 1);
|
||||||
|
std::strcpy(crep, rep.c_str());
|
||||||
|
return crep;
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispose_int_array(int *ptr)
|
||||||
|
{
|
||||||
|
free(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispose_char(char *ptr)
|
||||||
|
{
|
||||||
|
free(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispose_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
delete static_cast<torch::Tensor *>(tensor_handle);
|
||||||
|
}
|
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
package=ctorch
|
||||||
|
headers=ctorch.h
|
@ -0,0 +1,52 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kscience.kmath.structures.*
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import ctorch.*
|
||||||
|
|
||||||
|
public abstract class TorchTensor<T,
|
||||||
|
TVar : CPrimitiveVar,
|
||||||
|
TorchTensorBufferImpl : TorchTensorBuffer<T, TVar>> :
|
||||||
|
MutableNDBufferTrait<T, TorchTensorBufferImpl, TorchTensorStrides>() {
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
public fun copyFromFloatArray(scope: DeferScope, array: FloatArray, shape: IntArray): TorchTensorFloat {
|
||||||
|
val tensorHandle: COpaquePointer = copy_from_blob_float(
|
||||||
|
array.toCValues(), shape.toCValues(), shape.size
|
||||||
|
)!!
|
||||||
|
return TorchTensorFloat(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
||||||
|
}
|
||||||
|
public fun copyFromIntArray(scope: DeferScope, array: IntArray, shape: IntArray): TorchTensorInt {
|
||||||
|
val tensorHandle: COpaquePointer = copy_from_blob_int(
|
||||||
|
array.toCValues(), shape.toCValues(), shape.size
|
||||||
|
)!!
|
||||||
|
return TorchTensorInt(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(buffer.tensorHandle)!!
|
||||||
|
val stringRepresentation = nativeStringRepresentation.toKString()
|
||||||
|
dispose_char(nativeStringRepresentation)
|
||||||
|
return stringRepresentation
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorFloat internal constructor(
|
||||||
|
override val strides: TorchTensorStrides,
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
): TorchTensor<Float, FloatVar, TorchTensorBufferFloat>() {
|
||||||
|
override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorInt internal constructor(
|
||||||
|
override val strides: TorchTensorStrides,
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
): TorchTensor<Int, IntVar, TorchTensorBufferInt>() {
|
||||||
|
override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,67 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kscience.kmath.structures.MutableBuffer
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import ctorch.*
|
||||||
|
|
||||||
|
public abstract class TorchTensorBuffer<T, TVar : CPrimitiveVar> internal constructor(
|
||||||
|
internal val scope: DeferScope,
|
||||||
|
internal val tensorHandle: COpaquePointer
|
||||||
|
) : MutableBuffer<T> {
|
||||||
|
init {
|
||||||
|
scope.defer(::close)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun close() {
|
||||||
|
dispose_tensor(tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract val tensorData: CPointer<TVar>
|
||||||
|
|
||||||
|
override val size: Int
|
||||||
|
get() = get_numel(tensorHandle)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class TorchTensorBufferFloat internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensorBuffer<Float, FloatVar>(scope, tensorHandle) {
|
||||||
|
override val tensorData: CPointer<FloatVar> = get_data_float(tensorHandle)!!
|
||||||
|
|
||||||
|
override operator fun get(index: Int): Float = tensorData[index]
|
||||||
|
|
||||||
|
override operator fun set(index: Int, value: Float) {
|
||||||
|
tensorData[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun iterator(): Iterator<Float> = (1..size).map { tensorData[it - 1] }.iterator()
|
||||||
|
|
||||||
|
override fun copy(): TorchTensorBufferFloat = TorchTensorBufferFloat(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_tensor(tensorHandle)!!
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorBufferInt internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensorBuffer<Int, IntVar>(scope, tensorHandle) {
|
||||||
|
override val tensorData: CPointer<IntVar> = get_data_int(tensorHandle)!!
|
||||||
|
|
||||||
|
override operator fun get(index: Int): Int = tensorData[index]
|
||||||
|
|
||||||
|
override operator fun set(index: Int, value: Int) {
|
||||||
|
tensorData[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun iterator(): Iterator<Int> = (1..size).map { tensorData[it - 1] }.iterator()
|
||||||
|
|
||||||
|
override fun copy(): TorchTensorBufferInt = TorchTensorBufferInt(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_tensor(tensorHandle)!!
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,55 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kscience.kmath.structures.Strides
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import ctorch.*
|
||||||
|
|
||||||
|
public class TorchTensorStrides internal constructor(
|
||||||
|
override val shape: IntArray,
|
||||||
|
override val strides: IntArray,
|
||||||
|
override val linearSize: Int
|
||||||
|
) : Strides {
|
||||||
|
override fun index(offset: Int): IntArray {
|
||||||
|
val nDim = shape.size
|
||||||
|
val res = IntArray(nDim)
|
||||||
|
var current = offset
|
||||||
|
var strideIndex = 0
|
||||||
|
|
||||||
|
while (strideIndex < nDim) {
|
||||||
|
res[strideIndex] = (current / strides[strideIndex])
|
||||||
|
current %= strides[strideIndex]
|
||||||
|
strideIndex++
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private inline fun intPointerToArrayAndClean(ptr: CPointer<IntVar>, nDim: Int): IntArray {
|
||||||
|
val res: IntArray = (1 .. nDim).map{ptr[it-1]}.toIntArray()
|
||||||
|
dispose_int_array(ptr)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun getShapeFromNative(tensorHandle: COpaquePointer, nDim: Int): IntArray{
|
||||||
|
return intPointerToArrayAndClean(get_shape(tensorHandle)!!, nDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun getStridesFromNative(tensorHandle: COpaquePointer, nDim: Int): IntArray{
|
||||||
|
return intPointerToArrayAndClean(get_strides(tensorHandle)!!, nDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun populateStridesFromNative(
|
||||||
|
tensorHandle: COpaquePointer,
|
||||||
|
rawShape: IntArray? = null,
|
||||||
|
rawStrides: IntArray? = null,
|
||||||
|
rawLinearSize: Int? = null
|
||||||
|
): TorchTensorStrides {
|
||||||
|
val nDim = rawShape?.size?: rawStrides?.size?: get_dim(tensorHandle)
|
||||||
|
return TorchTensorStrides(
|
||||||
|
shape = rawShape?: getShapeFromNative(tensorHandle, nDim),
|
||||||
|
strides = rawStrides?: getStridesFromNative(tensorHandle, nDim),
|
||||||
|
linearSize = rawLinearSize?: get_numel(tensorHandle)
|
||||||
|
)
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
|
import ctorch.*
|
||||||
|
|
||||||
|
public fun getNumThreads(): Int {
|
||||||
|
return get_num_threads()
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun setNumThreads(numThreads: Int): Unit {
|
||||||
|
set_num_threads(numThreads)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun cudaAvailable(): Boolean {
|
||||||
|
return cuda_is_available()
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun setSeed(seed: Int): Unit {
|
||||||
|
set_seed(seed)
|
||||||
|
}
|
@ -0,0 +1,33 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
|
||||||
|
import kotlinx.cinterop.memScoped
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
internal class TestTorchTensor {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun intTensorLayout() = memScoped {
|
||||||
|
val array = intArrayOf(7,8,9,2,6,5)
|
||||||
|
val shape = intArrayOf(3,2)
|
||||||
|
val tensor = TorchTensor.copyFromIntArray(scope=this, array=array, shape=shape)
|
||||||
|
tensor.elements().forEach {
|
||||||
|
assertEquals(tensor[it.first], it.second)
|
||||||
|
}
|
||||||
|
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun floatTensorLayout() = memScoped {
|
||||||
|
val array = floatArrayOf(7.5f,8.2f,9f,2.58f,6.5f,5f)
|
||||||
|
val shape = intArrayOf(2,3)
|
||||||
|
val tensor = TorchTensor.copyFromFloatArray(this, array, shape)
|
||||||
|
tensor.elements().forEach {
|
||||||
|
assertEquals(tensor[it.first], it.second)
|
||||||
|
}
|
||||||
|
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
|
||||||
|
internal class TestUtils {
|
||||||
|
@Test
|
||||||
|
fun settingTorchThreadsCount(){
|
||||||
|
val numThreads = 2
|
||||||
|
setNumThreads(numThreads)
|
||||||
|
assertEquals(numThreads, getNumThreads())
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
fun cudaAvailability(){
|
||||||
|
assertTrue(cudaAvailable())
|
||||||
|
}
|
||||||
|
}
|
@ -42,3 +42,7 @@ include(
|
|||||||
":kmath-kotlingrad",
|
":kmath-kotlingrad",
|
||||||
":examples"
|
":examples"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(System.getProperty("os.name") == "Linux"){
|
||||||
|
include(":kmath-torch")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user