Buffer protocol for torch tensors

This commit is contained in:
rgrit91 2020-12-29 22:42:33 +00:00
parent b72b6fb34f
commit a229aaa6a4
16 changed files with 796 additions and 40 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
build/ build/
out/ out/
.idea/ .idea/
.vscode/
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored) # Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
!gradle-wrapper.jar !gradle-wrapper.jar

View File

@ -148,23 +148,28 @@ public interface Strides {
/** /**
* Array strides * Array strides
*/ */
public val strides: List<Int> public val strides: IntArray
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
/** /**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/ */
public val linearSize: Int public val linearSize: Int
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= this.shape[i])
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
// TODO introduce a fast way to calculate index of the next element? // TODO introduce a fast way to calculate index of the next element?
/** /**
@ -183,7 +188,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
/** /**
* Strides for memory access * Strides for memory access
*/ */
override val strides: List<Int> by lazy { override val strides: IntArray by lazy {
sequence { sequence {
var current = 1 var current = 1
yield(1) yield(1)
@ -192,16 +197,9 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
current *= it current *= it
yield(current) yield(current)
} }
}.toList() }.toList().toIntArray()
} }
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= this.shape[i])
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
override fun index(offset: Int): IntArray { override fun index(offset: Int): IntArray {
val res = IntArray(shape.size) val res = IntArray(shape.size)
var current = offset var current = offset
@ -238,20 +236,22 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
} }
/** /**
* Represents [NDStructure] over [Buffer]. * Trait for [NDStructure] over [Buffer].
* *
* @param T the type of items. * @param T the type of items
* @param BufferImpl implementation of [Buffer].
*/ */
public abstract class NDBuffer<T> : NDStructure<T> { public abstract class NDBufferTrait<T, out BufferImpl : Buffer<T>, out StridesImpl: Strides> :
NDStructure<T> {
/** /**
* The underlying buffer. * The underlying buffer.
*/ */
public abstract val buffer: Buffer<T> public abstract val buffer: BufferImpl
/** /**
* The strides to access elements of [Buffer] by linear indices. * The strides to access elements of [Buffer] by linear indices.
*/ */
public abstract val strides: Strides public abstract val strides: StridesImpl
override operator fun get(index: IntArray): T = buffer[strides.offset(index)] override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
@ -259,8 +259,8 @@ public abstract class NDBuffer<T> : NDStructure<T> {
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] } override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
override fun equals(other: Any?): Boolean { public fun checkStridesBufferCompatibility(): Unit = require(strides.linearSize == buffer.size) {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) "Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
} }
override fun hashCode(): Int { override fun hashCode(): Int {
@ -269,6 +269,10 @@ public abstract class NDBuffer<T> : NDStructure<T> {
return result return result
} }
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun toString(): String { override fun toString(): String {
val bufferRepr: String = when (shape.size) { val bufferRepr: String = when (shape.size) {
1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ") 1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
@ -282,10 +286,36 @@ public abstract class NDBuffer<T> : NDStructure<T> {
} }
return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)" return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
} }
} }
/**
* Trait for [MutableNDStructure] over [MutableBuffer].
*
* @param T the type of items
* @param MutableBufferImpl implementation of [MutableBuffer].
*/
public abstract class MutableNDBufferTrait<T, out MutableBufferImpl : MutableBuffer<T>, out StridesImpl: Strides> :
NDBufferTrait<T, MutableBufferImpl, StridesImpl>(), MutableNDStructure<T> {
override fun hashCode(): Int = 0
override fun equals(other: Any?): Boolean = false
override operator fun set(index: IntArray, value: T): Unit =
buffer.set(strides.offset(index), value)
}
/**
* Default representation of [NDStructure] over [Buffer].
*
* @param T the type of items.
*/
public abstract class NDBuffer<T> : NDBufferTrait<T, Buffer<T>, Strides>()
/**
* Default representation of [MutableNDStructure] over [MutableBuffer].
*
* @param T the type of items.
*/
public abstract class MutableNDBuffer<T> : MutableNDBufferTrait<T, MutableBuffer<T>, Strides>()
/** /**
* Boxing generic [NDStructure] * Boxing generic [NDStructure]
*/ */
@ -294,9 +324,7 @@ public class BufferNDStructure<T>(
override val buffer: Buffer<T>, override val buffer: Buffer<T>,
) : NDBuffer<T>() { ) : NDBuffer<T>() {
init { init {
if (strides.linearSize != buffer.size) { checkStridesBufferCompatibility()
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
}
} }
} }
@ -316,22 +344,17 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
} }
/** /**
* Mutable ND buffer based on linear [MutableBuffer]. * Boxing generic [MutableNDStructure].
*/ */
public class MutableBufferNDStructure<T>( public class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: MutableBuffer<T>, override val buffer: MutableBuffer<T>,
) : NDBuffer<T>(), MutableNDStructure<T> { ) : MutableNDBuffer<T>() {
init { init {
require(strides.linearSize == buffer.size) { checkStridesBufferCompatibility()
"Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
} }
} }
override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
}
public inline fun <reified T : Any> NDStructure<T>.combine( public inline fun <reified T : Any> NDStructure<T>.combine(
struct: NDStructure<T>, struct: NDStructure<T>,
crossinline block: (T, T) -> T, crossinline block: (T, T) -> T,

95
kmath-torch/README.md Normal file
View File

@ -0,0 +1,95 @@
# LibTorch extension (`kmath-torch`)
This is a `Kotlin/Native` module, with only `linuxX64` supported so far. This library wraps some of the [PyTorch C++ API](https://pytorch.org/cppdocs), focusing on integrating `Aten` & `Autograd` with `KMath`.
## Installation
To install the library, you have to build & publish locally `kmath-core`, `kmath-memory` with `kmath-torch`:
```
./gradlew -q :kmath-core:publishToMavenLocal :kmath-memory:publishToMavenLocal :kmath-torch:publishToMavenLocal
```
This builds `ctorch`, a C wrapper for `LibTorch` placed inside:
`~/.konan/third-party/kmath-torch-0.2.0-dev-4/cpp-build`
You will have to link against it in your own project. Here is an example of build script for a standalone application:
```kotlin
//build.gradle.kts
plugins {
id("ru.mipt.npm.mpp")
}
repositories {
jcenter()
mavenLocal()
}
val home = System.getProperty("user.home")
val kver = "0.2.0-dev-4"
val cppBuildDir = "$home/.konan/third-party/kmath-torch-$kver/cpp-build"
kotlin {
explicitApiWarning()
val nativeTarget = linuxX64("your.app")
nativeTarget.apply {
binaries {
executable {
entryPoint = "your.app.main"
}
all {
linkerOpts(
"-L$cppBuildDir",
"-Wl,-rpath=$cppBuildDir",
"-lctorch"
)
}
}
}
val main by nativeTarget.compilations.getting
sourceSets {
val nativeMain by creating {
dependencies {
implementation("kscience.kmath:kmath-torch:$kver")
}
}
main.defaultSourceSet.dependsOn(nativeMain)
}
}
```
```kotlin
//settings.gradle.kts
pluginManagement {
repositories {
gradlePluginPortal()
jcenter()
maven("https://dl.bintray.com/mipt-npm/dev")
}
plugins {
id("ru.mipt.npm.mpp") version "0.7.1"
kotlin("jvm") version "1.4.21"
}
}
```
## Usage
Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
```kotlin
memScoped {
val intTensor: TorchTensorInt = TorchTensor.copyFromIntArray(
scope = this,
array = intArrayOf(7,8,9,2,6,5),
shape = intArrayOf(3,2))
println(intTensor)
val floatTensor: TorchTensorFloat = TorchTensor.copyFromFloatArray(
scope = this,
array = floatArrayOf(7f,8.9f,2.6f,5.6f),
shape = intArrayOf(4))
println(intTensor)
}
```

View File

@ -0,0 +1,146 @@
import de.undercouch.gradle.tasks.download.Download
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
plugins {
id("ru.mipt.npm.mpp")
id("de.undercouch.download")
}
val home = System.getProperty("user.home")
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
val cppBuildDir = "$thirdPartyDir/cpp-build"
val cmakeArchive = "cmake-3.19.2-Linux-x86_64"
val torchArchive = "libtorch"
val cmakeCmd = "$thirdPartyDir/$cmakeArchive/bin/cmake"
val ninjaCmd = "$thirdPartyDir/ninja"
val downloadCMake by tasks.registering(Download::class) {
val tarFile = "$cmakeArchive.tar.gz"
src("https://github.com/Kitware/CMake/releases/download/v3.19.2/$tarFile")
dest(File(thirdPartyDir, tarFile))
overwrite(false)
}
val downloadNinja by tasks.registering(Download::class) {
src("https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip")
dest(File(thirdPartyDir, "ninja-linux.zip"))
overwrite(false)
}
val downloadTorch by tasks.registering(Download::class) {
val zipFile = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2Bcu110.zip"
src("https://download.pytorch.org/libtorch/cu110/$zipFile")
dest(File(thirdPartyDir, "$torchArchive.zip"))
overwrite(false)
}
val extractCMake by tasks.registering(Copy::class) {
dependsOn(downloadCMake)
from(tarTree(resources.gzip(downloadCMake.get().dest)))
into(thirdPartyDir)
}
val extractTorch by tasks.registering(Copy::class) {
dependsOn(downloadTorch)
from(zipTree(downloadTorch.get().dest))
into(thirdPartyDir)
}
val extractNinja by tasks.registering(Copy::class) {
dependsOn(downloadNinja)
from(zipTree(downloadNinja.get().dest))
into(thirdPartyDir)
}
val configureCpp by tasks.registering {
dependsOn(extractCMake)
dependsOn(extractNinja)
dependsOn(extractTorch)
onlyIf { !file(cppBuildDir).exists() }
doLast {
exec {
workingDir(thirdPartyDir)
commandLine("mkdir", "-p", cppBuildDir)
}
exec {
workingDir(cppBuildDir)
commandLine(
cmakeCmd,
projectDir.resolve("ctorch"),
"-GNinja",
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
"-DCMAKE_BUILD_TYPE=Release"
)
}
}
}
val cleanCppBuild by tasks.registering {
onlyIf { file(cppBuildDir).exists() }
doLast {
exec {
workingDir(thirdPartyDir)
commandLine("rm", "-rf", cppBuildDir)
}
}
}
val buildCpp by tasks.registering {
dependsOn(configureCpp)
doLast {
exec {
workingDir(cppBuildDir)
commandLine(cmakeCmd, "--build", ".", "--config", "Release")
}
}
}
kotlin {
explicitApiWarning()
val nativeTarget = linuxX64("torch")
nativeTarget.apply {
binaries {
all {
linkerOpts(
"-L$cppBuildDir",
"-Wl,-rpath=$cppBuildDir",
"-lctorch"
)
}
}
}
val main by nativeTarget.compilations.getting {
cinterops {
val libctorch by creating {
includeDirs(projectDir.resolve("ctorch/include"))
}
}
}
val test by nativeTarget.compilations.getting
sourceSets {
val nativeMain by creating {
dependencies {
api(project(":kmath-core"))
}
}
val nativeTest by creating {
dependsOn(nativeMain)
}
main.defaultSourceSet.dependsOn(nativeMain)
test.defaultSourceSet.dependsOn(nativeTest)
}
}
val torch: KotlinNativeTarget by kotlin.targets
tasks[torch.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
.dependsOn(buildCpp)

View File

@ -0,0 +1,25 @@
cmake_minimum_required(VERSION 3.12)
project(CTorch LANGUAGES C CXX)
# Require C++17
set(CMAKE_CXX_STANDARD 17)
# Build configuration
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
endif()
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
find_package(Torch REQUIRED)
add_library(ctorch SHARED src/ctorch.cc)
target_include_directories(ctorch PRIVATE include)
target_link_libraries(ctorch PRIVATE torch)
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
include(GNUInstallDirs)
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
install(TARGETS ctorch
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})

View File

@ -0,0 +1,48 @@
#ifndef CTORCH
#define CTORCH
#include <stdbool.h>
#ifdef __cplusplus
extern "C"
{
#endif
typedef void *TorchTensorHandle;
int get_num_threads();
void set_num_threads(int num_threads);
bool cuda_is_available();
void set_seed(int seed);
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim);
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim);
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim);
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim);
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
double *get_data_double(TorchTensorHandle tensor_handle);
float *get_data_float(TorchTensorHandle tensor_handle);
long *get_data_long(TorchTensorHandle tensor_handle);
int *get_data_int(TorchTensorHandle tensor_handle);
int get_numel(TorchTensorHandle tensor_handle);
int get_dim(TorchTensorHandle tensor_handle);
int *get_shape(TorchTensorHandle tensor_handle);
int *get_strides(TorchTensorHandle tensor_handle);
char *tensor_to_string(TorchTensorHandle tensor_handle);
void dispose_int_array(int *ptr);
void dispose_char(char *ptr);
void dispose_tensor(TorchTensorHandle tensor_handle);
#ifdef __cplusplus
}
#endif
#endif //CTORCH

