Refactoring project structure

This commit is contained in:
Roland Grinis 2021-01-16 21:03:11 +00:00
parent 889691a122
commit 6eb718f64a
27 changed files with 608 additions and 394 deletions

3
.gitignore vendored
View File

@ -9,3 +9,6 @@ out/
# Cache of project # Cache of project
.gradletasknamecache .gradletasknamecache
# Generated by javac -h
*.class

View File

@ -1,49 +1,70 @@
package kscience.kmath.structures package kscience.kmath.structures
import kscience.kmath.operations.* import kscience.kmath.operations.*
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> { public interface TensorStructure<T> : MutableNDStructure<T> {
// A tensor can have empty shape, in which case it represents just a value
public abstract fun value(): T
}
public operator fun T.plus(other: TorchTensorType): TorchTensorType // https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public operator fun TorchTensorType.plus(value: T): TorchTensorType
public operator fun TorchTensorType.plusAssign(value: T): Unit
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
public operator fun T.minus(other: TorchTensorType): TorchTensorType public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public operator fun TorchTensorType.minus(value: T): TorchTensorType
public operator fun TorchTensorType.minusAssign(value: T): Unit
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
public operator fun T.times(other: TorchTensorType): TorchTensorType public operator fun T.plus(other: TensorType): TensorType
public operator fun TorchTensorType.times(value: T): TorchTensorType public operator fun TensorType.plus(value: T): TensorType
public operator fun TorchTensorType.timesAssign(value: T): Unit public operator fun TensorType.plus(other: TensorType): TensorType
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit public operator fun TensorType.plusAssign(value: T): Unit
public operator fun TensorType.plusAssign(other: TensorType): Unit
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType public operator fun T.minus(other: TensorType): TensorType
public operator fun TensorType.minus(value: T): TensorType
public operator fun TensorType.minus(other: TensorType): TensorType
public operator fun TensorType.minusAssign(value: T): Unit
public operator fun TensorType.minusAssign(other: TensorType): Unit
public operator fun T.times(other: TensorType): TensorType
public operator fun TensorType.times(value: T): TensorType
public operator fun TensorType.times(other: TensorType): TensorType
public operator fun TensorType.timesAssign(value: T): Unit
public operator fun TensorType.timesAssign(other: TensorType): Unit
public operator fun TensorType.unaryMinus(): TensorType
public infix fun TensorType.dot(other: TensorType): TensorType
public infix fun TensorType.dotAssign(other: TensorType): Unit
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
public fun diagonalEmbedding( public fun diagonalEmbedding(
diagonalEntries: TorchTensorType, diagonalEntries: TensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1 offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TorchTensorType ): TensorType
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType public fun TensorType.transpose(i: Int, j: Int): TensorType
public fun TorchTensorType.view(shape: IntArray): TorchTensorType public fun TensorType.transposeAssign(i: Int, j: Int): Unit
public fun TorchTensorType.abs(): TorchTensorType public fun TensorType.view(shape: IntArray): TensorType
public fun TorchTensorType.sum(): TorchTensorType
public fun TensorType.abs(): TensorType
public fun TensorType.absAssign(): Unit
public fun TensorType.sum(): TensorType
public fun TensorType.sumAssign(): Unit
} }
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> : // https://proofwiki.org/wiki/Definition:Division_Algebra
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
public operator fun TorchTensorType.divAssign(b: TorchTensorType) public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
TensorAlgebra<T, TensorType> {
public fun TorchTensorType.exp(): TorchTensorType public operator fun TensorType.div(other: TensorType): TensorType
public fun TorchTensorType.log(): TorchTensorType public operator fun TensorType.divAssign(other: TensorType)
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> public fun TensorType.exp(): TensorType
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType> public fun TensorType.expAssign(): Unit
public fun TensorType.log(): TensorType
public fun TensorType.logAssign(): Unit
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
} }

View File

@ -1,11 +0,0 @@
package kscience.kmath.structures
public abstract class TensorStructure<T>: MutableNDStructure<T> {
// A tensor can have empty shape, in which case it represents just a value
public abstract fun value(): T
// Tensors are mutable and might hold shared resources
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}

View File

