forked from kscience/kmath
Buffer protocol for torch tensors
This commit is contained in:
parent
b72b6fb34f
commit
a229aaa6a4
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
||||
build/
|
||||
out/
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
||||
!gradle-wrapper.jar
|
||||
|
@ -148,23 +148,28 @@ public interface Strides {
|
||||
/**
|
||||
* Array strides
|
||||
*/
|
||||
public val strides: List<Int>
|
||||
|
||||
/**
|
||||
* Get linear index from multidimensional index
|
||||
*/
|
||||
public fun offset(index: IntArray): Int
|
||||
|
||||
/**
|
||||
* Get multidimensional from linear
|
||||
*/
|
||||
public fun index(offset: Int): IntArray
|
||||
public val strides: IntArray
|
||||
|
||||
/**
|
||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||
*/
|
||||
public val linearSize: Int
|
||||
|
||||
/**
|
||||
* Get linear index from multidimensional index
|
||||
*/
|
||||
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= this.shape[i])
|
||||
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
|
||||
/**
|
||||
* Get multidimensional from linear
|
||||
*/
|
||||
public fun index(offset: Int): IntArray
|
||||
|
||||
// TODO introduce a fast way to calculate index of the next element?
|
||||
|
||||
/**
|
||||
@ -183,7 +188,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
override val strides: List<Int> by lazy {
|
||||
override val strides: IntArray by lazy {
|
||||
sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
@ -192,16 +197,9 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
current *= it
|
||||
yield(current)
|
||||
}
|
||||
}.toList()
|
||||
}.toList().toIntArray()
|
||||
}
|
||||
|
||||
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= this.shape[i])
|
||||
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
|
||||
override fun index(offset: Int): IntArray {
|
||||
val res = IntArray(shape.size)
|
||||
var current = offset
|
||||
@ -238,20 +236,22 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDStructure] over [Buffer].
|
||||
* Trait for [NDStructure] over [Buffer].
|
||||
*
|
||||
* @param T the type of items.
|
||||
* @param T the type of items
|
||||
* @param BufferImpl implementation of [Buffer].
|
||||
*/
|
||||
public abstract class NDBuffer<T> : NDStructure<T> {
|
||||
public abstract class NDBufferTrait<T, out BufferImpl : Buffer<T>, out StridesImpl: Strides> :
|
||||
NDStructure<T> {
|
||||
/**
|
||||
* The underlying buffer.
|
||||
*/
|
||||
public abstract val buffer: Buffer<T>
|
||||
public abstract val buffer: BufferImpl
|
||||
|
||||
/**
|
||||
* The strides to access elements of [Buffer] by linear indices.
|
||||
*/
|
||||
public abstract val strides: Strides
|
||||
public abstract val strides: StridesImpl
|
||||
|
||||
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||
|
||||
@ -259,8 +259,8 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
public fun checkStridesBufferCompatibility(): Unit = require(strides.linearSize == buffer.size) {
|
||||
"Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
@ -269,6 +269,10 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
||||
return result
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
val bufferRepr: String = when (shape.size) {
|
||||
1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
|
||||
@ -282,10 +286,36 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
||||
}
|
||||
return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Trait for [MutableNDStructure] over [MutableBuffer].
|
||||
*
|
||||
* @param T the type of items
|
||||
* @param MutableBufferImpl implementation of [MutableBuffer].
|
||||
*/
|
||||
public abstract class MutableNDBufferTrait<T, out MutableBufferImpl : MutableBuffer<T>, out StridesImpl: Strides> :
|
||||
NDBufferTrait<T, MutableBufferImpl, StridesImpl>(), MutableNDStructure<T> {
|
||||
override fun hashCode(): Int = 0
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
override operator fun set(index: IntArray, value: T): Unit =
|
||||
buffer.set(strides.offset(index), value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Default representation of [NDStructure] over [Buffer].
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public abstract class NDBuffer<T> : NDBufferTrait<T, Buffer<T>, Strides>()
|
||||
|
||||
/**
|
||||
* Default representation of [MutableNDStructure] over [MutableBuffer].
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public abstract class MutableNDBuffer<T> : MutableNDBufferTrait<T, MutableBuffer<T>, Strides>()
|
||||
|
||||
/**
|
||||
* Boxing generic [NDStructure]
|
||||
*/
|
||||
@ -294,9 +324,7 @@ public class BufferNDStructure<T>(
|
||||
override val buffer: Buffer<T>,
|
||||
) : NDBuffer<T>() {
|
||||
init {
|
||||
if (strides.linearSize != buffer.size) {
|
||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
||||
}
|
||||
checkStridesBufferCompatibility()
|
||||
}
|
||||
}
|
||||
|
||||
@ -316,20 +344,15 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
||||
}
|
||||
|
||||
/**
|
||||
* Mutable ND buffer based on linear [MutableBuffer].
|
||||
* Boxing generic [MutableNDStructure].
|
||||
*/
|
||||
public class MutableBufferNDStructure<T>(
|
||||
override val strides: Strides,
|
||||
override val buffer: MutableBuffer<T>,
|
||||
) : NDBuffer<T>(), MutableNDStructure<T> {
|
||||
|
||||
) : MutableNDBuffer<T>() {
|
||||
init {
|
||||
require(strides.linearSize == buffer.size) {
|
||||
"Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
|
||||
}
|
||||
checkStridesBufferCompatibility()
|
||||
}
|
||||
|
||||
override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
||||
}
|
||||
|
||||
public inline fun <reified T : Any> NDStructure<T>.combine(
|
||||
|
95
kmath-torch/README.md
Normal file
95
kmath-torch/README.md
Normal file
@ -0,0 +1,95 @@
|
||||
# LibTorch extension (`kmath-torch`)
|
||||
|
||||
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
|
||||
|
||||
## Installation
|
||||
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
|
||||
```
|
||||
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
|
||||
```
|
||||
|
||||
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
|
||||
|
||||
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
|
||||
|
||||
You will have to link against it in your own project. Here is an example of build script for a standalone application:
|
||||
```kotlin
|
||||
//build.gradle.kts
|
||||
plugins {
|
||||
id("ru.mipt.npm.mpp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
jcenter()
|
||||
mavenLocal()
|
||||
}
|
||||
|
||||
val home = System.getProperty("user.home")
|
||||
val kver = "0.2.0-dev-4"
|
||||
val cppBuildDir = "$home/.konan/third-party/kmath-torch-$kver/cpp-build"
|
||||
|
||||
kotlin {
|
||||
explicitApiWarning()
|
||||
|
||||
val nativeTarget = linuxX64("your.app")
|
||||
nativeTarget.apply {
|
||||
binaries {
|
||||
executable {
|
||||
entryPoint = "your.app.main"
|
||||
}
|
||||
all {
|
||||
linkerOpts(
|
||||
"-L$cppBuildDir",
|
||||
"-Wl,-rpath=$cppBuildDir",
|
||||
"-lctorch"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val main by nativeTarget.compilations.getting
|
||||
|
||||
sourceSets {
|
||||
val nativeMain by creating {
|
||||
dependencies {
|
||||
implementation("kscience.kmath:kmath-torch:$kver")
|
||||
}
|
||||
}
|
||||
main.defaultSourceSet.dependsOn(nativeMain)
|
||||
}
|
||||
}
|
||||
```
|
||||
```kotlin
|
||||
//settings.gradle.kts
|
||||
pluginManagement {
|
||||
repositories {
|
||||
gradlePluginPortal()
|
||||
jcenter()
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
}
|
||||
plugins {
|
||||
id("ru.mipt.npm.mpp") version "0.7.1"
|
||||
kotlin("jvm") version "1.4.21"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
||||
```kotlin
|
||||
memScoped {
|
||||
val intTensor: TorchTensorInt = TorchTensor.copyFromIntArray(
|
||||
scope = this,
|
||||
array = intArrayOf(7,8,9,2,6,5),
|
||||
shape = intArrayOf(3,2))
|
||||
println(intTensor)
|
||||
|
||||
val floatTensor: TorchTensorFloat = TorchTensor.copyFromFloatArray(
|
||||
scope = this,
|
||||
array = floatArrayOf(7f,8.9f,2.6f,5.6f),
|
||||
shape = intArrayOf(4))
|
||||
println(intTensor)
|
||||
}
|
||||
```
|
||||
|
146
kmath-torch/build.gradle.kts
Normal file
146
kmath-torch/build.gradle.kts
Normal file
@ -0,0 +1,146 @@
|
||||
import de.undercouch.gradle.tasks.download.Download
|
||||
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
|
||||
|
||||
|
||||
plugins {
|
||||
id("ru.mipt.npm.mpp")
|
||||
id("de.undercouch.download")
|
||||
}
|
||||
|
||||
|
||||
val home = System.getProperty("user.home")
|
||||
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
||||
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
||||
|
||||
val cmakeArchive = "cmake-3.19.2-Linux-x86_64"
|
||||
val torchArchive = "libtorch"
|
||||
|
||||
val cmakeCmd = "$thirdPartyDir/$cmakeArchive/bin/cmake"
|
||||
val ninjaCmd = "$thirdPartyDir/ninja"
|
||||
|
||||
val downloadCMake by tasks.registering(Download::class) {
|
||||
val tarFile = "$cmakeArchive.tar.gz"
|
||||
src("https://github.com/Kitware/CMake/releases/download/v3.19.2/$tarFile")
|
||||
dest(File(thirdPartyDir, tarFile))
|
||||
overwrite(false)
|
||||
}
|
||||
|
||||
val downloadNinja by tasks.registering(Download::class) {
|
||||
src("https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip")
|
||||
dest(File(thirdPartyDir, "ninja-linux.zip"))
|
||||
overwrite(false)
|
||||
}
|
||||
|
||||
val downloadTorch by tasks.registering(Download::class) {
|
||||
val zipFile = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2Bcu110.zip"
|
||||
src("https://download.pytorch.org/libtorch/cu110/$zipFile")
|
||||
dest(File(thirdPartyDir, "$torchArchive.zip"))
|
||||
overwrite(false)
|
||||
}
|
||||
|
||||
val extractCMake by tasks.registering(Copy::class) {
|
||||
dependsOn(downloadCMake)
|
||||
from(tarTree(resources.gzip(downloadCMake.get().dest)))
|
||||
into(thirdPartyDir)
|
||||
}
|
||||
|
||||
val extractTorch by tasks.registering(Copy::class) {
|
||||
dependsOn(downloadTorch)
|
||||
from(zipTree(downloadTorch.get().dest))
|
||||
into(thirdPartyDir)
|
||||
}
|
||||
|
||||
val extractNinja by tasks.registering(Copy::class) {
|
||||
dependsOn(downloadNinja)
|
||||
from(zipTree(downloadNinja.get().dest))
|
||||
into(thirdPartyDir)
|
||||
}
|
||||
|
||||
val configureCpp by tasks.registering {
|
||||
dependsOn(extractCMake)
|
||||
dependsOn(extractNinja)
|
||||
dependsOn(extractTorch)
|
||||
onlyIf { !file(cppBuildDir).exists() }
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(thirdPartyDir)
|
||||
commandLine("mkdir", "-p", cppBuildDir)
|
||||
}
|
||||
exec {
|
||||
workingDir(cppBuildDir)
|
||||
commandLine(
|
||||
cmakeCmd,
|
||||
projectDir.resolve("ctorch"),
|
||||
"-GNinja",
|
||||
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
|
||||
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
|
||||
"-DCMAKE_BUILD_TYPE=Release"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val cleanCppBuild by tasks.registering {
|
||||
onlyIf { file(cppBuildDir).exists() }
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(thirdPartyDir)
|
||||
commandLine("rm", "-rf", cppBuildDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val buildCpp by tasks.registering {
|
||||
dependsOn(configureCpp)
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(cppBuildDir)
|
||||
commandLine(cmakeCmd, "--build", ".", "--config", "Release")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kotlin {
|
||||
explicitApiWarning()
|
||||
|
||||
val nativeTarget = linuxX64("torch")
|
||||
nativeTarget.apply {
|
||||
binaries {
|
||||
all {
|
||||
linkerOpts(
|
||||
"-L$cppBuildDir",
|
||||
"-Wl,-rpath=$cppBuildDir",
|
||||
"-lctorch"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val main by nativeTarget.compilations.getting {
|
||||
cinterops {
|
||||
val libctorch by creating {
|
||||
includeDirs(projectDir.resolve("ctorch/include"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val test by nativeTarget.compilations.getting
|
||||
|
||||
sourceSets {
|
||||
val nativeMain by creating {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
val nativeTest by creating {
|
||||
dependsOn(nativeMain)
|
||||
}
|
||||
|
||||
main.defaultSourceSet.dependsOn(nativeMain)
|
||||
test.defaultSourceSet.dependsOn(nativeTest)
|
||||
}
|
||||
}
|
||||
|
||||
val torch: KotlinNativeTarget by kotlin.targets
|
||||
tasks[torch.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
||||
.dependsOn(buildCpp)
|
25
kmath-torch/ctorch/CMakeLists.txt
Normal file
25
kmath-torch/ctorch/CMakeLists.txt
Normal file
@ -0,0 +1,25 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
|
||||
project(CTorch LANGUAGES C CXX)
|
||||
|
||||
# Require C++17
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
# Build configuration
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||
endif()
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_library(ctorch SHARED src/ctorch.cc)
|
||||
target_include_directories(ctorch PRIVATE include)
|
||||
target_link_libraries(ctorch PRIVATE torch)
|
||||
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
|
||||
install(TARGETS ctorch
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
48
kmath-torch/ctorch/include/ctorch.h
Normal file
48
kmath-torch/ctorch/include/ctorch.h
Normal file
@ -0,0 +1,48 @@
|
||||
#ifndef CTORCH
|
||||
#define CTORCH
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
{
|
||||
#endif
|
||||
|
||||
typedef void *TorchTensorHandle;
|
||||
|
||||
int get_num_threads();
|
||||
|
||||
void set_num_threads(int num_threads);
|
||||
|
||||
bool cuda_is_available();
|
||||
|
||||
void set_seed(int seed);
|
||||
|
||||
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim);
|
||||
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim);
|
||||
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim);
|
||||
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim);
|
||||
|
||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||
|
||||
int get_numel(TorchTensorHandle tensor_handle);
|
||||
int get_dim(TorchTensorHandle tensor_handle);
|
||||
int *get_shape(TorchTensorHandle tensor_handle);
|
||||
int *get_strides(TorchTensorHandle tensor_handle);
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
||||
|
||||
void dispose_int_array(int *ptr);
|
||||
void dispose_char(char *ptr);
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif //CTORCH
|
56
kmath-torch/ctorch/include/utils.hh
Normal file
56
kmath-torch/ctorch/include/utils.hh
Normal file
@ -0,0 +1,56 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include "ctorch.h"
|
||||
|
||||
namespace ctorch
|
||||
{
|
||||
template <typename Dtype>
|
||||
inline c10::ScalarType dtype()
|
||||
{
|
||||
return torch::kFloat64;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline c10::ScalarType dtype<float>()
|
||||
{
|
||||
return torch::kFloat32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline c10::ScalarType dtype<long>()
|
||||
{
|
||||
return torch::kInt64;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline c10::ScalarType dtype<int>()
|
||||
{
|
||||
return torch::kInt32;
|
||||
}
|
||||
|
||||
inline torch::Tensor &cast(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return *static_cast<torch::Tensor *>(tensor_handle);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor copy_from_blob(Dtype *data, int *shape, int dim)
|
||||
{
|
||||
auto shape_vec = std::vector<int64_t>(dim);
|
||||
shape_vec.assign(shape, shape + dim);
|
||||
return torch::from_blob(data, shape_vec, dtype<Dtype>()).clone();
|
||||
}
|
||||
|
||||
template <typename IntArray>
|
||||
inline int *to_dynamic_ints(IntArray arr)
|
||||
{
|
||||
size_t n = arr.size();
|
||||
int *res = (int *)malloc(sizeof(int) * n);
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
res[i] = arr[i];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace ctorch
|
110
kmath-torch/ctorch/src/ctorch.cc
Normal file
110
kmath-torch/ctorch/src/ctorch.cc
Normal file
@ -0,0 +1,110 @@
|
||||
#include <torch/torch.h>
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "ctorch.h"
|
||||
#include "utils.hh"
|
||||
|
||||
int get_num_threads()
|
||||
{
|
||||
return torch::get_num_threads();
|
||||
}
|
||||
|
||||
void set_num_threads(int num_threads)
|
||||
{
|
||||
torch::set_num_threads(num_threads);
|
||||
}
|
||||
|
||||
bool cuda_is_available()
|
||||
{
|
||||
return torch::cuda::is_available();
|
||||
}
|
||||
|
||||
void set_seed(int seed)
|
||||
{
|
||||
torch::manual_seed(seed);
|
||||
}
|
||||
|
||||
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<double>(data, shape, dim));
|
||||
}
|
||||
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<float>(data, shape, dim));
|
||||
}
|
||||
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<long>(data, shape, dim));
|
||||
}
|
||||
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<int>(data, shape, dim));
|
||||
}
|
||||
|
||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).clone());
|
||||
}
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||
}
|
||||
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||
}
|
||||
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||
}
|
||||
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||
}
|
||||
|
||||
int get_numel(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).numel();
|
||||
}
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).dim();
|
||||
}
|
||||
|
||||
int *get_shape(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::to_dynamic_ints(ctorch::cast(tensor_handle).sizes());
|
||||
}
|
||||
|
||||
int *get_strides(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::to_dynamic_ints(ctorch::cast(tensor_handle).strides());
|
||||
}
|
||||
|
||||
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
std::stringstream bufrep;
|
||||
bufrep << ctorch::cast(tensor_handle);
|
||||
auto rep = bufrep.str();
|
||||
char *crep = (char *)malloc(rep.length() + 1);
|
||||
std::strcpy(crep, rep.c_str());
|
||||
return crep;
|
||||
}
|
||||
|
||||
void dispose_int_array(int *ptr)
|
||||
{
|
||||
free(ptr);
|
||||
}
|
||||
|
||||
void dispose_char(char *ptr)
|
||||
{
|
||||
free(ptr);
|
||||
}
|
||||
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
delete static_cast<torch::Tensor *>(tensor_handle);
|
||||
}
|
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
2
kmath-torch/src/nativeInterop/cinterop/libctorch.def
Normal file
@ -0,0 +1,2 @@
|
||||
package=ctorch
|
||||
headers=ctorch.h
|
@ -0,0 +1,52 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import ctorch.*
|
||||
|
||||
public abstract class TorchTensor<T,
|
||||
TVar : CPrimitiveVar,
|
||||
TorchTensorBufferImpl : TorchTensorBuffer<T, TVar>> :
|
||||
MutableNDBufferTrait<T, TorchTensorBufferImpl, TorchTensorStrides>() {
|
||||
|
||||
public companion object {
|
||||
public fun copyFromFloatArray(scope: DeferScope, array: FloatArray, shape: IntArray): TorchTensorFloat {
|
||||
val tensorHandle: COpaquePointer = copy_from_blob_float(
|
||||
array.toCValues(), shape.toCValues(), shape.size
|
||||
)!!
|
||||
return TorchTensorFloat(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
||||
}
|
||||
public fun copyFromIntArray(scope: DeferScope, array: IntArray, shape: IntArray): TorchTensorInt {
|
||||
val tensorHandle: COpaquePointer = copy_from_blob_int(
|
||||
array.toCValues(), shape.toCValues(), shape.size
|
||||
)!!
|
||||
return TorchTensorInt(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(buffer.tensorHandle)!!
|
||||
val stringRepresentation = nativeStringRepresentation.toKString()
|
||||
dispose_char(nativeStringRepresentation)
|
||||
return stringRepresentation
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorFloat internal constructor(
|
||||
override val strides: TorchTensorStrides,
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
): TorchTensor<Float, FloatVar, TorchTensorBufferFloat>() {
|
||||
override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle)
|
||||
}
|
||||
|
||||
public class TorchTensorInt internal constructor(
|
||||
override val strides: TorchTensorStrides,
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
): TorchTensor<Int, IntVar, TorchTensorBufferInt>() {
|
||||
override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle)
|
||||
}
|
||||
|
@ -0,0 +1,67 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.structures.MutableBuffer
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import ctorch.*
|
||||
|
||||
public abstract class TorchTensorBuffer<T, TVar : CPrimitiveVar> internal constructor(
|
||||
internal val scope: DeferScope,
|
||||
internal val tensorHandle: COpaquePointer
|
||||
) : MutableBuffer<T> {
|
||||
init {
|
||||
scope.defer(::close)
|
||||
}
|
||||
|
||||
internal fun close() {
|
||||
dispose_tensor(tensorHandle)
|
||||
}
|
||||
|
||||
protected abstract val tensorData: CPointer<TVar>
|
||||
|
||||
override val size: Int
|
||||
get() = get_numel(tensorHandle)
|
||||
|
||||
}
|
||||
|
||||
|
||||
public class TorchTensorBufferFloat internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorBuffer<Float, FloatVar>(scope, tensorHandle) {
|
||||
override val tensorData: CPointer<FloatVar> = get_data_float(tensorHandle)!!
|
||||
|
||||
override operator fun get(index: Int): Float = tensorData[index]
|
||||
|
||||
override operator fun set(index: Int, value: Float) {
|
||||
tensorData[index] = value
|
||||
}
|
||||
|
||||
override operator fun iterator(): Iterator<Float> = (1..size).map { tensorData[it - 1] }.iterator()
|
||||
|
||||
override fun copy(): TorchTensorBufferFloat = TorchTensorBufferFloat(
|
||||
scope = scope,
|
||||
tensorHandle = copy_tensor(tensorHandle)!!
|
||||
)
|
||||
}
|
||||
|
||||
public class TorchTensorBufferInt internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorBuffer<Int, IntVar>(scope, tensorHandle) {
|
||||
override val tensorData: CPointer<IntVar> = get_data_int(tensorHandle)!!
|
||||
|
||||
override operator fun get(index: Int): Int = tensorData[index]
|
||||
|
||||
override operator fun set(index: Int, value: Int) {
|
||||
tensorData[index] = value
|
||||
}
|
||||
|
||||
override operator fun iterator(): Iterator<Int> = (1..size).map { tensorData[it - 1] }.iterator()
|
||||
|
||||
override fun copy(): TorchTensorBufferInt = TorchTensorBufferInt(
|
||||
scope = scope,
|
||||
tensorHandle = copy_tensor(tensorHandle)!!
|
||||
)
|
||||
}
|
||||
|
@ -0,0 +1,55 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.structures.Strides
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import ctorch.*
|
||||
|
||||
public class TorchTensorStrides internal constructor(
|
||||
override val shape: IntArray,
|
||||
override val strides: IntArray,
|
||||
override val linearSize: Int
|
||||
) : Strides {
|
||||
override fun index(offset: Int): IntArray {
|
||||
val nDim = shape.size
|
||||
val res = IntArray(nDim)
|
||||
var current = offset
|
||||
var strideIndex = 0
|
||||
|
||||
while (strideIndex < nDim) {
|
||||
res[strideIndex] = (current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex++
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private inline fun intPointerToArrayAndClean(ptr: CPointer<IntVar>, nDim: Int): IntArray {
|
||||
val res: IntArray = (1 .. nDim).map{ptr[it-1]}.toIntArray()
|
||||
dispose_int_array(ptr)
|
||||
return res
|
||||
}
|
||||
|
||||
private inline fun getShapeFromNative(tensorHandle: COpaquePointer, nDim: Int): IntArray{
|
||||
return intPointerToArrayAndClean(get_shape(tensorHandle)!!, nDim)
|
||||
}
|
||||
|
||||
private inline fun getStridesFromNative(tensorHandle: COpaquePointer, nDim: Int): IntArray{
|
||||
return intPointerToArrayAndClean(get_strides(tensorHandle)!!, nDim)
|
||||
}
|
||||
|
||||
internal inline fun populateStridesFromNative(
|
||||
tensorHandle: COpaquePointer,
|
||||
rawShape: IntArray? = null,
|
||||
rawStrides: IntArray? = null,
|
||||
rawLinearSize: Int? = null
|
||||
): TorchTensorStrides {
|
||||
val nDim = rawShape?.size?: rawStrides?.size?: get_dim(tensorHandle)
|
||||
return TorchTensorStrides(
|
||||
shape = rawShape?: getShapeFromNative(tensorHandle, nDim),
|
||||
strides = rawStrides?: getStridesFromNative(tensorHandle, nDim),
|
||||
linearSize = rawLinearSize?: get_numel(tensorHandle)
|
||||
)
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import ctorch.*
|
||||
|
||||
public fun getNumThreads(): Int {
|
||||
return get_num_threads()
|
||||
}
|
||||
|
||||
public fun setNumThreads(numThreads: Int): Unit {
|
||||
set_num_threads(numThreads)
|
||||
}
|
||||
|
||||
public fun cudaAvailable(): Boolean {
|
||||
return cuda_is_available()
|
||||
}
|
||||
|
||||
public fun setSeed(seed: Int): Unit {
|
||||
set_seed(seed)
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.structures.asBuffer
|
||||
|
||||
import kotlinx.cinterop.memScoped
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
internal class TestTorchTensor {
|
||||
|
||||
@Test
|
||||
fun intTensorLayout() = memScoped {
|
||||
val array = intArrayOf(7,8,9,2,6,5)
|
||||
val shape = intArrayOf(3,2)
|
||||
val tensor = TorchTensor.copyFromIntArray(scope=this, array=array, shape=shape)
|
||||
tensor.elements().forEach {
|
||||
assertEquals(tensor[it.first], it.second)
|
||||
}
|
||||
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun floatTensorLayout() = memScoped {
|
||||
val array = floatArrayOf(7.5f,8.2f,9f,2.58f,6.5f,5f)
|
||||
val shape = intArrayOf(2,3)
|
||||
val tensor = TorchTensor.copyFromFloatArray(this, array, shape)
|
||||
tensor.elements().forEach {
|
||||
assertEquals(tensor[it.first], it.second)
|
||||
}
|
||||
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
|
||||
internal class TestUtils {
|
||||
@Test
|
||||
fun settingTorchThreadsCount(){
|
||||
val numThreads = 2
|
||||
setNumThreads(numThreads)
|
||||
assertEquals(numThreads, getNumThreads())
|
||||
}
|
||||
@Test
|
||||
fun cudaAvailability(){
|
||||
assertTrue(cudaAvailable())
|
||||
}
|
||||
}
|
@ -42,3 +42,7 @@ include(
|
||||
":kmath-kotlingrad",
|
||||
":examples"
|
||||
)
|
||||
|
||||
if(System.getProperty("os.name") == "Linux"){
|
||||
include(":kmath-torch")
|
||||
}
|
Loading…
Reference in New Issue
Block a user