View File

@ -0,0 +1,56 @@
#include <torch/torch.h>
#include "ctorch.h"
namespace ctorch
{
template <typename Dtype>
inline c10::ScalarType dtype()
{
return torch::kFloat64;
}
template <>
inline c10::ScalarType dtype<float>()
{
return torch::kFloat32;
}
template <>
inline c10::ScalarType dtype<long>()
{
return torch::kInt64;
}
template <>
inline c10::ScalarType dtype<int>()
{
return torch::kInt32;
}
inline torch::Tensor &cast(TorchTensorHandle tensor_handle)
{
return *static_cast<torch::Tensor *>(tensor_handle);
}
template <typename Dtype>
inline torch::Tensor copy_from_blob(Dtype *data, int *shape, int dim)
{
auto shape_vec = std::vector<int64_t>(dim);
shape_vec.assign(shape, shape + dim);
return torch::from_blob(data, shape_vec, dtype<Dtype>()).clone();
}
template <typename IntArray>
inline int *to_dynamic_ints(IntArray arr)
{
size_t n = arr.size();
int *res = (int *)malloc(sizeof(int) * n);
for (size_t i = 0; i < n; i++)
{
res[i] = arr[i];
}
return res;
}
} // namespace ctorch

View File

@ -0,0 +1,110 @@
#include <torch/torch.h>
#include <iostream>
#include <stdlib.h>
#include "ctorch.h"
#include "utils.hh"
int get_num_threads()
{
return torch::get_num_threads();
}
void set_num_threads(int num_threads)
{
torch::set_num_threads(num_threads);
}
bool cuda_is_available()
{
return torch::cuda::is_available();
}
void set_seed(int seed)
{
torch::manual_seed(seed);
}
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim)
{
return new torch::Tensor(ctorch::copy_from_blob<double>(data, shape, dim));
}
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim)
{
return new torch::Tensor(ctorch::copy_from_blob<float>(data, shape, dim));
}
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim)
{
return new torch::Tensor(ctorch::copy_from_blob<long>(data, shape, dim));
}
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim)
{
return new torch::Tensor(ctorch::copy_from_blob<int>(data, shape, dim));
}
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).clone());
}
double *get_data_double(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<double>();
}
float *get_data_float(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<float>();
}
long *get_data_long(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<long>();
}
int *get_data_int(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<int>();
}
int get_numel(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).numel();
}
int get_dim(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).dim();
}
int *get_shape(TorchTensorHandle tensor_handle)
{
return ctorch::to_dynamic_ints(ctorch::cast(tensor_handle).sizes());
}
int *get_strides(TorchTensorHandle tensor_handle)
{
return ctorch::to_dynamic_ints(ctorch::cast(tensor_handle).strides());
}
char *tensor_to_string(TorchTensorHandle tensor_handle)
{
std::stringstream bufrep;
bufrep << ctorch::cast(tensor_handle);
auto rep = bufrep.str();
char *crep = (char *)malloc(rep.length() + 1);
std::strcpy(crep, rep.c_str());
return crep;
}
void dispose_int_array(int *ptr)
{
free(ptr);
}
void dispose_char(char *ptr)
{
free(ptr);
}
void dispose_tensor(TorchTensorHandle tensor_handle)
{
delete static_cast<torch::Tensor *>(tensor_handle);
}