@ -1,5 +1,6 @@
import de.undercouch.gradle.tasks.download.Download import de.undercouch.gradle.tasks.download.Download
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
import org.gradle.api.JavaVersion.VERSION_11
plugins { plugins {
@ -7,10 +8,16 @@ plugins {
id("de.undercouch.download") id("de.undercouch.download")
} }
java {
sourceCompatibility = VERSION_11
targetCompatibility = VERSION_11
}
val home = System.getProperty("user.home") val home = System.getProperty("user.home")
val javaHome = System.getProperty("java.home")
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}" val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
val cppBuildDir = "$thirdPartyDir/cpp-build" val cppBuildDir = "$thirdPartyDir/cpp-build"
val cppSources = projectDir.resolve("src/cppMain")
val cudaHome: String? = System.getenv("CUDA_HOME") val cudaHome: String? = System.getenv("CUDA_HOME")
val cudaDefault = file("/usr/local/cuda").exists() val cudaDefault = file("/usr/local/cuda").exists()
@ -77,10 +84,11 @@ val configureCpp by tasks.registering {
workingDir(cppBuildDir) workingDir(cppBuildDir)
commandLine( commandLine(
cmakeCmd, cmakeCmd,
projectDir.resolve("ctorch"), cppSources,
"-GNinja", "-GNinja",
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd", "-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive", "-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
"-DJAVA_HOME=$javaHome",
"-DCMAKE_BUILD_TYPE=Release" "-DCMAKE_BUILD_TYPE=Release"
) )
} }
@ -107,10 +115,23 @@ val buildCpp by tasks.registering {
} }
} }
val generateJNIHeader by tasks.registering {
doLast {
exec {
workingDir(projectDir.resolve("src/jvmMain/java/kscience/kmath/torch"))
commandLine("$javaHome/bin/javac", "-h", cppSources.resolve("include") , "JTorch.java")
}
}
}
kotlin { kotlin {
explicitApiWarning() explicitApiWarning()
val nativeTarget = linuxX64("torch") jvm {
withJava()
}
val nativeTarget = linuxX64("native")
nativeTarget.apply { nativeTarget.apply {
binaries { binaries {
all { all {
@ -128,38 +149,38 @@ kotlin {
val main by nativeTarget.compilations.getting { val main by nativeTarget.compilations.getting {
cinterops { cinterops {
val libctorch by creating { val libctorch by creating {
includeDirs(projectDir.resolve("ctorch/include")) includeDirs(cppSources.resolve("include"))
} }
} }
} }
val test by nativeTarget.compilations.getting val test by nativeTarget.compilations.getting
sourceSets { sourceSets {
val nativeMain by creating {
val commonMain by getting {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
} }
} }
val nativeTest by creating {
dependsOn(nativeMain)
}
val nativeGPUTest by creating {
dependsOn(nativeMain)
dependsOn(nativeTest)
}
val nativeMain by getting {
main.defaultSourceSet.dependsOn(nativeMain) dependencies {
test.defaultSourceSet.dependsOn(nativeTest) api(project(":kmath-core"))
if(cudaFound) { }
test.defaultSourceSet.dependsOn(nativeGPUTest)
} }
} }
} }
val torch: KotlinNativeTarget by kotlin.targets val native: KotlinNativeTarget by kotlin.targets
tasks[torch.compilations["main"].cinterops["libctorch"].interopProcessingTaskName] tasks[native.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
.dependsOn(buildCpp) .dependsOn(buildCpp)
tasks["jvmProcessResources"].dependsOn(buildCpp)
tasks {
withType<Test>{
systemProperty("java.library.path", cppBuildDir.toString())
}
}

View File

@ -0,0 +1,52 @@
package kscience.kmath.torch
import kscience.kmath.structures.TensorStructure
public interface TorchTensor<T> : TensorStructure<T> {
public fun item(): T
public val strides: IntArray
public val size: Int
public val device: Device
override fun value(): T {
checkIsValue()
return item()
}
override fun elements(): Sequence<Pair<IntArray, T>> {
if (dimension == 0) {
return emptySequence()
}
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
return indices.map { it to get(it) }
}
}
public inline fun <T> TorchTensor<T>.isValue(): Boolean {
return (dimension == 0)
}
public inline fun <T> TorchTensor<T>.isNotValue(): Boolean = !this.isValue()
public inline fun <T> TorchTensor<T>.checkIsValue(): Unit = check(this.isValue()) {
"This tensor has shape ${shape.toList()}"
}
public interface TorchTensorOverField<T>: TorchTensor<T>
{
public var requiresGrad: Boolean
}
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}

View File

@ -0,0 +1,140 @@
package kscience.kmath.torch
import kscience.kmath.structures.*
public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>> :
TensorAlgebra<T, TorchTensorType> {
public fun getNumThreads(): Int
public fun setNumThreads(numThreads: Int): Unit
public fun cudaAvailable(): Boolean
public fun setSeed(seed: Int): Unit
public var checks: Boolean
public fun copyFromArray(
array: PrimitiveArrayType,
shape: IntArray,
device: Device = Device.CPU
): TorchTensorType
public fun TorchTensorType.copyToArray(): PrimitiveArrayType
public fun full(value: T, shape: IntArray, device: Device): TorchTensorType
public fun randIntegral(
low: T, high: T, shape: IntArray,
device: Device = Device.CPU
): TorchTensorType
public fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
public fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
public fun TorchTensorType.copy(): TorchTensorType
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType
public infix fun TorchTensorType.swap(other: TorchTensorType)
}
public interface TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>> :
TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>, TensorPartialDivisionAlgebra<T, TorchTensorType> {
public fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
public fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
public fun TorchTensorType.randUniform(): TorchTensorType
public fun TorchTensorType.randUniformAssign(): Unit
public fun TorchTensorType.randNormal(): TorchTensorType
public fun TorchTensorType.randNormalAssign(): Unit
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
this.grad(variable, false)
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType
public fun TorchTensorType.detachFromGraph(): TorchTensorType
}
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.withChecks(block: TorchTensorAlgebraType.() -> Unit): Unit {
val state = this.checks
this.checks = true
this.block()
this.checks = state
}
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.checkDeviceCompatible(
a: TorchTensorType, b: TorchTensorType
): Unit =
check(a.device == b.device) {
"Tensors must be on the same device"
}
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.checkShapeCompatible(
a: TorchTensorType,
b: TorchTensorType
): Unit =
check(a.shape contentEquals b.shape) {
"Tensors must be of identical shape"
}
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.checkLinearOperation(
a: TorchTensorType,
b: TorchTensorType
) {
if (a.isNotValue() and b.isNotValue()) {
this.checkDeviceCompatible(a, b)
this.checkShapeCompatible(a, b)
}
}
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
checkDeviceCompatible(a, b)
val sa = a.shape
val sb = b.shape
val na = sa.size
val nb = sb.size
var status: Boolean
if (nb == 1) {
status = sa.last() == sb[0]
} else {
status = sa.last() == sb[nb - 2]
if ((na > 2) and (nb > 2)) {
status = status and
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
}
}
check(status) { "Incompatible shapes $sa and $sb for dot product" }
}
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
check((i < dim) and (j < dim)) {
"Cannot transpose $i to $j for a tensor of dim $dim"
}
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorAlgebraType.checkView(a: TorchTensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
TorchTensorDivisionAlgebraType.withGradAt(
tensor: TorchTensorType,
block: TorchTensorDivisionAlgebraType.(TorchTensorType) -> TorchTensorType
): TorchTensorType {
tensor.requiresGrad = true
return this.block(tensor)
}

View File

@ -0,0 +1,14 @@
package kscience.kmath.torch
import kscience.kmath.memory.DeferScope
public abstract class TorchTensorMemoryHolder internal constructor(
public val scope: DeferScope) {
init {
scope.defer(::close)
}
protected abstract fun close(): Unit
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}

View File

@ -0,0 +1,5 @@
package kscience.kmath.torch
internal val SEED = 987654
internal val TOLERANCE = 1e-6

View File

@ -12,14 +12,23 @@ endif()
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
find_package(JNI REQUIRED)
add_library(ctorch SHARED src/ctorch.cc) add_library(ctorch SHARED src/ctorch.cc)
target_include_directories(ctorch PRIVATE include) target_include_directories(ctorch PRIVATE include)
target_link_libraries(ctorch PRIVATE torch) target_link_libraries(ctorch PRIVATE torch)
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC) target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
add_library(jtorch SHARED src/jtorch.cc)
target_include_directories(jtorch PRIVATE include ${JNI_INCLUDE_DIRS})
target_link_libraries(jtorch PRIVATE torch)
target_compile_options(jtorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
include(GNUInstallDirs) include(GNUInstallDirs)
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h) set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
install(TARGETS ctorch install(TARGETS ctorch
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
install(TARGETS jtorch LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

View File

@ -0,0 +1,53 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class kscience_kmath_torch_JTorch */
#ifndef _Included_kscience_kmath_torch_JTorch
#define _Included_kscience_kmath_torch_JTorch
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: kscience_kmath_torch_JTorch
* Method: getNumThreads
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumThreads
(JNIEnv *, jclass);
/*
* Class: kscience_kmath_torch_JTorch
* Method: setNumThreads
* Signature: (I)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads
(JNIEnv *, jclass, jint);
/*
* Class: kscience_kmath_torch_JTorch
* Method: createTensor
* Signature: ()J
*/
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor
(JNIEnv *, jclass);
/*
* Class: kscience_kmath_torch_JTorch
* Method: printTensor
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor
(JNIEnv *, jclass, jlong);
/*
* Class: kscience_kmath_torch_JTorch
* Method: disposeTensor
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor
(JNIEnv *, jclass, jlong);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -1,9 +1,10 @@
#include <torch/torch.h> #include <torch/torch.h>
#include "ctorch.h"
namespace ctorch namespace ctorch
{ {
using TorchTensorHandle = void*;
template <typename Dtype> template <typename Dtype>
inline c10::ScalarType dtype() inline c10::ScalarType dtype()
{ {

View File

@ -0,0 +1,35 @@
#include <torch/torch.h>
#include <iostream>
#include <stdlib.h>
#include "kscience_kmath_torch_JTorch.h"
#include "utils.hh"
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumThreads(JNIEnv *, jclass)
{
return torch::get_num_threads();
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *, jclass, jint num_threads)
{
torch::set_num_threads(num_threads);
}
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor(JNIEnv *, jclass)
{
auto ten = torch::randn({2, 3});
std::cout << ten << std::endl;
void *ptr = new torch::Tensor(ten);
return (long)ptr;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor(JNIEnv *, jclass, jlong tensor_handle)
{
auto ten = ctorch::cast((void *)tensor_handle);
std::cout << ten << std::endl;
}
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle)
{
delete static_cast<torch::Tensor *>((void *)tensor_handle);
}

View File

@ -0,0 +1,14 @@
package kscience.kmath.torch;
class JTorch {
static {
System.loadLibrary("jtorch");
}
public static native int getNumThreads();
public static native void setNumThreads(int numThreads);
public static native long createTensor();
public static native void printTensor(long tensorHandle);
public static native void disposeTensor(long tensorHandle);
}

View File

@ -0,0 +1,25 @@
package kscience.kmath.torch
public fun getNumThreads(): Int {
return JTorch.getNumThreads()
}
public fun setNumThreads(numThreads: Int): Unit {
JTorch.setNumThreads(numThreads)
}
public fun cudaAvailable(): Boolean {
TODO("Implementation not available yet")
}
public fun setSeed(seed: Int): Unit {
TODO("Implementation not available yet")
}
public fun runCPD(): Unit {
val tensorHandle = JTorch.createTensor()
JTorch.printTensor(tensorHandle)
JTorch.disposeTensor(tensorHandle)
}

View File

@ -0,0 +1,19 @@
package kscience.kmath.torch
import kotlin.test.*
class TestUtils {
@Test
fun testSetNumThreads() {
val numThreads = 2
setNumThreads(numThreads)
assertEquals(numThreads, getNumThreads())
}
@Test
fun testCPD() {
runCPD()
}
}

View File

@ -1,2 +1,2 @@
package=kscience.kmath.ctorch package=kscience.kmath.torch.ctorch
headers=ctorch.h headers=ctorch.h

View File

@ -1,163 +1,90 @@
package kscience.kmath.torch package kscience.kmath.torch
import kscience.kmath.structures.*
import kscience.kmath.memory.DeferScope import kscience.kmath.memory.DeferScope
import kscience.kmath.memory.withDeferScope import kscience.kmath.memory.withDeferScope
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.torch.ctorch.*
public sealed class TorchTensorAlgebra< public sealed class TorchTensorAlgebraNative<
T, T,
TVar : CPrimitiveVar, TVar : CPrimitiveVar,
PrimitiveArrayType, PrimitiveArrayType,
TorchTensorType : TorchTensor<T>> constructor( TorchTensorType : TorchTensorNative<T>> constructor(
internal val scope: DeferScope internal val scope: DeferScope
) : ) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
TensorAlgebra<T, TorchTensorType> {
override fun getNumThreads(): Int {
return get_num_threads()
}
override fun setNumThreads(numThreads: Int): Unit {
set_num_threads(numThreads)
}
override fun cudaAvailable(): Boolean {
return cuda_is_available()
}
override fun setSeed(seed: Int): Unit {
set_seed(seed)
}
override var checks: Boolean = false
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
public abstract fun copyFromArray(
array: PrimitiveArrayType,
shape: IntArray,
device: Device = Device.CPU
): TorchTensorType
public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
public abstract fun TorchTensorType.getData(): CPointer<TVar> public abstract fun TorchTensorType.getData(): CPointer<TVar>
public abstract fun full(value: T, shape: IntArray, device: Device): TorchTensorType override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
public abstract fun randIntegral( return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
low: T, high: T, shape: IntArray,
device: Device = Device.CPU
): TorchTensorType
public abstract fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
public abstract fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
override val zero: TorchTensorType
get() = number(0)
override val one: TorchTensorType
get() = number(1)
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
check(a.device == b.device) {
"Tensors must be on the same device"
} }
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit = override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
check(a.shape contentEquals b.shape) { if (checks) checkLinearOperation(this, other)
"Tensors must be of identical shape" times_tensor_assign(this.tensorHandle, other.tensorHandle)
} }
protected inline fun checkLinearOperation(a: TorchTensorType, b: TorchTensorType) { override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
if (a.isNotValue() and b.isNotValue()) { if (checks) checkLinearOperation(this, other)
checkDeviceCompatible(a, b) return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
checkShapeCompatible(a, b)
}
} }
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType = override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
this.times(b, safe = true) if (checks) checkLinearOperation(this, other)
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
public fun TorchTensorType.times(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkLinearOperation(this, b)
return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!)
} }
override operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit = override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
this.timesAssign(b, safe = true) if (checks) checkLinearOperation(this, other)
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
public fun TorchTensorType.timesAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkLinearOperation(this, b)
times_tensor_assign(this.tensorHandle, b.tensorHandle)
} }
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
if (checks) checkLinearOperation(this, other)
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType = minus_tensor_assign(this.tensorHandle, other.tensorHandle)
this.plus(b, safe = true)
public fun TorchTensorType.plus(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkLinearOperation(this, b)
return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!)
}
override operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit =
this.plusAssign(b, false)
public fun TorchTensorType.plusAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkLinearOperation(this, b)
plus_tensor_assign(this.tensorHandle, b.tensorHandle)
}
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType =
this.minus(b, safe = true)
public fun TorchTensorType.minus(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkLinearOperation(this, b)
return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!)
}
override operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit =
this.minusAssign(b, safe = true)
public fun TorchTensorType.minusAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkLinearOperation(this, b)
minus_tensor_assign(this.tensorHandle, b.tensorHandle)
} }
override operator fun TorchTensorType.unaryMinus(): TorchTensorType = override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(unary_minus(this.tensorHandle)!!) wrap(unary_minus(this.tensorHandle)!!)
private inline fun checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit { override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
checkDeviceCompatible(a, b) if (checks) checkDotOperation(this, other)
val sa = a.shape return wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
val sb = b.shape
val na = sa.size
val nb = sb.size
var status: Boolean
if (nb == 1) {
status = sa.last() == sb[0]
} else {
status = sa.last() == sb[nb - 2]
if ((na > 2) and (nb > 2)) {
status = status and
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
}
}
check(status) { "Incompatible shapes $sa and $sb for dot product" }
} }
override infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType = override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
this.dot(b, safe = true) if (checks) checkDotOperation(this, other)
matmul_assign(this.tensorHandle, other.tensorHandle)
public fun TorchTensorType.dot(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkDotOperation(this, b)
return wrap(matmul(this.tensorHandle, b.tensorHandle)!!)
} }
public infix fun TorchTensorType.dotAssign(b: TorchTensorType): Unit = override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
this.dotAssign(b, safe = true) if (checks) checkDotOperation(this, other)
matmul_right_assign(this.tensorHandle, other.tensorHandle)
public fun TorchTensorType.dotAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkDotOperation(this, b)
matmul_assign(this.tensorHandle, b.tensorHandle)
}
public infix fun TorchTensorType.dotRightAssign(b: TorchTensorType): Unit =
this.dotRightAssign(b, safe = true)
public fun TorchTensorType.dotRightAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkDotOperation(this, b)
matmul_right_assign(this.tensorHandle, b.tensorHandle)
} }
override fun diagonalEmbedding( override fun diagonalEmbedding(
@ -165,106 +92,78 @@ public sealed class TorchTensorAlgebra<
): TorchTensorType = ): TorchTensorType =
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!) wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
private inline fun checkTranspose(dim: Int, i: Int, j: Int): Unit = override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
check((i < dim) and (j < dim)) { if (checks) checkTranspose(this.dimension, i, j)
"Cannot transpose $i to $j for a tensor of dim $dim"
}
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
this.transpose(i, j, safe = true)
public fun TorchTensorType.transpose(i: Int, j: Int, safe: Boolean): TorchTensorType {
if (safe) checkTranspose(this.dimension, i, j)
return wrap(transpose_tensor(tensorHandle, i, j)!!) return wrap(transpose_tensor(tensorHandle, i, j)!!)
} }
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit = override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
this.transposeAssign(i, j, safe = true) if (checks) checkTranspose(this.dimension, i, j)
public fun TorchTensorType.transposeAssign(i: Int, j: Int, safe: Boolean): Unit {
if (safe) checkTranspose(this.dimension, i, j)
transpose_tensor_assign(tensorHandle, i, j) transpose_tensor_assign(tensorHandle, i, j)
} }
private inline fun checkView(a: TorchTensorType, shape: IntArray): Unit = override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) if (checks) checkView(this, shape)
override fun TorchTensorType.view(shape: IntArray): TorchTensorType =
this.view(shape, safe = true)
public fun TorchTensorType.view(shape: IntArray, safe: Boolean): TorchTensorType {
if (safe) checkView(this, shape)
return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!) return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
} }
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!) override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
public fun TorchTensorType.absAssign(): Unit { override fun TorchTensorType.absAssign(): Unit {
abs_tensor_assign(tensorHandle) abs_tensor_assign(tensorHandle)
} }
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!) override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
public fun TorchTensorType.sumAssign(): Unit { override fun TorchTensorType.sumAssign(): Unit {
sum_tensor_assign(tensorHandle) sum_tensor_assign(tensorHandle)
} }
public fun TorchTensorType.copy(): TorchTensorType = override fun TorchTensorType.copy(): TorchTensorType =
wrap(copy_tensor(this.tensorHandle)!!) wrap(copy_tensor(this.tensorHandle)!!)
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType = override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
wrap(copy_to_device(this.tensorHandle, device.toInt())!!) wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit { override infix fun TorchTensorType.swap(other: TorchTensorType): Unit {
swap_tensors(this.tensorHandle, otherTensor.tensorHandle) swap_tensors(this.tensorHandle, other.tensorHandle)
} }
} }
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar, public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>>(scope: DeferScope) : PrimitiveArrayType, TorchTensorType : TorchTensorOverFieldNative<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope), TorchTensorAlgebraNative<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
TensorFieldAlgebra<T, TorchTensorType> { TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
override operator fun TorchTensorType.div(b: TorchTensorType): TorchTensorType = override operator fun TorchTensorType.div(b: TorchTensorType): TorchTensorType {
this.div(b, safe = true) if (checks) checkLinearOperation(this, b)
public fun TorchTensorType.div(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkLinearOperation(this, b)
return wrap(div_tensor(this.tensorHandle, b.tensorHandle)!!) return wrap(div_tensor(this.tensorHandle, b.tensorHandle)!!)
} }
override operator fun TorchTensorType.divAssign(b: TorchTensorType): Unit = override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
this.divAssign(b, safe = true) if (checks) checkLinearOperation(this, other)
div_tensor_assign(this.tensorHandle, other.tensorHandle)
public fun TorchTensorType.divAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkLinearOperation(this, b)
div_tensor_assign(this.tensorHandle, b.tensorHandle)
} }
override fun divide(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a / b override fun TorchTensorType.randUniform(): TorchTensorType =
public abstract fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
public abstract fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
public fun TorchTensorType.randUniform(): TorchTensorType =
wrap(rand_like(this.tensorHandle)!!) wrap(rand_like(this.tensorHandle)!!)
public fun TorchTensorType.randUniformAssign(): Unit { override fun TorchTensorType.randUniformAssign(): Unit {
rand_like_assign(this.tensorHandle) rand_like_assign(this.tensorHandle)
} }
public fun TorchTensorType.randNormal(): TorchTensorType = override fun TorchTensorType.randNormal(): TorchTensorType =
wrap(randn_like(this.tensorHandle)!!) wrap(randn_like(this.tensorHandle)!!)
public fun TorchTensorType.randNormalAssign(): Unit { override fun TorchTensorType.randNormalAssign(): Unit {
randn_like_assign(this.tensorHandle) randn_like_assign(this.tensorHandle)
} }
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!) override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
public fun TorchTensorType.expAssign(): Unit { override fun TorchTensorType.expAssign(): Unit {
exp_tensor_assign(tensorHandle) exp_tensor_assign(tensorHandle)
} }
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!) override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
public fun TorchTensorType.logAssign(): Unit { override fun TorchTensorType.logAssign(): Unit {
log_tensor_assign(tensorHandle) log_tensor_assign(tensorHandle)
} }
@ -283,31 +182,26 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
return Pair(wrap(S), wrap(V)) return Pair(wrap(S), wrap(V))
} }
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType { override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
this.checkIsValue() if (checks) this.checkIsValue()
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!) return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
} }
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType = override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
this.grad(variable, false) if (checks) this.checkIsValue()
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
this.checkIsValue()
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!) return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
} }
public fun TorchTensorType.detachFromGraph(): TorchTensorType = override fun TorchTensorType.detachFromGraph(): TorchTensorType =
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!) wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
} }
public class TorchTensorRealAlgebra(scope: DeferScope) : public class TorchTensorRealAlgebra(scope: DeferScope) :
TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) { TorchTensorPartialDivisionAlgebraNative<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal = override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
TorchTensorReal(scope = scope, tensorHandle = tensorHandle) TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorReal =
full(value.toDouble(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorReal.copyToArray(): DoubleArray = override fun TorchTensorReal.copyToArray(): DoubleArray =
this.elements().map { it.second }.toList().toDoubleArray() this.elements().map { it.second }.toList().toDoubleArray()
@ -360,8 +254,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
times_double_assign(value, this.tensorHandle) times_double_assign(value, this.tensorHandle)
} }
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
@ -377,13 +269,10 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
} }
public class TorchTensorFloatAlgebra(scope: DeferScope) : public class TorchTensorFloatAlgebra(scope: DeferScope) :
TorchTensorFieldAlgebra<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) { TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat = override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle) TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorFloat =
full(value.toFloat(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorFloat.copyToArray(): FloatArray = override fun TorchTensorFloat.copyToArray(): FloatArray =
this.elements().map { it.second }.toList().toFloatArray() this.elements().map { it.second }.toList().toFloatArray()
@ -436,8 +325,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
times_float_assign(value, this.tensorHandle) times_float_assign(value, this.tensorHandle)
} }
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
@ -453,13 +340,10 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
} }
public class TorchTensorLongAlgebra(scope: DeferScope) : public class TorchTensorLongAlgebra(scope: DeferScope) :
TorchTensorAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) { TorchTensorAlgebraNative<Long, LongVar, LongArray, TorchTensorLong>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong = override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
TorchTensorLong(scope = scope, tensorHandle = tensorHandle) TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorLong =
full(value.toLong(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorLong.copyToArray(): LongArray = override fun TorchTensorLong.copyToArray(): LongArray =
this.elements().map { it.second }.toList().toLongArray() this.elements().map { it.second }.toList().toLongArray()
@ -516,20 +400,15 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
times_long_assign(value, this.tensorHandle) times_long_assign(value, this.tensorHandle)
} }
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong = override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
} }
public class TorchTensorIntAlgebra(scope: DeferScope) : public class TorchTensorIntAlgebra(scope: DeferScope) :
TorchTensorAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) { TorchTensorAlgebraNative<Int, IntVar, IntArray, TorchTensorInt>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt = override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
TorchTensorInt(scope = scope, tensorHandle = tensorHandle) TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorInt =
full(value.toInt(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorInt.copyToArray(): IntArray = override fun TorchTensorInt.copyToArray(): IntArray =
this.elements().map { it.second }.toList().toIntArray() this.elements().map { it.second }.toList().toIntArray()
@ -586,12 +465,11 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
times_int_assign(value, this.tensorHandle) times_int_assign(value, this.tensorHandle)
} }
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt = override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
} }
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
withDeferScope { TorchTensorRealAlgebra(this).block() } withDeferScope { TorchTensorRealAlgebra(this).block() }
@ -604,12 +482,3 @@ public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() ->
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R = public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
withDeferScope { TorchTensorIntAlgebra(this).block() } withDeferScope { TorchTensorIntAlgebra(this).block() }
public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
this.requiresGrad = true
return TorchTensorRealAlgebra(this.scope).block()
}
public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
this.requiresGrad = true
return TorchTensorFloatAlgebra(this.scope).block()
}

View File

@ -1,32 +1,25 @@
package kscience.kmath.torch package kscience.kmath.torch
import kscience.kmath.structures.TensorStructure
import kscience.kmath.memory.DeferScope import kscience.kmath.memory.DeferScope
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.torch.ctorch.*
public sealed class TorchTensor<T> constructor( public sealed class TorchTensorNative<T> constructor(
public val scope: DeferScope, scope: DeferScope,
internal val tensorHandle: COpaquePointer internal val tensorHandle: COpaquePointer
) : TensorStructure<T>() { ) : TorchTensor<T>, TorchTensorMemoryHolder(scope) {
init {
scope.defer(::close)
}
private fun close(): Unit = dispose_tensor(tensorHandle) override fun close(): Unit = dispose_tensor(tensorHandle)
protected abstract fun item(): T
override val dimension: Int get() = get_dim(tensorHandle) override val dimension: Int get() = get_dim(tensorHandle)
override val shape: IntArray override val shape: IntArray
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray() get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
public val strides: IntArray override val strides: IntArray
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray() get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
public val size: Int get() = get_numel(tensorHandle) override val size: Int get() = get_numel(tensorHandle)
public val device: Device get() = Device.fromInt(get_device(tensorHandle)) override val device: Device get() = Device.fromInt(get_device(tensorHandle))
override fun toString(): String { override fun toString(): String {
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!! val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!!
@ -35,25 +28,6 @@ public sealed class TorchTensor<T> constructor(
return stringRepresentation return stringRepresentation
} }
override fun elements(): Sequence<Pair<IntArray, T>> {
if (dimension == 0) {
return emptySequence()
}
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
return indices.map { it to get(it) }
}
public inline fun isValue(): Boolean = dimension == 0
public inline fun isNotValue(): Boolean = !isValue()
internal inline fun checkIsValue() = check(isValue()) {
"This tensor has shape ${shape.toList()}"
}
override fun value(): T {
checkIsValue()
return item()
}
public fun copyToDouble(): TorchTensorReal = TorchTensorReal( public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = copy_to_double(this.tensorHandle)!! tensorHandle = copy_to_double(this.tensorHandle)!!
@ -76,11 +50,11 @@ public sealed class TorchTensor<T> constructor(
} }
public sealed class TorchTensorOverField<T> constructor( public sealed class TorchTensorOverFieldNative<T> constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensor<T>(scope, tensorHandle) { ) : TorchTensorNative<T>(scope, tensorHandle), TorchTensorOverField<T> {
public var requiresGrad: Boolean override var requiresGrad: Boolean
get() = requires_grad(tensorHandle) get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value) set(value) = requires_grad_(tensorHandle, value)
} }
@ -89,7 +63,7 @@ public sealed class TorchTensorOverField<T> constructor(
public class TorchTensorReal internal constructor( public class TorchTensorReal internal constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensorOverField<Double>(scope, tensorHandle) { ) : TorchTensorOverFieldNative<Double>(scope, tensorHandle) {
override fun item(): Double = get_item_double(tensorHandle) override fun item(): Double = get_item_double(tensorHandle)
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues()) override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Double) { override fun set(index: IntArray, value: Double) {
@ -100,7 +74,7 @@ public class TorchTensorReal internal constructor(
public class TorchTensorFloat internal constructor( public class TorchTensorFloat internal constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensorOverField<Float>(scope, tensorHandle) { ) : TorchTensorOverFieldNative<Float>(scope, tensorHandle) {
override fun item(): Float = get_item_float(tensorHandle) override fun item(): Float = get_item_float(tensorHandle)
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues()) override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Float) { override fun set(index: IntArray, value: Float) {
@ -111,7 +85,7 @@ public class TorchTensorFloat internal constructor(
public class TorchTensorLong internal constructor( public class TorchTensorLong internal constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensor<Long>(scope, tensorHandle) { ) : TorchTensorNative<Long>(scope, tensorHandle) {
override fun item(): Long = get_item_long(tensorHandle) override fun item(): Long = get_item_long(tensorHandle)
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues()) override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Long) { override fun set(index: IntArray, value: Long) {
@ -122,24 +96,10 @@ public class TorchTensorLong internal constructor(
public class TorchTensorInt internal constructor( public class TorchTensorInt internal constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensor<Int>(scope, tensorHandle) { ) : TorchTensorNative<Int>(scope, tensorHandle) {
override fun item(): Int = get_item_int(tensorHandle) override fun item(): Int = get_item_int(tensorHandle)
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues()) override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Int) { override fun set(index: IntArray, value: Int) {
set_int(tensorHandle, index.toCValues(), value) set_int(tensorHandle, index.toCValues(), value)
} }
} }
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}

View File

@ -1,20 +0,0 @@
package kscience.kmath.torch
import kotlinx.cinterop.*
import kscience.kmath.ctorch.*
public fun getNumThreads(): Int {
return get_num_threads()
}
public fun setNumThreads(numThreads: Int): Unit {
set_num_threads(numThreads)
}
public fun cudaAvailable(): Boolean {
return cuda_is_available()
}
public fun setSeed(seed: Int): Unit {
set_seed(seed)
}

View File

@ -14,8 +14,8 @@ internal fun benchmarkingMatMultDouble(
setSeed(SEED) setSeed(SEED)
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device) val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device) val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
repeat(numWarmUp) { lhs.dotAssign(rhs, false) } repeat(numWarmUp) { lhs dotAssign rhs }
val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } } val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
println(" ${measuredTime / numIter} p.o. with $numIter iterations") println(" ${measuredTime / numIter} p.o. with $numIter iterations")
} }
} }
@ -31,8 +31,8 @@ internal fun benchmarkingMatMultFloat(
setSeed(SEED) setSeed(SEED)
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device) val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device) val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
repeat(numWarmUp) { lhs.dotAssign(rhs, false) } repeat(numWarmUp) { lhs dotAssign rhs }
val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } } val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
println(" ${measuredTime / numIter} p.o. with $numIter iterations") println(" ${measuredTime / numIter} p.o. with $numIter iterations")
} }
} }

View File

@ -8,12 +8,12 @@ internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
val tensorX = randNormal(shape = intArrayOf(dim), device = device) val tensorX = randNormal(shape = intArrayOf(dim), device = device)
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device) val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(0,1) val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device) val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
val expressionAtX = tensorX.withGrad { val expressionAtX = withGradAt(tensorX, { x ->
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 0.5 * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9
} })
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true) val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
val hessianAtX = expressionAtX hess tensorX val hessianAtX = expressionAtX hess tensorX
@ -25,21 +25,23 @@ internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
} }
} }
internal fun testingBatchedAutoGrad(bath: IntArray, internal fun testingBatchedAutoGrad(
bath: IntArray,
dim: Int, dim: Int,
device: Device = Device.CPU): Unit { device: Device = Device.CPU
): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device) val tensorX = randNormal(shape = bath + intArrayOf(1, dim), device = device)
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device) val randFeatures = randNormal(shape = bath + intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1) val tensorSigma = randFeatures + randFeatures.transpose(-2, -1)
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device) val tensorMu = randNormal(shape = bath + intArrayOf(1, dim), device = device)
val expressionAtX = tensorX.withGrad{ val expressionAtX = withGradAt(tensorX, { x ->
val tensorXt = tensorX.transpose(-1,-2) val xt = x.transpose(-1, -2)
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2 0.5 * (x dot (tensorSigma dot xt)) + (tensorMu dot xt) + 58.2
} })
expressionAtX.sumAssign() expressionAtX.sumAssign()
val gradientAtX = expressionAtX grad tensorX val gradientAtX = expressionAtX grad tensorX
@ -53,8 +55,8 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
internal class TestAutograd { internal class TestAutograd {
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 3) fun testAutoGrad() = testingAutoGrad(dim = 100)
@Test @Test
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30) fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2, 10), dim = 30)
} }

View File

@ -69,4 +69,5 @@ class TestTorchTensor {
viewTensor[intArrayOf(0, 0)] = 10 viewTensor[intArrayOf(0, 0)] = 10
assertEquals(tensor[intArrayOf(0)], 10) assertEquals(tensor[intArrayOf(0)], 10)
} }
} }

View File

@ -51,6 +51,7 @@ internal fun testingMatrixMultiplication(device: Device = Device.CPU): Unit {
internal fun testingLinearStructure(device: Device = Device.CPU): Unit { internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
withChecks {
val shape = intArrayOf(3) val shape = intArrayOf(3)
val tensorA = full(value = -4.5, shape = shape, device = device) val tensorA = full(value = -4.5, shape = shape, device = device)
val tensorB = full(value = 10.9, shape = shape, device = device) val tensorB = full(value = 10.9, shape = shape, device = device)
@ -81,7 +82,8 @@ internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
val error = (expected - result).abs().sum().value() + val error = (expected - result).abs().sum().value() +
(expected - assignResult).abs().sum().value() (expected - assignResult).abs().sum().value()
assertTrue(error < TOLERANCE) assertTrue(error < TOLERANCE)
} println(expected)
}}
} }
internal fun testingTensorTransformations(device: Device = Device.CPU): Unit { internal fun testingTensorTransformations(device: Device = Device.CPU): Unit {

View File

@ -3,9 +3,6 @@ package kscience.kmath.torch
import kotlin.test.* import kotlin.test.*
internal val SEED = 987654
internal val TOLERANCE = 1e-6
internal fun testingSetSeed(device: Device = Device.CPU): Unit { internal fun testingSetSeed(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
@ -22,10 +19,12 @@ internal fun testingSetSeed(device: Device = Device.CPU): Unit {
internal class TestUtils { internal class TestUtils {
@Test @Test
fun testSetNumThreads() { fun testSetNumThreads() {
TorchTensorRealAlgebra {
val numThreads = 2 val numThreads = 2
setNumThreads(numThreads) setNumThreads(numThreads)
assertEquals(numThreads, getNumThreads()) assertEquals(numThreads, getNumThreads())
} }
}
@Test @Test
fun testSetSeed() = testingSetSeed() fun testSetSeed() = testingSetSeed()