Refactoring project structure
This commit is contained in:
parent
889691a122
commit
6eb718f64a
3
.gitignore
vendored
3
.gitignore
vendored
@ -9,3 +9,6 @@ out/
|
||||
|
||||
# Cache of project
|
||||
.gradletasknamecache
|
||||
|
||||
# Generated by javac -h
|
||||
*.class
|
||||
|
@ -1,49 +1,70 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
|
||||
import kscience.kmath.operations.*
|
||||
|
||||
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
|
||||
public interface TensorStructure<T> : MutableNDStructure<T> {
|
||||
// A tensor can have empty shape, in which case it represents just a value
|
||||
public abstract fun value(): T
|
||||
}
|
||||
|
||||
public operator fun T.plus(other: TorchTensorType): TorchTensorType
|
||||
public operator fun TorchTensorType.plus(value: T): TorchTensorType
|
||||
public operator fun TorchTensorType.plusAssign(value: T): Unit
|
||||
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
|
||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||
|
||||
public operator fun T.minus(other: TorchTensorType): TorchTensorType
|
||||
public operator fun TorchTensorType.minus(value: T): TorchTensorType
|
||||
public operator fun TorchTensorType.minusAssign(value: T): Unit
|
||||
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
|
||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
|
||||
public operator fun T.times(other: TorchTensorType): TorchTensorType
|
||||
public operator fun TorchTensorType.times(value: T): TorchTensorType
|
||||
public operator fun TorchTensorType.timesAssign(value: T): Unit
|
||||
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
|
||||
public operator fun T.plus(other: TensorType): TensorType
|
||||
public operator fun TensorType.plus(value: T): TensorType
|
||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||
public operator fun TensorType.plusAssign(value: T): Unit
|
||||
public operator fun TensorType.plusAssign(other: TensorType): Unit
|
||||
|
||||
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
|
||||
public operator fun T.minus(other: TensorType): TensorType
|
||||
public operator fun TensorType.minus(value: T): TensorType
|
||||
public operator fun TensorType.minus(other: TensorType): TensorType
|
||||
public operator fun TensorType.minusAssign(value: T): Unit
|
||||
public operator fun TensorType.minusAssign(other: TensorType): Unit
|
||||
|
||||
public operator fun T.times(other: TensorType): TensorType
|
||||
public operator fun TensorType.times(value: T): TensorType
|
||||
public operator fun TensorType.times(other: TensorType): TensorType
|
||||
public operator fun TensorType.timesAssign(value: T): Unit
|
||||
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||
public operator fun TensorType.unaryMinus(): TensorType
|
||||
|
||||
|
||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||
public infix fun TensorType.dotAssign(other: TensorType): Unit
|
||||
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
|
||||
|
||||
public fun diagonalEmbedding(
|
||||
diagonalEntries: TorchTensorType,
|
||||
diagonalEntries: TensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||
): TorchTensorType
|
||||
): TensorType
|
||||
|
||||
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
|
||||
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
|
||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||
public fun TensorType.transposeAssign(i: Int, j: Int): Unit
|
||||
|
||||
public fun TorchTensorType.abs(): TorchTensorType
|
||||
public fun TorchTensorType.sum(): TorchTensorType
|
||||
public fun TensorType.view(shape: IntArray): TensorType
|
||||
|
||||
public fun TensorType.abs(): TensorType
|
||||
public fun TensorType.absAssign(): Unit
|
||||
public fun TensorType.sum(): TensorType
|
||||
public fun TensorType.sumAssign(): Unit
|
||||
}
|
||||
|
||||
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
|
||||
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
|
||||
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
||||
|
||||
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
|
||||
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
|
||||
TensorAlgebra<T, TensorType> {
|
||||
|
||||
public fun TorchTensorType.exp(): TorchTensorType
|
||||
public fun TorchTensorType.log(): TorchTensorType
|
||||
public operator fun TensorType.div(other: TensorType): TensorType
|
||||
public operator fun TensorType.divAssign(other: TensorType)
|
||||
|
||||
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
|
||||
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
|
||||
public fun TensorType.exp(): TensorType
|
||||
public fun TensorType.expAssign(): Unit
|
||||
public fun TensorType.log(): TensorType
|
||||
public fun TensorType.logAssign(): Unit
|
||||
|
||||
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
||||
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
public abstract class TensorStructure<T>: MutableNDStructure<T> {
|
||||
|
||||
// A tensor can have empty shape, in which case it represents just a value
|
||||
public abstract fun value(): T
|
||||
|
||||
// Tensors are mutable and might hold shared resources
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
override fun hashCode(): Int = 0
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
import de.undercouch.gradle.tasks.download.Download
|
||||
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
|
||||
import org.gradle.api.JavaVersion.VERSION_11
|
||||
|
||||
|
||||
plugins {
|
||||
@ -7,10 +8,16 @@ plugins {
|
||||
id("de.undercouch.download")
|
||||
}
|
||||
|
||||
java {
|
||||
sourceCompatibility = VERSION_11
|
||||
targetCompatibility = VERSION_11
|
||||
}
|
||||
|
||||
val home = System.getProperty("user.home")
|
||||
val javaHome = System.getProperty("java.home")
|
||||
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
||||
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
||||
val cppSources = projectDir.resolve("src/cppMain")
|
||||
|
||||
val cudaHome: String? = System.getenv("CUDA_HOME")
|
||||
val cudaDefault = file("/usr/local/cuda").exists()
|
||||
@ -77,10 +84,11 @@ val configureCpp by tasks.registering {
|
||||
workingDir(cppBuildDir)
|
||||
commandLine(
|
||||
cmakeCmd,
|
||||
projectDir.resolve("ctorch"),
|
||||
cppSources,
|
||||
"-GNinja",
|
||||
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
|
||||
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
|
||||
"-DJAVA_HOME=$javaHome",
|
||||
"-DCMAKE_BUILD_TYPE=Release"
|
||||
)
|
||||
}
|
||||
@ -107,10 +115,23 @@ val buildCpp by tasks.registering {
|
||||
}
|
||||
}
|
||||
|
||||
val generateJNIHeader by tasks.registering {
|
||||
doLast {
|
||||
exec {
|
||||
workingDir(projectDir.resolve("src/jvmMain/java/kscience/kmath/torch"))
|
||||
commandLine("$javaHome/bin/javac", "-h", cppSources.resolve("include") , "JTorch.java")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kotlin {
|
||||
explicitApiWarning()
|
||||
|
||||
val nativeTarget = linuxX64("torch")
|
||||
jvm {
|
||||
withJava()
|
||||
}
|
||||
|
||||
val nativeTarget = linuxX64("native")
|
||||
nativeTarget.apply {
|
||||
binaries {
|
||||
all {
|
||||
@ -128,38 +149,38 @@ kotlin {
|
||||
val main by nativeTarget.compilations.getting {
|
||||
cinterops {
|
||||
val libctorch by creating {
|
||||
includeDirs(projectDir.resolve("ctorch/include"))
|
||||
includeDirs(cppSources.resolve("include"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val test by nativeTarget.compilations.getting
|
||||
|
||||
|
||||
sourceSets {
|
||||
val nativeMain by creating {
|
||||
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
val nativeTest by creating {
|
||||
dependsOn(nativeMain)
|
||||
}
|
||||
val nativeGPUTest by creating {
|
||||
dependsOn(nativeMain)
|
||||
dependsOn(nativeTest)
|
||||
}
|
||||
|
||||
|
||||
main.defaultSourceSet.dependsOn(nativeMain)
|
||||
test.defaultSourceSet.dependsOn(nativeTest)
|
||||
if(cudaFound) {
|
||||
test.defaultSourceSet.dependsOn(nativeGPUTest)
|
||||
val nativeMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
val torch: KotlinNativeTarget by kotlin.targets
|
||||
tasks[torch.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
||||
val native: KotlinNativeTarget by kotlin.targets
|
||||
tasks[native.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
||||
.dependsOn(buildCpp)
|
||||
|
||||
tasks["jvmProcessResources"].dependsOn(buildCpp)
|
||||
|
||||
tasks {
|
||||
withType<Test>{
|
||||
systemProperty("java.library.path", cppBuildDir.toString())
|
||||
}
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.structures.TensorStructure
|
||||
|
||||
|
||||
|
||||
public interface TorchTensor<T> : TensorStructure<T> {
|
||||
public fun item(): T
|
||||
public val strides: IntArray
|
||||
public val size: Int
|
||||
public val device: Device
|
||||
override fun value(): T {
|
||||
checkIsValue()
|
||||
return item()
|
||||
}
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||
if (dimension == 0) {
|
||||
return emptySequence()
|
||||
}
|
||||
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
|
||||
return indices.map { it to get(it) }
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <T> TorchTensor<T>.isValue(): Boolean {
|
||||
return (dimension == 0)
|
||||
}
|
||||
|
||||
public inline fun <T> TorchTensor<T>.isNotValue(): Boolean = !this.isValue()
|
||||
|
||||
public inline fun <T> TorchTensor<T>.checkIsValue(): Unit = check(this.isValue()) {
|
||||
"This tensor has shape ${shape.toList()}"
|
||||
}
|
||||
|
||||
public interface TorchTensorOverField<T>: TorchTensor<T>
|
||||
{
|
||||
public var requiresGrad: Boolean
|
||||
}
|
||||
|
||||
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
val res = IntArray(nDim)
|
||||
var current = offset
|
||||
var strideIndex = 0
|
||||
|
||||
while (strideIndex < nDim) {
|
||||
res[strideIndex] = (current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex++
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
@ -0,0 +1,140 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>> :
|
||||
TensorAlgebra<T, TorchTensorType> {
|
||||
|
||||
public fun getNumThreads(): Int
|
||||
public fun setNumThreads(numThreads: Int): Unit
|
||||
public fun cudaAvailable(): Boolean
|
||||
public fun setSeed(seed: Int): Unit
|
||||
|
||||
public var checks: Boolean
|
||||
|
||||
public fun copyFromArray(
|
||||
array: PrimitiveArrayType,
|
||||
shape: IntArray,
|
||||
device: Device = Device.CPU
|
||||
): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.copyToArray(): PrimitiveArrayType
|
||||
|
||||
public fun full(value: T, shape: IntArray, device: Device): TorchTensorType
|
||||
|
||||
public fun randIntegral(
|
||||
low: T, high: T, shape: IntArray,
|
||||
device: Device = Device.CPU
|
||||
): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
|
||||
public fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
|
||||
|
||||
public fun TorchTensorType.copy(): TorchTensorType
|
||||
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType
|
||||
public infix fun TorchTensorType.swap(other: TorchTensorType)
|
||||
|
||||
}
|
||||
|
||||
public interface TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>> :
|
||||
TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>, TensorPartialDivisionAlgebra<T, TorchTensorType> {
|
||||
|
||||
public fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||
public fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||
public fun TorchTensorType.randUniform(): TorchTensorType
|
||||
public fun TorchTensorType.randUniformAssign(): Unit
|
||||
public fun TorchTensorType.randNormal(): TorchTensorType
|
||||
public fun TorchTensorType.randNormalAssign(): Unit
|
||||
|
||||
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType
|
||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||
this.grad(variable, false)
|
||||
|
||||
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType
|
||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType
|
||||
}
|
||||
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.withChecks(block: TorchTensorAlgebraType.() -> Unit): Unit {
|
||||
val state = this.checks
|
||||
this.checks = true
|
||||
this.block()
|
||||
this.checks = state
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkDeviceCompatible(
|
||||
a: TorchTensorType, b: TorchTensorType
|
||||
): Unit =
|
||||
check(a.device == b.device) {
|
||||
"Tensors must be on the same device"
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkShapeCompatible(
|
||||
a: TorchTensorType,
|
||||
b: TorchTensorType
|
||||
): Unit =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Tensors must be of identical shape"
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkLinearOperation(
|
||||
a: TorchTensorType,
|
||||
b: TorchTensorType
|
||||
) {
|
||||
if (a.isNotValue() and b.isNotValue()) {
|
||||
this.checkDeviceCompatible(a, b)
|
||||
this.checkShapeCompatible(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
||||
checkDeviceCompatible(a, b)
|
||||
val sa = a.shape
|
||||
val sb = b.shape
|
||||
val na = sa.size
|
||||
val nb = sb.size
|
||||
var status: Boolean
|
||||
if (nb == 1) {
|
||||
status = sa.last() == sb[0]
|
||||
} else {
|
||||
status = sa.last() == sb[nb - 2]
|
||||
if ((na > 2) and (nb > 2)) {
|
||||
status = status and
|
||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||
}
|
||||
}
|
||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorAlgebraType.checkView(a: TorchTensorType, shape: IntArray): Unit =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||
|
||||
|
||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||
TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||
TorchTensorDivisionAlgebraType.withGradAt(
|
||||
tensor: TorchTensorType,
|
||||
block: TorchTensorDivisionAlgebraType.(TorchTensorType) -> TorchTensorType
|
||||
): TorchTensorType {
|
||||
tensor.requiresGrad = true
|
||||
return this.block(tensor)
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.memory.DeferScope
|
||||
|
||||
public abstract class TorchTensorMemoryHolder internal constructor(
|
||||
public val scope: DeferScope) {
|
||||
init {
|
||||
scope.defer(::close)
|
||||
}
|
||||
protected abstract fun close(): Unit
|
||||
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
override fun hashCode(): Int = 0
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
internal val SEED = 987654
|
||||
internal val TOLERANCE = 1e-6
|
||||
|
@ -12,14 +12,23 @@ endif()
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
find_package(JNI REQUIRED)
|
||||
|
||||
add_library(ctorch SHARED src/ctorch.cc)
|
||||
target_include_directories(ctorch PRIVATE include)
|
||||
target_link_libraries(ctorch PRIVATE torch)
|
||||
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||
|
||||
add_library(jtorch SHARED src/jtorch.cc)
|
||||
target_include_directories(jtorch PRIVATE include ${JNI_INCLUDE_DIRS})
|
||||
target_link_libraries(jtorch PRIVATE torch)
|
||||
target_compile_options(jtorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
|
||||
install(TARGETS ctorch
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
install(TARGETS jtorch LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
@ -0,0 +1,53 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class kscience_kmath_torch_JTorch */
|
||||
|
||||
#ifndef _Included_kscience_kmath_torch_JTorch
|
||||
#define _Included_kscience_kmath_torch_JTorch
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: getNumThreads
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumThreads
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: setNumThreads
|
||||
* Signature: (I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads
|
||||
(JNIEnv *, jclass, jint);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: createTensor
|
||||
* Signature: ()J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: printTensor
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: kscience_kmath_torch_JTorch
|
||||
* Method: disposeTensor
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
@ -1,9 +1,10 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include "ctorch.h"
|
||||
|
||||
namespace ctorch
|
||||
{
|
||||
|
||||
using TorchTensorHandle = void*;
|
||||
|
||||
template <typename Dtype>
|
||||
inline c10::ScalarType dtype()
|
||||
{
|
35
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
35
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
@ -0,0 +1,35 @@
|
||||
#include <torch/torch.h>
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "kscience_kmath_torch_JTorch.h"
|
||||
#include "utils.hh"
|
||||
|
||||
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumThreads(JNIEnv *, jclass)
|
||||
{
|
||||
return torch::get_num_threads();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *, jclass, jint num_threads)
|
||||
{
|
||||
torch::set_num_threads(num_threads);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor(JNIEnv *, jclass)
|
||||
{
|
||||
auto ten = torch::randn({2, 3});
|
||||
std::cout << ten << std::endl;
|
||||
void *ptr = new torch::Tensor(ten);
|
||||
return (long)ptr;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
auto ten = ctorch::cast((void *)tensor_handle);
|
||||
std::cout << ten << std::endl;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||
{
|
||||
delete static_cast<torch::Tensor *>((void *)tensor_handle);
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package kscience.kmath.torch;
|
||||
|
||||
class JTorch {
|
||||
|
||||
static {
|
||||
System.loadLibrary("jtorch");
|
||||
}
|
||||
|
||||
public static native int getNumThreads();
|
||||
public static native void setNumThreads(int numThreads);
|
||||
public static native long createTensor();
|
||||
public static native void printTensor(long tensorHandle);
|
||||
public static native void disposeTensor(long tensorHandle);
|
||||
}
|
25
kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt
Normal file
25
kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt
Normal file
@ -0,0 +1,25 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
|
||||
public fun getNumThreads(): Int {
|
||||
return JTorch.getNumThreads()
|
||||
}
|
||||
|
||||
public fun setNumThreads(numThreads: Int): Unit {
|
||||
JTorch.setNumThreads(numThreads)
|
||||
}
|
||||
|
||||
public fun cudaAvailable(): Boolean {
|
||||
TODO("Implementation not available yet")
|
||||
}
|
||||
|
||||
public fun setSeed(seed: Int): Unit {
|
||||
TODO("Implementation not available yet")
|
||||
}
|
||||
|
||||
|
||||
public fun runCPD(): Unit {
|
||||
val tensorHandle = JTorch.createTensor()
|
||||
JTorch.printTensor(tensorHandle)
|
||||
JTorch.disposeTensor(tensorHandle)
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
class TestUtils {
|
||||
|
||||
@Test
|
||||
fun testSetNumThreads() {
|
||||
val numThreads = 2
|
||||
setNumThreads(numThreads)
|
||||
assertEquals(numThreads, getNumThreads())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCPD() {
|
||||
runCPD()
|
||||
}
|
||||
}
|
@ -1,2 +1,2 @@
|
||||
package=kscience.kmath.ctorch
|
||||
package=kscience.kmath.torch.ctorch
|
||||
headers=ctorch.h
|
@ -1,163 +1,90 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
|
||||
import kscience.kmath.structures.*
|
||||
import kscience.kmath.memory.DeferScope
|
||||
import kscience.kmath.memory.withDeferScope
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
import kscience.kmath.torch.ctorch.*
|
||||
|
||||
public sealed class TorchTensorAlgebra<
|
||||
public sealed class TorchTensorAlgebraNative<
|
||||
T,
|
||||
TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType,
|
||||
TorchTensorType : TorchTensor<T>> constructor(
|
||||
TorchTensorType : TorchTensorNative<T>> constructor(
|
||||
internal val scope: DeferScope
|
||||
) :
|
||||
TensorAlgebra<T, TorchTensorType> {
|
||||
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
|
||||
override fun getNumThreads(): Int {
|
||||
return get_num_threads()
|
||||
}
|
||||
|
||||
override fun setNumThreads(numThreads: Int): Unit {
|
||||
set_num_threads(numThreads)
|
||||
}
|
||||
|
||||
override fun cudaAvailable(): Boolean {
|
||||
return cuda_is_available()
|
||||
}
|
||||
|
||||
override fun setSeed(seed: Int): Unit {
|
||||
set_seed(seed)
|
||||
}
|
||||
|
||||
|
||||
override var checks: Boolean = false
|
||||
|
||||
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
|
||||
|
||||
public abstract fun copyFromArray(
|
||||
array: PrimitiveArrayType,
|
||||
shape: IntArray,
|
||||
device: Device = Device.CPU
|
||||
): TorchTensorType
|
||||
|
||||
public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType
|
||||
|
||||
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
||||
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
||||
|
||||
public abstract fun full(value: T, shape: IntArray, device: Device): TorchTensorType
|
||||
|
||||
public abstract fun randIntegral(
|
||||
low: T, high: T, shape: IntArray,
|
||||
device: Device = Device.CPU
|
||||
): TorchTensorType
|
||||
|
||||
public abstract fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
|
||||
public abstract fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
|
||||
|
||||
override val zero: TorchTensorType
|
||||
get() = number(0)
|
||||
override val one: TorchTensorType
|
||||
get() = number(1)
|
||||
|
||||
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
||||
check(a.device == b.device) {
|
||||
"Tensors must be on the same device"
|
||||
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Tensors must be of identical shape"
|
||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
protected inline fun checkLinearOperation(a: TorchTensorType, b: TorchTensorType) {
|
||||
if (a.isNotValue() and b.isNotValue()) {
|
||||
checkDeviceCompatible(a, b)
|
||||
checkShapeCompatible(a, b)
|
||||
}
|
||||
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType =
|
||||
this.times(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.times(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit =
|
||||
this.timesAssign(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.timesAssign(b: TorchTensorType, safe: Boolean): Unit {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
times_tensor_assign(this.tensorHandle, b.tensorHandle)
|
||||
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
|
||||
|
||||
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType =
|
||||
this.plus(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.plus(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit =
|
||||
this.plusAssign(b, false)
|
||||
|
||||
public fun TorchTensorType.plusAssign(b: TorchTensorType, safe: Boolean): Unit {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
plus_tensor_assign(this.tensorHandle, b.tensorHandle)
|
||||
}
|
||||
|
||||
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
|
||||
|
||||
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType =
|
||||
this.minus(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.minus(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit =
|
||||
this.minusAssign(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.minusAssign(b: TorchTensorType, safe: Boolean): Unit {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
minus_tensor_assign(this.tensorHandle, b.tensorHandle)
|
||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||
wrap(unary_minus(this.tensorHandle)!!)
|
||||
|
||||
private inline fun checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
||||
checkDeviceCompatible(a, b)
|
||||
val sa = a.shape
|
||||
val sb = b.shape
|
||||
val na = sa.size
|
||||
val nb = sb.size
|
||||
var status: Boolean
|
||||
if (nb == 1) {
|
||||
status = sa.last() == sb[0]
|
||||
} else {
|
||||
status = sa.last() == sb[nb - 2]
|
||||
if ((na > 2) and (nb > 2)) {
|
||||
status = status and
|
||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||
}
|
||||
}
|
||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
return wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType =
|
||||
this.dot(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.dot(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
||||
if (safe) checkDotOperation(this, b)
|
||||
return wrap(matmul(this.tensorHandle, b.tensorHandle)!!)
|
||||
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
matmul_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public infix fun TorchTensorType.dotAssign(b: TorchTensorType): Unit =
|
||||
this.dotAssign(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.dotAssign(b: TorchTensorType, safe: Boolean): Unit {
|
||||
if (safe) checkDotOperation(this, b)
|
||||
matmul_assign(this.tensorHandle, b.tensorHandle)
|
||||
}
|
||||
|
||||
public infix fun TorchTensorType.dotRightAssign(b: TorchTensorType): Unit =
|
||||
this.dotRightAssign(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.dotRightAssign(b: TorchTensorType, safe: Boolean): Unit {
|
||||
if (safe) checkDotOperation(this, b)
|
||||
matmul_right_assign(this.tensorHandle, b.tensorHandle)
|
||||
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
@ -165,106 +92,78 @@ public sealed class TorchTensorAlgebra<
|
||||
): TorchTensorType =
|
||||
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||
|
||||
private inline fun checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
|
||||
this.transpose(i, j, safe = true)
|
||||
|
||||
public fun TorchTensorType.transpose(i: Int, j: Int, safe: Boolean): TorchTensorType {
|
||||
if (safe) checkTranspose(this.dimension, i, j)
|
||||
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
return wrap(transpose_tensor(tensorHandle, i, j)!!)
|
||||
}
|
||||
|
||||
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit =
|
||||
this.transposeAssign(i, j, safe = true)
|
||||
|
||||
public fun TorchTensorType.transposeAssign(i: Int, j: Int, safe: Boolean): Unit {
|
||||
if (safe) checkTranspose(this.dimension, i, j)
|
||||
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
transpose_tensor_assign(tensorHandle, i, j)
|
||||
}
|
||||
|
||||
private inline fun checkView(a: TorchTensorType, shape: IntArray): Unit =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||
|
||||
override fun TorchTensorType.view(shape: IntArray): TorchTensorType =
|
||||
this.view(shape, safe = true)
|
||||
|
||||
public fun TorchTensorType.view(shape: IntArray, safe: Boolean): TorchTensorType {
|
||||
if (safe) checkView(this, shape)
|
||||
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
|
||||
if (checks) checkView(this, shape)
|
||||
return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
|
||||
public fun TorchTensorType.absAssign(): Unit {
|
||||
override fun TorchTensorType.absAssign(): Unit {
|
||||
abs_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
|
||||
public fun TorchTensorType.sumAssign(): Unit {
|
||||
override fun TorchTensorType.sumAssign(): Unit {
|
||||
sum_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
public fun TorchTensorType.copy(): TorchTensorType =
|
||||
override fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(copy_tensor(this.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||
|
||||
public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit {
|
||||
swap_tensors(this.tensorHandle, otherTensor.tensorHandle)
|
||||
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit {
|
||||
swap_tensors(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
}
|
||||
|
||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
||||
TensorFieldAlgebra<T, TorchTensorType> {
|
||||
public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensorOverFieldNative<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebraNative<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
||||
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
|
||||
override operator fun TorchTensorType.div(b: TorchTensorType): TorchTensorType =
|
||||
this.div(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.div(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
override operator fun TorchTensorType.div(b: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, b)
|
||||
return wrap(div_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.divAssign(b: TorchTensorType): Unit =
|
||||
this.divAssign(b, safe = true)
|
||||
|
||||
public fun TorchTensorType.divAssign(b: TorchTensorType, safe: Boolean): Unit {
|
||||
if (safe) checkLinearOperation(this, b)
|
||||
div_tensor_assign(this.tensorHandle, b.tensorHandle)
|
||||
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun divide(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a / b
|
||||
|
||||
public abstract fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||
public abstract fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.randUniform(): TorchTensorType =
|
||||
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||
wrap(rand_like(this.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.randUniformAssign(): Unit {
|
||||
override fun TorchTensorType.randUniformAssign(): Unit {
|
||||
rand_like_assign(this.tensorHandle)
|
||||
}
|
||||
|
||||
public fun TorchTensorType.randNormal(): TorchTensorType =
|
||||
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||
wrap(randn_like(this.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.randNormalAssign(): Unit {
|
||||
override fun TorchTensorType.randNormalAssign(): Unit {
|
||||
randn_like_assign(this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||
public fun TorchTensorType.expAssign(): Unit {
|
||||
override fun TorchTensorType.expAssign(): Unit {
|
||||
exp_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
||||
public fun TorchTensorType.logAssign(): Unit {
|
||||
override fun TorchTensorType.logAssign(): Unit {
|
||||
log_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
@ -283,31 +182,26 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType {
|
||||
this.checkIsValue()
|
||||
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
|
||||
}
|
||||
|
||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||
this.grad(variable, false)
|
||||
|
||||
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||
this.checkIsValue()
|
||||
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
}
|
||||
|
||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
|
||||
TorchTensorPartialDivisionAlgebraNative<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun number(value: Number): TorchTensorReal =
|
||||
full(value.toDouble(), intArrayOf(1), Device.CPU).sum()
|
||||
|
||||
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||
this.elements().map { it.second }.toList().toDoubleArray()
|
||||
|
||||
@ -360,8 +254,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
times_double_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
@ -377,13 +269,10 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
}
|
||||
|
||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
TorchTensorFieldAlgebra<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||
TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
||||
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun number(value: Number): TorchTensorFloat =
|
||||
full(value.toFloat(), intArrayOf(1), Device.CPU).sum()
|
||||
|
||||
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||
this.elements().map { it.second }.toList().toFloatArray()
|
||||
|
||||
@ -436,8 +325,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
times_float_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
|
||||
|
||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
||||
@ -453,13 +340,10 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
}
|
||||
|
||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
||||
TorchTensorAlgebraNative<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
|
||||
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun number(value: Number): TorchTensorLong =
|
||||
full(value.toLong(), intArrayOf(1), Device.CPU).sum()
|
||||
|
||||
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||
this.elements().map { it.second }.toList().toLongArray()
|
||||
|
||||
@ -516,20 +400,15 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
times_long_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
|
||||
|
||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
||||
TorchTensorAlgebraNative<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
|
||||
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun number(value: Number): TorchTensorInt =
|
||||
full(value.toInt(), intArrayOf(1), Device.CPU).sum()
|
||||
|
||||
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||
this.elements().map { it.second }.toList().toIntArray()
|
||||
|
||||
@ -586,12 +465,11 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
times_int_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
||||
|
||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
}
|
||||
|
||||
|
||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||
|
||||
@ -604,12 +482,3 @@ public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() ->
|
||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
||||
|
||||
public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
||||
this.requiresGrad = true
|
||||
return TorchTensorRealAlgebra(this.scope).block()
|
||||
}
|
||||
|
||||
public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
||||
this.requiresGrad = true
|
||||
return TorchTensorFloatAlgebra(this.scope).block()
|
||||
}
|
@ -1,32 +1,25 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
|
||||
import kscience.kmath.structures.TensorStructure
|
||||
import kscience.kmath.memory.DeferScope
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
import kscience.kmath.torch.ctorch.*
|
||||
|
||||
|
||||
public sealed class TorchTensor<T> constructor(
|
||||
public val scope: DeferScope,
|
||||
public sealed class TorchTensorNative<T> constructor(
|
||||
scope: DeferScope,
|
||||
internal val tensorHandle: COpaquePointer
|
||||
) : TensorStructure<T>() {
|
||||
init {
|
||||
scope.defer(::close)
|
||||
}
|
||||
) : TorchTensor<T>, TorchTensorMemoryHolder(scope) {
|
||||
|
||||
private fun close(): Unit = dispose_tensor(tensorHandle)
|
||||
|
||||
protected abstract fun item(): T
|
||||
override fun close(): Unit = dispose_tensor(tensorHandle)
|
||||
|
||||
override val dimension: Int get() = get_dim(tensorHandle)
|
||||
override val shape: IntArray
|
||||
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
|
||||
public val strides: IntArray
|
||||
override val strides: IntArray
|
||||
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
|
||||
public val size: Int get() = get_numel(tensorHandle)
|
||||
public val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
||||
override val size: Int get() = get_numel(tensorHandle)
|
||||
override val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
||||
|
||||
override fun toString(): String {
|
||||
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!!
|
||||
@ -35,25 +28,6 @@ public sealed class TorchTensor<T> constructor(
|
||||
return stringRepresentation
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||
if (dimension == 0) {
|
||||
return emptySequence()
|
||||
}
|
||||
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
|
||||
return indices.map { it to get(it) }
|
||||
}
|
||||
|
||||
public inline fun isValue(): Boolean = dimension == 0
|
||||
public inline fun isNotValue(): Boolean = !isValue()
|
||||
internal inline fun checkIsValue() = check(isValue()) {
|
||||
"This tensor has shape ${shape.toList()}"
|
||||
}
|
||||
|
||||
override fun value(): T {
|
||||
checkIsValue()
|
||||
return item()
|
||||
}
|
||||
|
||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||
@ -76,11 +50,11 @@ public sealed class TorchTensor<T> constructor(
|
||||
|
||||
}
|
||||
|
||||
public sealed class TorchTensorOverField<T> constructor(
|
||||
public sealed class TorchTensorOverFieldNative<T> constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<T>(scope, tensorHandle) {
|
||||
public var requiresGrad: Boolean
|
||||
) : TorchTensorNative<T>(scope, tensorHandle), TorchTensorOverField<T> {
|
||||
override var requiresGrad: Boolean
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
}
|
||||
@ -89,7 +63,7 @@ public sealed class TorchTensorOverField<T> constructor(
|
||||
public class TorchTensorReal internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorOverField<Double>(scope, tensorHandle) {
|
||||
) : TorchTensorOverFieldNative<Double>(scope, tensorHandle) {
|
||||
override fun item(): Double = get_item_double(tensorHandle)
|
||||
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Double) {
|
||||
@ -100,7 +74,7 @@ public class TorchTensorReal internal constructor(
|
||||
public class TorchTensorFloat internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensorOverField<Float>(scope, tensorHandle) {
|
||||
) : TorchTensorOverFieldNative<Float>(scope, tensorHandle) {
|
||||
override fun item(): Float = get_item_float(tensorHandle)
|
||||
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Float) {
|
||||
@ -111,7 +85,7 @@ public class TorchTensorFloat internal constructor(
|
||||
public class TorchTensorLong internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<Long>(scope, tensorHandle) {
|
||||
) : TorchTensorNative<Long>(scope, tensorHandle) {
|
||||
override fun item(): Long = get_item_long(tensorHandle)
|
||||
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Long) {
|
||||
@ -122,24 +96,10 @@ public class TorchTensorLong internal constructor(
|
||||
public class TorchTensorInt internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<Int>(scope, tensorHandle) {
|
||||
) : TorchTensorNative<Int>(scope, tensorHandle) {
|
||||
override fun item(): Int = get_item_int(tensorHandle)
|
||||
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Int) {
|
||||
set_int(tensorHandle, index.toCValues(), value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
val res = IntArray(nDim)
|
||||
var current = offset
|
||||
var strideIndex = 0
|
||||
|
||||
while (strideIndex < nDim) {
|
||||
res[strideIndex] = (current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex++
|
||||
}
|
||||
return res
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
|
||||
public fun getNumThreads(): Int {
|
||||
return get_num_threads()
|
||||
}
|
||||
|
||||
public fun setNumThreads(numThreads: Int): Unit {
|
||||
set_num_threads(numThreads)
|
||||
}
|
||||
|
||||
public fun cudaAvailable(): Boolean {
|
||||
return cuda_is_available()
|
||||
}
|
||||
|
||||
public fun setSeed(seed: Int): Unit {
|
||||
set_seed(seed)
|
||||
}
|
@ -14,8 +14,8 @@ internal fun benchmarkingMatMultDouble(
|
||||
setSeed(SEED)
|
||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
repeat(numWarmUp) { lhs.dotAssign(rhs, false) }
|
||||
val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } }
|
||||
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||
}
|
||||
}
|
||||
@ -31,8 +31,8 @@ internal fun benchmarkingMatMultFloat(
|
||||
setSeed(SEED)
|
||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
repeat(numWarmUp) { lhs.dotAssign(rhs, false) }
|
||||
val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } }
|
||||
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||
}
|
||||
}
|
||||
|
@ -8,12 +8,12 @@ internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
|
||||
|
||||
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
|
||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
|
||||
val expressionAtX = tensorX.withGrad {
|
||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||
}
|
||||
val expressionAtX = withGradAt(tensorX, { x ->
|
||||
0.5 * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9
|
||||
})
|
||||
|
||||
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||
val hessianAtX = expressionAtX hess tensorX
|
||||
@ -25,21 +25,23 @@ internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
|
||||
}
|
||||
}
|
||||
|
||||
internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||
internal fun testingBatchedAutoGrad(
|
||||
bath: IntArray,
|
||||
dim: Int,
|
||||
device: Device = Device.CPU): Unit {
|
||||
device: Device = Device.CPU
|
||||
): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
setSeed(SEED)
|
||||
|
||||
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
|
||||
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||
val tensorX = randNormal(shape = bath + intArrayOf(1, dim), device = device)
|
||||
val randFeatures = randNormal(shape = bath + intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(-2, -1)
|
||||
val tensorMu = randNormal(shape = bath + intArrayOf(1, dim), device = device)
|
||||
|
||||
val expressionAtX = tensorX.withGrad{
|
||||
val tensorXt = tensorX.transpose(-1,-2)
|
||||
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
|
||||
}
|
||||
val expressionAtX = withGradAt(tensorX, { x ->
|
||||
val xt = x.transpose(-1, -2)
|
||||
0.5 * (x dot (tensorSigma dot xt)) + (tensorMu dot xt) + 58.2
|
||||
})
|
||||
expressionAtX.sumAssign()
|
||||
|
||||
val gradientAtX = expressionAtX grad tensorX
|
||||
@ -53,8 +55,8 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||
|
||||
internal class TestAutograd {
|
||||
@Test
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 3)
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||
|
||||
@Test
|
||||
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)
|
||||
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2, 10), dim = 30)
|
||||
}
|
@ -69,4 +69,5 @@ class TestTorchTensor {
|
||||
viewTensor[intArrayOf(0, 0)] = 10
|
||||
assertEquals(tensor[intArrayOf(0)], 10)
|
||||
}
|
||||
|
||||
}
|
@ -51,6 +51,7 @@ internal fun testingMatrixMultiplication(device: Device = Device.CPU): Unit {
|
||||
|
||||
internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
val shape = intArrayOf(3)
|
||||
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||
@ -81,7 +82,8 @@ internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
|
||||
val error = (expected - result).abs().sum().value() +
|
||||
(expected - assignResult).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
println(expected)
|
||||
}}
|
||||
}
|
||||
|
||||
internal fun testingTensorTransformations(device: Device = Device.CPU): Unit {
|
||||
|
@ -3,9 +3,6 @@ package kscience.kmath.torch
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
internal val SEED = 987654
|
||||
internal val TOLERANCE = 1e-6
|
||||
|
||||
internal fun testingSetSeed(device: Device = Device.CPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
setSeed(SEED)
|
||||
@ -22,10 +19,12 @@ internal fun testingSetSeed(device: Device = Device.CPU): Unit {
|
||||
internal class TestUtils {
|
||||
@Test
|
||||
fun testSetNumThreads() {
|
||||
TorchTensorRealAlgebra {
|
||||
val numThreads = 2
|
||||
setNumThreads(numThreads)
|
||||
assertEquals(numThreads, getNumThreads())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSetSeed() = testingSetSeed()
|
||||
|
Loading…
Reference in New Issue
Block a user