View File

@ -0,0 +1,2 @@
package=ctorch
headers=ctorch.h

View File

@ -0,0 +1,52 @@
package kscience.kmath.torch
import kscience.kmath.structures.*
import kotlinx.cinterop.*
import ctorch.*
public abstract class TorchTensor<T,
TVar : CPrimitiveVar,
TorchTensorBufferImpl : TorchTensorBuffer<T, TVar>> :
MutableNDBufferTrait<T, TorchTensorBufferImpl, TorchTensorStrides>() {
public companion object {
public fun copyFromFloatArray(scope: DeferScope, array: FloatArray, shape: IntArray): TorchTensorFloat {
val tensorHandle: COpaquePointer = copy_from_blob_float(
array.toCValues(), shape.toCValues(), shape.size
)!!
return TorchTensorFloat(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
}
public fun copyFromIntArray(scope: DeferScope, array: IntArray, shape: IntArray): TorchTensorInt {
val tensorHandle: COpaquePointer = copy_from_blob_int(
array.toCValues(), shape.toCValues(), shape.size
)!!
return TorchTensorInt(populateStridesFromNative(tensorHandle, rawShape = shape), scope, tensorHandle)
}
}
override fun toString(): String {
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(buffer.tensorHandle)!!
val stringRepresentation = nativeStringRepresentation.toKString()
dispose_char(nativeStringRepresentation)
return stringRepresentation
}
}
public class TorchTensorFloat internal constructor(
override val strides: TorchTensorStrides,
scope: DeferScope,
tensorHandle: COpaquePointer
): TorchTensor<Float, FloatVar, TorchTensorBufferFloat>() {
override val buffer: TorchTensorBufferFloat = TorchTensorBufferFloat(scope, tensorHandle)
}
public class TorchTensorInt internal constructor(
override val strides: TorchTensorStrides,
scope: DeferScope,
tensorHandle: COpaquePointer
): TorchTensor<Int, IntVar, TorchTensorBufferInt>() {
override val buffer: TorchTensorBufferInt = TorchTensorBufferInt(scope, tensorHandle)
}

View File

@ -0,0 +1,67 @@
package kscience.kmath.torch
import kscience.kmath.structures.MutableBuffer
import kotlinx.cinterop.*
import ctorch.*
public abstract class TorchTensorBuffer<T, TVar : CPrimitiveVar> internal constructor(
internal val scope: DeferScope,
internal val tensorHandle: COpaquePointer
) : MutableBuffer<T> {
init {
scope.defer(::close)
}
internal fun close() {
dispose_tensor(tensorHandle)
}
protected abstract val tensorData: CPointer<TVar>
override val size: Int
get() = get_numel(tensorHandle)
}
public class TorchTensorBufferFloat internal constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensorBuffer<Float, FloatVar>(scope, tensorHandle) {
override val tensorData: CPointer<FloatVar> = get_data_float(tensorHandle)!!
override operator fun get(index: Int): Float = tensorData[index]
override operator fun set(index: Int, value: Float) {
tensorData[index] = value
}
override operator fun iterator(): Iterator<Float> = (1..size).map { tensorData[it - 1] }.iterator()
override fun copy(): TorchTensorBufferFloat = TorchTensorBufferFloat(
scope = scope,
tensorHandle = copy_tensor(tensorHandle)!!
)
}
public class TorchTensorBufferInt internal constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensorBuffer<Int, IntVar>(scope, tensorHandle) {
override val tensorData: CPointer<IntVar> = get_data_int(tensorHandle)!!
override operator fun get(index: Int): Int = tensorData[index]
override operator fun set(index: Int, value: Int) {
tensorData[index] = value
}
override operator fun iterator(): Iterator<Int> = (1..size).map { tensorData[it - 1] }.iterator()
override fun copy(): TorchTensorBufferInt = TorchTensorBufferInt(
scope = scope,
tensorHandle = copy_tensor(tensorHandle)!!
)
}

View File

@ -0,0 +1,55 @@
package kscience.kmath.torch
import kscience.kmath.structures.Strides
import kotlinx.cinterop.*
import ctorch.*
public class TorchTensorStrides internal constructor(
override val shape: IntArray,
override val strides: IntArray,
override val linearSize: Int
) : Strides {
override fun index(offset: Int): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}
}
private inline fun intPointerToArrayAndClean(ptr: CPointer<IntVar>, nDim: Int): IntArray {
val res: IntArray = (1 .. nDim).map{ptr[it-1]}.toIntArray()
dispose_int_array(ptr)
return res
}
private inline fun getShapeFromNative(tensorHandle: COpaquePointer, nDim: Int): IntArray{
return intPointerToArrayAndClean(get_shape(tensorHandle)!!, nDim)
}
private inline fun getStridesFromNative(tensorHandle: COpaquePointer, nDim: Int): IntArray{
return intPointerToArrayAndClean(get_strides(tensorHandle)!!, nDim)
}
internal inline fun populateStridesFromNative(
tensorHandle: COpaquePointer,
rawShape: IntArray? = null,
rawStrides: IntArray? = null,
rawLinearSize: Int? = null
): TorchTensorStrides {
val nDim = rawShape?.size?: rawStrides?.size?: get_dim(tensorHandle)
return TorchTensorStrides(
shape = rawShape?: getShapeFromNative(tensorHandle, nDim),
strides = rawStrides?: getStridesFromNative(tensorHandle, nDim),
linearSize = rawLinearSize?: get_numel(tensorHandle)
)
}

View File

@ -0,0 +1,20 @@
package kscience.kmath.torch
import kotlinx.cinterop.*
import ctorch.*
public fun getNumThreads(): Int {
return get_num_threads()
}
public fun setNumThreads(numThreads: Int): Unit {
set_num_threads(numThreads)
}
public fun cudaAvailable(): Boolean {
return cuda_is_available()
}
public fun setSeed(seed: Int): Unit {
set_seed(seed)
}

View File

@ -0,0 +1,33 @@
package kscience.kmath.torch
import kscience.kmath.structures.asBuffer
import kotlinx.cinterop.memScoped
import kotlin.test.*
internal class TestTorchTensor {
@Test
fun intTensorLayout() = memScoped {
val array = intArrayOf(7,8,9,2,6,5)
val shape = intArrayOf(3,2)
val tensor = TorchTensor.copyFromIntArray(scope=this, array=array, shape=shape)
tensor.elements().forEach {
assertEquals(tensor[it.first], it.second)
}
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
}
@Test
fun floatTensorLayout() = memScoped {
val array = floatArrayOf(7.5f,8.2f,9f,2.58f,6.5f,5f)
val shape = intArrayOf(2,3)
val tensor = TorchTensor.copyFromFloatArray(this, array, shape)
tensor.elements().forEach {
assertEquals(tensor[it.first], it.second)
}
assertTrue(tensor.buffer.contentEquals(array.asBuffer()))
}
}

View File

@ -0,0 +1,19 @@
package kscience.kmath.torch
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class TestUtils {
@Test
fun settingTorchThreadsCount(){
val numThreads = 2
setNumThreads(numThreads)
assertEquals(numThreads, getNumThreads())
}
@Test
fun cudaAvailability(){
assertTrue(cudaAvailable())
}
}

View File

@ -42,3 +42,7 @@ include(
":kmath-kotlingrad", ":kmath-kotlingrad",
":examples" ":examples"
) )
if(System.getProperty("os.name") == "Linux"){
include(":kmath-torch")
}