forked from kscience/kmath
Refactoring project structure
This commit is contained in:
parent
889691a122
commit
6eb718f64a
3
.gitignore
vendored
3
.gitignore
vendored
@ -9,3 +9,6 @@ out/
|
|||||||
|
|
||||||
# Cache of project
|
# Cache of project
|
||||||
.gradletasknamecache
|
.gradletasknamecache
|
||||||
|
|
||||||
|
# Generated by javac -h
|
||||||
|
*.class
|
||||||
|
@ -1,49 +1,70 @@
|
|||||||
package kscience.kmath.structures
|
package kscience.kmath.structures
|
||||||
|
|
||||||
|
|
||||||
import kscience.kmath.operations.*
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
|
public interface TensorStructure<T> : MutableNDStructure<T> {
|
||||||
|
// A tensor can have empty shape, in which case it represents just a value
|
||||||
|
public abstract fun value(): T
|
||||||
|
}
|
||||||
|
|
||||||
public operator fun T.plus(other: TorchTensorType): TorchTensorType
|
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
public operator fun TorchTensorType.plus(value: T): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.plusAssign(value: T): Unit
|
|
||||||
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
|
|
||||||
|
|
||||||
public operator fun T.minus(other: TorchTensorType): TorchTensorType
|
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||||
public operator fun TorchTensorType.minus(value: T): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.minusAssign(value: T): Unit
|
|
||||||
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
|
|
||||||
|
|
||||||
public operator fun T.times(other: TorchTensorType): TorchTensorType
|
public operator fun T.plus(other: TensorType): TensorType
|
||||||
public operator fun TorchTensorType.times(value: T): TorchTensorType
|
public operator fun TensorType.plus(value: T): TensorType
|
||||||
public operator fun TorchTensorType.timesAssign(value: T): Unit
|
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||||
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
|
public operator fun TensorType.plusAssign(value: T): Unit
|
||||||
|
public operator fun TensorType.plusAssign(other: TensorType): Unit
|
||||||
|
|
||||||
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
|
public operator fun T.minus(other: TensorType): TensorType
|
||||||
|
public operator fun TensorType.minus(value: T): TensorType
|
||||||
|
public operator fun TensorType.minus(other: TensorType): TensorType
|
||||||
|
public operator fun TensorType.minusAssign(value: T): Unit
|
||||||
|
public operator fun TensorType.minusAssign(other: TensorType): Unit
|
||||||
|
|
||||||
|
public operator fun T.times(other: TensorType): TensorType
|
||||||
|
public operator fun TensorType.times(value: T): TensorType
|
||||||
|
public operator fun TensorType.times(other: TensorType): TensorType
|
||||||
|
public operator fun TensorType.timesAssign(value: T): Unit
|
||||||
|
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||||
|
public operator fun TensorType.unaryMinus(): TensorType
|
||||||
|
|
||||||
|
|
||||||
|
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||||
|
public infix fun TensorType.dotAssign(other: TensorType): Unit
|
||||||
|
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
|
||||||
|
|
||||||
public fun diagonalEmbedding(
|
public fun diagonalEmbedding(
|
||||||
diagonalEntries: TorchTensorType,
|
diagonalEntries: TensorType,
|
||||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||||
): TorchTensorType
|
): TensorType
|
||||||
|
|
||||||
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
|
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||||
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
|
public fun TensorType.transposeAssign(i: Int, j: Int): Unit
|
||||||
|
|
||||||
public fun TorchTensorType.abs(): TorchTensorType
|
public fun TensorType.view(shape: IntArray): TensorType
|
||||||
public fun TorchTensorType.sum(): TorchTensorType
|
|
||||||
|
|
||||||
|
public fun TensorType.abs(): TensorType
|
||||||
|
public fun TensorType.absAssign(): Unit
|
||||||
|
public fun TensorType.sum(): TensorType
|
||||||
|
public fun TensorType.sumAssign(): Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
|
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
||||||
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
|
|
||||||
|
|
||||||
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
|
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
|
||||||
|
TensorAlgebra<T, TensorType> {
|
||||||
|
|
||||||
public fun TorchTensorType.exp(): TorchTensorType
|
public operator fun TensorType.div(other: TensorType): TensorType
|
||||||
public fun TorchTensorType.log(): TorchTensorType
|
public operator fun TensorType.divAssign(other: TensorType)
|
||||||
|
|
||||||
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
|
public fun TensorType.exp(): TensorType
|
||||||
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
|
public fun TensorType.expAssign(): Unit
|
||||||
|
public fun TensorType.log(): TensorType
|
||||||
|
public fun TensorType.logAssign(): Unit
|
||||||
|
|
||||||
|
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||||
|
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
||||||
|
|
||||||
}
|
}
|
@ -1,11 +0,0 @@
|
|||||||
package kscience.kmath.structures
|
|
||||||
|
|
||||||
public abstract class TensorStructure<T>: MutableNDStructure<T> {
|
|
||||||
|
|
||||||
// A tensor can have empty shape, in which case it represents just a value
|
|
||||||
public abstract fun value(): T
|
|
||||||
|
|
||||||
// Tensors are mutable and might hold shared resources
|
|
||||||
override fun equals(other: Any?): Boolean = false
|
|
||||||
override fun hashCode(): Int = 0
|
|
||||||
}
|
|
@ -1,5 +1,6 @@
|
|||||||
import de.undercouch.gradle.tasks.download.Download
|
import de.undercouch.gradle.tasks.download.Download
|
||||||
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
|
import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget
|
||||||
|
import org.gradle.api.JavaVersion.VERSION_11
|
||||||
|
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
@ -7,10 +8,16 @@ plugins {
|
|||||||
id("de.undercouch.download")
|
id("de.undercouch.download")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
java {
|
||||||
|
sourceCompatibility = VERSION_11
|
||||||
|
targetCompatibility = VERSION_11
|
||||||
|
}
|
||||||
|
|
||||||
val home = System.getProperty("user.home")
|
val home = System.getProperty("user.home")
|
||||||
|
val javaHome = System.getProperty("java.home")
|
||||||
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
||||||
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
||||||
|
val cppSources = projectDir.resolve("src/cppMain")
|
||||||
|
|
||||||
val cudaHome: String? = System.getenv("CUDA_HOME")
|
val cudaHome: String? = System.getenv("CUDA_HOME")
|
||||||
val cudaDefault = file("/usr/local/cuda").exists()
|
val cudaDefault = file("/usr/local/cuda").exists()
|
||||||
@ -77,10 +84,11 @@ val configureCpp by tasks.registering {
|
|||||||
workingDir(cppBuildDir)
|
workingDir(cppBuildDir)
|
||||||
commandLine(
|
commandLine(
|
||||||
cmakeCmd,
|
cmakeCmd,
|
||||||
projectDir.resolve("ctorch"),
|
cppSources,
|
||||||
"-GNinja",
|
"-GNinja",
|
||||||
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
|
"-DCMAKE_MAKE_PROGRAM=$ninjaCmd",
|
||||||
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
|
"-DCMAKE_PREFIX_PATH=$thirdPartyDir/$torchArchive",
|
||||||
|
"-DJAVA_HOME=$javaHome",
|
||||||
"-DCMAKE_BUILD_TYPE=Release"
|
"-DCMAKE_BUILD_TYPE=Release"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -107,10 +115,23 @@ val buildCpp by tasks.registering {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val generateJNIHeader by tasks.registering {
|
||||||
|
doLast {
|
||||||
|
exec {
|
||||||
|
workingDir(projectDir.resolve("src/jvmMain/java/kscience/kmath/torch"))
|
||||||
|
commandLine("$javaHome/bin/javac", "-h", cppSources.resolve("include") , "JTorch.java")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kotlin {
|
kotlin {
|
||||||
explicitApiWarning()
|
explicitApiWarning()
|
||||||
|
|
||||||
val nativeTarget = linuxX64("torch")
|
jvm {
|
||||||
|
withJava()
|
||||||
|
}
|
||||||
|
|
||||||
|
val nativeTarget = linuxX64("native")
|
||||||
nativeTarget.apply {
|
nativeTarget.apply {
|
||||||
binaries {
|
binaries {
|
||||||
all {
|
all {
|
||||||
@ -128,38 +149,38 @@ kotlin {
|
|||||||
val main by nativeTarget.compilations.getting {
|
val main by nativeTarget.compilations.getting {
|
||||||
cinterops {
|
cinterops {
|
||||||
val libctorch by creating {
|
val libctorch by creating {
|
||||||
includeDirs(projectDir.resolve("ctorch/include"))
|
includeDirs(cppSources.resolve("include"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val test by nativeTarget.compilations.getting
|
val test by nativeTarget.compilations.getting
|
||||||
|
|
||||||
|
|
||||||
sourceSets {
|
sourceSets {
|
||||||
val nativeMain by creating {
|
|
||||||
|
val commonMain by getting {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val nativeTest by creating {
|
|
||||||
dependsOn(nativeMain)
|
|
||||||
}
|
|
||||||
val nativeGPUTest by creating {
|
|
||||||
dependsOn(nativeMain)
|
|
||||||
dependsOn(nativeTest)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
val nativeMain by getting {
|
||||||
main.defaultSourceSet.dependsOn(nativeMain)
|
dependencies {
|
||||||
test.defaultSourceSet.dependsOn(nativeTest)
|
api(project(":kmath-core"))
|
||||||
if(cudaFound) {
|
}
|
||||||
test.defaultSourceSet.dependsOn(nativeGPUTest)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val torch: KotlinNativeTarget by kotlin.targets
|
val native: KotlinNativeTarget by kotlin.targets
|
||||||
tasks[torch.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
tasks[native.compilations["main"].cinterops["libctorch"].interopProcessingTaskName]
|
||||||
.dependsOn(buildCpp)
|
.dependsOn(buildCpp)
|
||||||
|
|
||||||
|
tasks["jvmProcessResources"].dependsOn(buildCpp)
|
||||||
|
|
||||||
|
tasks {
|
||||||
|
withType<Test>{
|
||||||
|
systemProperty("java.library.path", cppBuildDir.toString())
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kscience.kmath.structures.TensorStructure
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public interface TorchTensor<T> : TensorStructure<T> {
|
||||||
|
public fun item(): T
|
||||||
|
public val strides: IntArray
|
||||||
|
public val size: Int
|
||||||
|
public val device: Device
|
||||||
|
override fun value(): T {
|
||||||
|
checkIsValue()
|
||||||
|
return item()
|
||||||
|
}
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
|
if (dimension == 0) {
|
||||||
|
return emptySequence()
|
||||||
|
}
|
||||||
|
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
|
||||||
|
return indices.map { it to get(it) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T> TorchTensor<T>.isValue(): Boolean {
|
||||||
|
return (dimension == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T> TorchTensor<T>.isNotValue(): Boolean = !this.isValue()
|
||||||
|
|
||||||
|
public inline fun <T> TorchTensor<T>.checkIsValue(): Unit = check(this.isValue()) {
|
||||||
|
"This tensor has shape ${shape.toList()}"
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface TorchTensorOverField<T>: TorchTensor<T>
|
||||||
|
{
|
||||||
|
public var requiresGrad: Boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||||
|
val res = IntArray(nDim)
|
||||||
|
var current = offset
|
||||||
|
var strideIndex = 0
|
||||||
|
|
||||||
|
while (strideIndex < nDim) {
|
||||||
|
res[strideIndex] = (current / strides[strideIndex])
|
||||||
|
current %= strides[strideIndex]
|
||||||
|
strideIndex++
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,140 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kscience.kmath.structures.*
|
||||||
|
|
||||||
|
public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>> :
|
||||||
|
TensorAlgebra<T, TorchTensorType> {
|
||||||
|
|
||||||
|
public fun getNumThreads(): Int
|
||||||
|
public fun setNumThreads(numThreads: Int): Unit
|
||||||
|
public fun cudaAvailable(): Boolean
|
||||||
|
public fun setSeed(seed: Int): Unit
|
||||||
|
|
||||||
|
public var checks: Boolean
|
||||||
|
|
||||||
|
public fun copyFromArray(
|
||||||
|
array: PrimitiveArrayType,
|
||||||
|
shape: IntArray,
|
||||||
|
device: Device = Device.CPU
|
||||||
|
): TorchTensorType
|
||||||
|
|
||||||
|
public fun TorchTensorType.copyToArray(): PrimitiveArrayType
|
||||||
|
|
||||||
|
public fun full(value: T, shape: IntArray, device: Device): TorchTensorType
|
||||||
|
|
||||||
|
public fun randIntegral(
|
||||||
|
low: T, high: T, shape: IntArray,
|
||||||
|
device: Device = Device.CPU
|
||||||
|
): TorchTensorType
|
||||||
|
|
||||||
|
public fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
|
||||||
|
public fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
|
||||||
|
|
||||||
|
public fun TorchTensorType.copy(): TorchTensorType
|
||||||
|
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType
|
||||||
|
public infix fun TorchTensorType.swap(other: TorchTensorType)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>> :
|
||||||
|
TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>, TensorPartialDivisionAlgebra<T, TorchTensorType> {
|
||||||
|
|
||||||
|
public fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||||
|
public fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
||||||
|
public fun TorchTensorType.randUniform(): TorchTensorType
|
||||||
|
public fun TorchTensorType.randUniformAssign(): Unit
|
||||||
|
public fun TorchTensorType.randNormal(): TorchTensorType
|
||||||
|
public fun TorchTensorType.randNormalAssign(): Unit
|
||||||
|
|
||||||
|
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType
|
||||||
|
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||||
|
this.grad(variable, false)
|
||||||
|
|
||||||
|
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType
|
||||||
|
public fun TorchTensorType.detachFromGraph(): TorchTensorType
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.withChecks(block: TorchTensorAlgebraType.() -> Unit): Unit {
|
||||||
|
val state = this.checks
|
||||||
|
this.checks = true
|
||||||
|
this.block()
|
||||||
|
this.checks = state
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkDeviceCompatible(
|
||||||
|
a: TorchTensorType, b: TorchTensorType
|
||||||
|
): Unit =
|
||||||
|
check(a.device == b.device) {
|
||||||
|
"Tensors must be on the same device"
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkShapeCompatible(
|
||||||
|
a: TorchTensorType,
|
||||||
|
b: TorchTensorType
|
||||||
|
): Unit =
|
||||||
|
check(a.shape contentEquals b.shape) {
|
||||||
|
"Tensors must be of identical shape"
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkLinearOperation(
|
||||||
|
a: TorchTensorType,
|
||||||
|
b: TorchTensorType
|
||||||
|
) {
|
||||||
|
if (a.isNotValue() and b.isNotValue()) {
|
||||||
|
this.checkDeviceCompatible(a, b)
|
||||||
|
this.checkShapeCompatible(a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
||||||
|
checkDeviceCompatible(a, b)
|
||||||
|
val sa = a.shape
|
||||||
|
val sb = b.shape
|
||||||
|
val na = sa.size
|
||||||
|
val nb = sb.size
|
||||||
|
var status: Boolean
|
||||||
|
if (nb == 1) {
|
||||||
|
status = sa.last() == sb[0]
|
||||||
|
} else {
|
||||||
|
status = sa.last() == sb[nb - 2]
|
||||||
|
if ((na > 2) and (nb > 2)) {
|
||||||
|
status = status and
|
||||||
|
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||||
|
check((i < dim) and (j < dim)) {
|
||||||
|
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorAlgebraType.checkView(a: TorchTensorType, shape: IntArray): Unit =
|
||||||
|
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||||
|
|
||||||
|
|
||||||
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||||
|
TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
TorchTensorDivisionAlgebraType.withGradAt(
|
||||||
|
tensor: TorchTensorType,
|
||||||
|
block: TorchTensorDivisionAlgebraType.(TorchTensorType) -> TorchTensorType
|
||||||
|
): TorchTensorType {
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
return this.block(tensor)
|
||||||
|
}
|
@ -0,0 +1,14 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kscience.kmath.memory.DeferScope
|
||||||
|
|
||||||
|
public abstract class TorchTensorMemoryHolder internal constructor(
|
||||||
|
public val scope: DeferScope) {
|
||||||
|
init {
|
||||||
|
scope.defer(::close)
|
||||||
|
}
|
||||||
|
protected abstract fun close(): Unit
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean = false
|
||||||
|
override fun hashCode(): Int = 0
|
||||||
|
}
|
@ -0,0 +1,5 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
internal val SEED = 987654
|
||||||
|
internal val TOLERANCE = 1e-6
|
||||||
|
|
@ -12,14 +12,23 @@ endif()
|
|||||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||||
|
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
find_package(JNI REQUIRED)
|
||||||
|
|
||||||
add_library(ctorch SHARED src/ctorch.cc)
|
add_library(ctorch SHARED src/ctorch.cc)
|
||||||
target_include_directories(ctorch PRIVATE include)
|
target_include_directories(ctorch PRIVATE include)
|
||||||
target_link_libraries(ctorch PRIVATE torch)
|
target_link_libraries(ctorch PRIVATE torch)
|
||||||
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
target_compile_options(ctorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||||
|
|
||||||
|
add_library(jtorch SHARED src/jtorch.cc)
|
||||||
|
target_include_directories(jtorch PRIVATE include ${JNI_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(jtorch PRIVATE torch)
|
||||||
|
target_compile_options(jtorch PRIVATE -Wall -Wextra -Wpedantic -O3 -fPIC)
|
||||||
|
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
|
set_target_properties(ctorch PROPERTIES PUBLIC_HEADER include/ctorch.h)
|
||||||
install(TARGETS ctorch
|
install(TARGETS ctorch
|
||||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||||
|
|
||||||
|
install(TARGETS jtorch LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
@ -0,0 +1,53 @@
|
|||||||
|
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||||
|
#include <jni.h>
|
||||||
|
/* Header for class kscience_kmath_torch_JTorch */
|
||||||
|
|
||||||
|
#ifndef _Included_kscience_kmath_torch_JTorch
|
||||||
|
#define _Included_kscience_kmath_torch_JTorch
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
/*
|
||||||
|
* Class: kscience_kmath_torch_JTorch
|
||||||
|
* Method: getNumThreads
|
||||||
|
* Signature: ()I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumThreads
|
||||||
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: kscience_kmath_torch_JTorch
|
||||||
|
* Method: setNumThreads
|
||||||
|
* Signature: (I)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads
|
||||||
|
(JNIEnv *, jclass, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: kscience_kmath_torch_JTorch
|
||||||
|
* Method: createTensor
|
||||||
|
* Signature: ()J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor
|
||||||
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: kscience_kmath_torch_JTorch
|
||||||
|
* Method: printTensor
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: kscience_kmath_torch_JTorch
|
||||||
|
* Method: disposeTensor
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
@ -1,9 +1,10 @@
|
|||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
|
||||||
#include "ctorch.h"
|
|
||||||
|
|
||||||
namespace ctorch
|
namespace ctorch
|
||||||
{
|
{
|
||||||
|
|
||||||
|
using TorchTensorHandle = void*;
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
inline c10::ScalarType dtype()
|
inline c10::ScalarType dtype()
|
||||||
{
|
{
|
35
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
35
kmath-torch/src/cppMain/src/jtorch.cc
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#include <torch/torch.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "kscience_kmath_torch_JTorch.h"
|
||||||
|
#include "utils.hh"
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_kscience_kmath_torch_JTorch_getNumThreads(JNIEnv *, jclass)
|
||||||
|
{
|
||||||
|
return torch::get_num_threads();
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_setNumThreads(JNIEnv *, jclass, jint num_threads)
|
||||||
|
{
|
||||||
|
torch::set_num_threads(num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL Java_kscience_kmath_torch_JTorch_createTensor(JNIEnv *, jclass)
|
||||||
|
{
|
||||||
|
auto ten = torch::randn({2, 3});
|
||||||
|
std::cout << ten << std::endl;
|
||||||
|
void *ptr = new torch::Tensor(ten);
|
||||||
|
return (long)ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_printTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
auto ten = ctorch::cast((void *)tensor_handle);
|
||||||
|
std::cout << ten << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL Java_kscience_kmath_torch_JTorch_disposeTensor(JNIEnv *, jclass, jlong tensor_handle)
|
||||||
|
{
|
||||||
|
delete static_cast<torch::Tensor *>((void *)tensor_handle);
|
||||||
|
}
|
@ -0,0 +1,14 @@
|
|||||||
|
package kscience.kmath.torch;
|
||||||
|
|
||||||
|
class JTorch {
|
||||||
|
|
||||||
|
static {
|
||||||
|
System.loadLibrary("jtorch");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static native int getNumThreads();
|
||||||
|
public static native void setNumThreads(int numThreads);
|
||||||
|
public static native long createTensor();
|
||||||
|
public static native void printTensor(long tensorHandle);
|
||||||
|
public static native void disposeTensor(long tensorHandle);
|
||||||
|
}
|
25
kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt
Normal file
25
kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/Utils.kt
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
|
public fun getNumThreads(): Int {
|
||||||
|
return JTorch.getNumThreads()
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun setNumThreads(numThreads: Int): Unit {
|
||||||
|
JTorch.setNumThreads(numThreads)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun cudaAvailable(): Boolean {
|
||||||
|
TODO("Implementation not available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun setSeed(seed: Int): Unit {
|
||||||
|
TODO("Implementation not available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public fun runCPD(): Unit {
|
||||||
|
val tensorHandle = JTorch.createTensor()
|
||||||
|
JTorch.printTensor(tensorHandle)
|
||||||
|
JTorch.disposeTensor(tensorHandle)
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSetNumThreads() {
|
||||||
|
val numThreads = 2
|
||||||
|
setNumThreads(numThreads)
|
||||||
|
assertEquals(numThreads, getNumThreads())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCPD() {
|
||||||
|
runCPD()
|
||||||
|
}
|
||||||
|
}
|
@ -1,2 +1,2 @@
|
|||||||
package=kscience.kmath.ctorch
|
package=kscience.kmath.torch.ctorch
|
||||||
headers=ctorch.h
|
headers=ctorch.h
|
@ -1,163 +1,90 @@
|
|||||||
package kscience.kmath.torch
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
import kscience.kmath.structures.*
|
|
||||||
import kscience.kmath.memory.DeferScope
|
import kscience.kmath.memory.DeferScope
|
||||||
import kscience.kmath.memory.withDeferScope
|
import kscience.kmath.memory.withDeferScope
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.torch.ctorch.*
|
||||||
|
|
||||||
public sealed class TorchTensorAlgebra<
|
public sealed class TorchTensorAlgebraNative<
|
||||||
T,
|
T,
|
||||||
TVar : CPrimitiveVar,
|
TVar : CPrimitiveVar,
|
||||||
PrimitiveArrayType,
|
PrimitiveArrayType,
|
||||||
TorchTensorType : TorchTensor<T>> constructor(
|
TorchTensorType : TorchTensorNative<T>> constructor(
|
||||||
internal val scope: DeferScope
|
internal val scope: DeferScope
|
||||||
) :
|
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||||
TensorAlgebra<T, TorchTensorType> {
|
|
||||||
|
override fun getNumThreads(): Int {
|
||||||
|
return get_num_threads()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setNumThreads(numThreads: Int): Unit {
|
||||||
|
set_num_threads(numThreads)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cudaAvailable(): Boolean {
|
||||||
|
return cuda_is_available()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setSeed(seed: Int): Unit {
|
||||||
|
set_seed(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
override var checks: Boolean = false
|
||||||
|
|
||||||
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
|
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
|
||||||
|
|
||||||
public abstract fun copyFromArray(
|
|
||||||
array: PrimitiveArrayType,
|
|
||||||
shape: IntArray,
|
|
||||||
device: Device = Device.CPU
|
|
||||||
): TorchTensorType
|
|
||||||
|
|
||||||
public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType
|
|
||||||
|
|
||||||
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
||||||
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
||||||
|
|
||||||
public abstract fun full(value: T, shape: IntArray, device: Device): TorchTensorType
|
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
public abstract fun randIntegral(
|
return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
low: T, high: T, shape: IntArray,
|
|
||||||
device: Device = Device.CPU
|
|
||||||
): TorchTensorType
|
|
||||||
|
|
||||||
public abstract fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
|
|
||||||
public abstract fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
|
|
||||||
|
|
||||||
override val zero: TorchTensorType
|
|
||||||
get() = number(0)
|
|
||||||
override val one: TorchTensorType
|
|
||||||
get() = number(1)
|
|
||||||
|
|
||||||
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
|
||||||
check(a.device == b.device) {
|
|
||||||
"Tensors must be on the same device"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
check(a.shape contentEquals b.shape) {
|
if (checks) checkLinearOperation(this, other)
|
||||||
"Tensors must be of identical shape"
|
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected inline fun checkLinearOperation(a: TorchTensorType, b: TorchTensorType) {
|
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||||
if (a.isNotValue() and b.isNotValue()) {
|
if (checks) checkLinearOperation(this, other)
|
||||||
checkDeviceCompatible(a, b)
|
return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
checkShapeCompatible(a, b)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType =
|
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
this.times(b, safe = true)
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
public fun TorchTensorType.times(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit =
|
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||||
this.timesAssign(b, safe = true)
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
public fun TorchTensorType.timesAssign(b: TorchTensorType, safe: Boolean): Unit {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
times_tensor_assign(this.tensorHandle, b.tensorHandle)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
|
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
|
if (checks) checkLinearOperation(this, other)
|
||||||
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType =
|
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
this.plus(b, safe = true)
|
|
||||||
|
|
||||||
public fun TorchTensorType.plus(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit =
|
|
||||||
this.plusAssign(b, false)
|
|
||||||
|
|
||||||
public fun TorchTensorType.plusAssign(b: TorchTensorType, safe: Boolean): Unit {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
plus_tensor_assign(this.tensorHandle, b.tensorHandle)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
|
|
||||||
|
|
||||||
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType =
|
|
||||||
this.minus(b, safe = true)
|
|
||||||
|
|
||||||
public fun TorchTensorType.minus(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit =
|
|
||||||
this.minusAssign(b, safe = true)
|
|
||||||
|
|
||||||
public fun TorchTensorType.minusAssign(b: TorchTensorType, safe: Boolean): Unit {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
minus_tensor_assign(this.tensorHandle, b.tensorHandle)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
wrap(unary_minus(this.tensorHandle)!!)
|
wrap(unary_minus(this.tensorHandle)!!)
|
||||||
|
|
||||||
private inline fun checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
|
||||||
checkDeviceCompatible(a, b)
|
if (checks) checkDotOperation(this, other)
|
||||||
val sa = a.shape
|
return wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||||
val sb = b.shape
|
|
||||||
val na = sa.size
|
|
||||||
val nb = sb.size
|
|
||||||
var status: Boolean
|
|
||||||
if (nb == 1) {
|
|
||||||
status = sa.last() == sb[0]
|
|
||||||
} else {
|
|
||||||
status = sa.last() == sb[nb - 2]
|
|
||||||
if ((na > 2) and (nb > 2)) {
|
|
||||||
status = status and
|
|
||||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType =
|
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||||
this.dot(b, safe = true)
|
if (checks) checkDotOperation(this, other)
|
||||||
|
matmul_assign(this.tensorHandle, other.tensorHandle)
|
||||||
public fun TorchTensorType.dot(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
|
||||||
if (safe) checkDotOperation(this, b)
|
|
||||||
return wrap(matmul(this.tensorHandle, b.tensorHandle)!!)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public infix fun TorchTensorType.dotAssign(b: TorchTensorType): Unit =
|
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||||
this.dotAssign(b, safe = true)
|
if (checks) checkDotOperation(this, other)
|
||||||
|
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
||||||
public fun TorchTensorType.dotAssign(b: TorchTensorType, safe: Boolean): Unit {
|
|
||||||
if (safe) checkDotOperation(this, b)
|
|
||||||
matmul_assign(this.tensorHandle, b.tensorHandle)
|
|
||||||
}
|
|
||||||
|
|
||||||
public infix fun TorchTensorType.dotRightAssign(b: TorchTensorType): Unit =
|
|
||||||
this.dotRightAssign(b, safe = true)
|
|
||||||
|
|
||||||
public fun TorchTensorType.dotRightAssign(b: TorchTensorType, safe: Boolean): Unit {
|
|
||||||
if (safe) checkDotOperation(this, b)
|
|
||||||
matmul_right_assign(this.tensorHandle, b.tensorHandle)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
@ -165,106 +92,78 @@ public sealed class TorchTensorAlgebra<
|
|||||||
): TorchTensorType =
|
): TorchTensorType =
|
||||||
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||||
|
|
||||||
private inline fun checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
|
||||||
check((i < dim) and (j < dim)) {
|
if (checks) checkTranspose(this.dimension, i, j)
|
||||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
|
|
||||||
this.transpose(i, j, safe = true)
|
|
||||||
|
|
||||||
public fun TorchTensorType.transpose(i: Int, j: Int, safe: Boolean): TorchTensorType {
|
|
||||||
if (safe) checkTranspose(this.dimension, i, j)
|
|
||||||
return wrap(transpose_tensor(tensorHandle, i, j)!!)
|
return wrap(transpose_tensor(tensorHandle, i, j)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit =
|
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||||
this.transposeAssign(i, j, safe = true)
|
if (checks) checkTranspose(this.dimension, i, j)
|
||||||
|
|
||||||
public fun TorchTensorType.transposeAssign(i: Int, j: Int, safe: Boolean): Unit {
|
|
||||||
if (safe) checkTranspose(this.dimension, i, j)
|
|
||||||
transpose_tensor_assign(tensorHandle, i, j)
|
transpose_tensor_assign(tensorHandle, i, j)
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun checkView(a: TorchTensorType, shape: IntArray): Unit =
|
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
|
||||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
if (checks) checkView(this, shape)
|
||||||
|
|
||||||
override fun TorchTensorType.view(shape: IntArray): TorchTensorType =
|
|
||||||
this.view(shape, safe = true)
|
|
||||||
|
|
||||||
public fun TorchTensorType.view(shape: IntArray, safe: Boolean): TorchTensorType {
|
|
||||||
if (safe) checkView(this, shape)
|
|
||||||
return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
|
return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
|
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
|
||||||
public fun TorchTensorType.absAssign(): Unit {
|
override fun TorchTensorType.absAssign(): Unit {
|
||||||
abs_tensor_assign(tensorHandle)
|
abs_tensor_assign(tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
|
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
|
||||||
public fun TorchTensorType.sumAssign(): Unit {
|
override fun TorchTensorType.sumAssign(): Unit {
|
||||||
sum_tensor_assign(tensorHandle)
|
sum_tensor_assign(tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun TorchTensorType.copy(): TorchTensorType =
|
override fun TorchTensorType.copy(): TorchTensorType =
|
||||||
wrap(copy_tensor(this.tensorHandle)!!)
|
wrap(copy_tensor(this.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||||
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||||
|
|
||||||
public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit {
|
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit {
|
||||||
swap_tensors(this.tensorHandle, otherTensor.tensorHandle)
|
swap_tensors(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitiveVar,
|
||||||
PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>>(scope: DeferScope) :
|
PrimitiveArrayType, TorchTensorType : TorchTensorOverFieldNative<T>>(scope: DeferScope) :
|
||||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
TorchTensorAlgebraNative<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
||||||
TensorFieldAlgebra<T, TorchTensorType> {
|
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||||
|
|
||||||
override operator fun TorchTensorType.div(b: TorchTensorType): TorchTensorType =
|
override operator fun TorchTensorType.div(b: TorchTensorType): TorchTensorType {
|
||||||
this.div(b, safe = true)
|
if (checks) checkLinearOperation(this, b)
|
||||||
|
|
||||||
public fun TorchTensorType.div(b: TorchTensorType, safe: Boolean): TorchTensorType {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
return wrap(div_tensor(this.tensorHandle, b.tensorHandle)!!)
|
return wrap(div_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.divAssign(b: TorchTensorType): Unit =
|
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||||
this.divAssign(b, safe = true)
|
if (checks) checkLinearOperation(this, other)
|
||||||
|
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
public fun TorchTensorType.divAssign(b: TorchTensorType, safe: Boolean): Unit {
|
|
||||||
if (safe) checkLinearOperation(this, b)
|
|
||||||
div_tensor_assign(this.tensorHandle, b.tensorHandle)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a / b
|
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||||
|
|
||||||
public abstract fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
|
||||||
public abstract fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
|
|
||||||
|
|
||||||
public fun TorchTensorType.randUniform(): TorchTensorType =
|
|
||||||
wrap(rand_like(this.tensorHandle)!!)
|
wrap(rand_like(this.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensorType.randUniformAssign(): Unit {
|
override fun TorchTensorType.randUniformAssign(): Unit {
|
||||||
rand_like_assign(this.tensorHandle)
|
rand_like_assign(this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun TorchTensorType.randNormal(): TorchTensorType =
|
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||||
wrap(randn_like(this.tensorHandle)!!)
|
wrap(randn_like(this.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensorType.randNormalAssign(): Unit {
|
override fun TorchTensorType.randNormalAssign(): Unit {
|
||||||
randn_like_assign(this.tensorHandle)
|
randn_like_assign(this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||||
public fun TorchTensorType.expAssign(): Unit {
|
override fun TorchTensorType.expAssign(): Unit {
|
||||||
exp_tensor_assign(tensorHandle)
|
exp_tensor_assign(tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
||||||
public fun TorchTensorType.logAssign(): Unit {
|
override fun TorchTensorType.logAssign(): Unit {
|
||||||
log_tensor_assign(tensorHandle)
|
log_tensor_assign(tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,31 +182,26 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
|||||||
return Pair(wrap(S), wrap(V))
|
return Pair(wrap(S), wrap(V))
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType {
|
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
|
||||||
this.checkIsValue()
|
if (checks) this.checkIsValue()
|
||||||
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
|
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||||
this.grad(variable, false)
|
if (checks) this.checkIsValue()
|
||||||
|
|
||||||
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
|
||||||
this.checkIsValue()
|
|
||||||
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||||
TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
|
TorchTensorPartialDivisionAlgebraNative<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
|
||||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
||||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
override fun number(value: Number): TorchTensorReal =
|
|
||||||
full(value.toDouble(), intArrayOf(1), Device.CPU).sum()
|
|
||||||
|
|
||||||
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||||
this.elements().map { it.second }.toList().toDoubleArray()
|
this.elements().map { it.second }.toList().toDoubleArray()
|
||||||
|
|
||||||
@ -360,8 +254,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
times_double_assign(value, this.tensorHandle)
|
times_double_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
|
||||||
|
|
||||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
@ -377,13 +269,10 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||||
TorchTensorFieldAlgebra<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
TorchTensorPartialDivisionAlgebraNative<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
||||||
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
override fun number(value: Number): TorchTensorFloat =
|
|
||||||
full(value.toFloat(), intArrayOf(1), Device.CPU).sum()
|
|
||||||
|
|
||||||
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||||
this.elements().map { it.second }.toList().toFloatArray()
|
this.elements().map { it.second }.toList().toFloatArray()
|
||||||
|
|
||||||
@ -436,8 +325,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
|||||||
times_float_assign(value, this.tensorHandle)
|
times_float_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
|
|
||||||
|
|
||||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
@ -453,13 +340,10 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
TorchTensorAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
TorchTensorAlgebraNative<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
||||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
|
||||||
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
override fun number(value: Number): TorchTensorLong =
|
|
||||||
full(value.toLong(), intArrayOf(1), Device.CPU).sum()
|
|
||||||
|
|
||||||
override fun TorchTensorLong.copyToArray(): LongArray =
|
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||||
this.elements().map { it.second }.toList().toLongArray()
|
this.elements().map { it.second }.toList().toLongArray()
|
||||||
|
|
||||||
@ -516,20 +400,15 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
|||||||
times_long_assign(value, this.tensorHandle)
|
times_long_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
|
|
||||||
|
|
||||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
TorchTensorAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
TorchTensorAlgebraNative<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
||||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
|
||||||
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
override fun number(value: Number): TorchTensorInt =
|
|
||||||
full(value.toInt(), intArrayOf(1), Device.CPU).sum()
|
|
||||||
|
|
||||||
override fun TorchTensorInt.copyToArray(): IntArray =
|
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||||
this.elements().map { it.second }.toList().toIntArray()
|
this.elements().map { it.second }.toList().toIntArray()
|
||||||
|
|
||||||
@ -586,12 +465,11 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
|||||||
times_int_assign(value, this.tensorHandle)
|
times_int_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
|
||||||
|
|
||||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||||
|
|
||||||
@ -604,12 +482,3 @@ public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() ->
|
|||||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||||
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
||||||
|
|
||||||
public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
|
||||||
this.requiresGrad = true
|
|
||||||
return TorchTensorRealAlgebra(this.scope).block()
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
|
||||||
this.requiresGrad = true
|
|
||||||
return TorchTensorFloatAlgebra(this.scope).block()
|
|
||||||
}
|
|
@ -1,32 +1,25 @@
|
|||||||
package kscience.kmath.torch
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
import kscience.kmath.structures.TensorStructure
|
|
||||||
import kscience.kmath.memory.DeferScope
|
import kscience.kmath.memory.DeferScope
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.torch.ctorch.*
|
||||||
|
|
||||||
|
|
||||||
public sealed class TorchTensor<T> constructor(
|
public sealed class TorchTensorNative<T> constructor(
|
||||||
public val scope: DeferScope,
|
scope: DeferScope,
|
||||||
internal val tensorHandle: COpaquePointer
|
internal val tensorHandle: COpaquePointer
|
||||||
) : TensorStructure<T>() {
|
) : TorchTensor<T>, TorchTensorMemoryHolder(scope) {
|
||||||
init {
|
|
||||||
scope.defer(::close)
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun close(): Unit = dispose_tensor(tensorHandle)
|
override fun close(): Unit = dispose_tensor(tensorHandle)
|
||||||
|
|
||||||
protected abstract fun item(): T
|
|
||||||
|
|
||||||
override val dimension: Int get() = get_dim(tensorHandle)
|
override val dimension: Int get() = get_dim(tensorHandle)
|
||||||
override val shape: IntArray
|
override val shape: IntArray
|
||||||
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
|
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
|
||||||
public val strides: IntArray
|
override val strides: IntArray
|
||||||
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
|
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
|
||||||
public val size: Int get() = get_numel(tensorHandle)
|
override val size: Int get() = get_numel(tensorHandle)
|
||||||
public val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
override val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
||||||
|
|
||||||
override fun toString(): String {
|
override fun toString(): String {
|
||||||
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!!
|
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!!
|
||||||
@ -35,25 +28,6 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
return stringRepresentation
|
return stringRepresentation
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
|
||||||
if (dimension == 0) {
|
|
||||||
return emptySequence()
|
|
||||||
}
|
|
||||||
val indices = (1..size).asSequence().map { indexFromOffset(it - 1, strides, dimension) }
|
|
||||||
return indices.map { it to get(it) }
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun isValue(): Boolean = dimension == 0
|
|
||||||
public inline fun isNotValue(): Boolean = !isValue()
|
|
||||||
internal inline fun checkIsValue() = check(isValue()) {
|
|
||||||
"This tensor has shape ${shape.toList()}"
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun value(): T {
|
|
||||||
checkIsValue()
|
|
||||||
return item()
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = copy_to_double(this.tensorHandle)!!
|
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||||
@ -76,11 +50,11 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorOverField<T> constructor(
|
public sealed class TorchTensorOverFieldNative<T> constructor(
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensor<T>(scope, tensorHandle) {
|
) : TorchTensorNative<T>(scope, tensorHandle), TorchTensorOverField<T> {
|
||||||
public var requiresGrad: Boolean
|
override var requiresGrad: Boolean
|
||||||
get() = requires_grad(tensorHandle)
|
get() = requires_grad(tensorHandle)
|
||||||
set(value) = requires_grad_(tensorHandle, value)
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
}
|
}
|
||||||
@ -89,7 +63,7 @@ public sealed class TorchTensorOverField<T> constructor(
|
|||||||
public class TorchTensorReal internal constructor(
|
public class TorchTensorReal internal constructor(
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensorOverField<Double>(scope, tensorHandle) {
|
) : TorchTensorOverFieldNative<Double>(scope, tensorHandle) {
|
||||||
override fun item(): Double = get_item_double(tensorHandle)
|
override fun item(): Double = get_item_double(tensorHandle)
|
||||||
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
||||||
override fun set(index: IntArray, value: Double) {
|
override fun set(index: IntArray, value: Double) {
|
||||||
@ -100,7 +74,7 @@ public class TorchTensorReal internal constructor(
|
|||||||
public class TorchTensorFloat internal constructor(
|
public class TorchTensorFloat internal constructor(
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensorOverField<Float>(scope, tensorHandle) {
|
) : TorchTensorOverFieldNative<Float>(scope, tensorHandle) {
|
||||||
override fun item(): Float = get_item_float(tensorHandle)
|
override fun item(): Float = get_item_float(tensorHandle)
|
||||||
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||||
override fun set(index: IntArray, value: Float) {
|
override fun set(index: IntArray, value: Float) {
|
||||||
@ -111,7 +85,7 @@ public class TorchTensorFloat internal constructor(
|
|||||||
public class TorchTensorLong internal constructor(
|
public class TorchTensorLong internal constructor(
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensor<Long>(scope, tensorHandle) {
|
) : TorchTensorNative<Long>(scope, tensorHandle) {
|
||||||
override fun item(): Long = get_item_long(tensorHandle)
|
override fun item(): Long = get_item_long(tensorHandle)
|
||||||
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
|
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
|
||||||
override fun set(index: IntArray, value: Long) {
|
override fun set(index: IntArray, value: Long) {
|
||||||
@ -122,24 +96,10 @@ public class TorchTensorLong internal constructor(
|
|||||||
public class TorchTensorInt internal constructor(
|
public class TorchTensorInt internal constructor(
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensor<Int>(scope, tensorHandle) {
|
) : TorchTensorNative<Int>(scope, tensorHandle) {
|
||||||
override fun item(): Int = get_item_int(tensorHandle)
|
override fun item(): Int = get_item_int(tensorHandle)
|
||||||
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
|
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
|
||||||
override fun set(index: IntArray, value: Int) {
|
override fun set(index: IntArray, value: Int) {
|
||||||
set_int(tensorHandle, index.toCValues(), value)
|
set_int(tensorHandle, index.toCValues(), value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
|
||||||
val res = IntArray(nDim)
|
|
||||||
var current = offset
|
|
||||||
var strideIndex = 0
|
|
||||||
|
|
||||||
while (strideIndex < nDim) {
|
|
||||||
res[strideIndex] = (current / strides[strideIndex])
|
|
||||||
current %= strides[strideIndex]
|
|
||||||
strideIndex++
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
@ -1,20 +0,0 @@
|
|||||||
package kscience.kmath.torch
|
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
|
||||||
import kscience.kmath.ctorch.*
|
|
||||||
|
|
||||||
public fun getNumThreads(): Int {
|
|
||||||
return get_num_threads()
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun setNumThreads(numThreads: Int): Unit {
|
|
||||||
set_num_threads(numThreads)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun cudaAvailable(): Boolean {
|
|
||||||
return cuda_is_available()
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun setSeed(seed: Int): Unit {
|
|
||||||
set_seed(seed)
|
|
||||||
}
|
|
@ -14,8 +14,8 @@ internal fun benchmarkingMatMultDouble(
|
|||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
repeat(numWarmUp) { lhs.dotAssign(rhs, false) }
|
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||||
val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } }
|
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -31,8 +31,8 @@ internal fun benchmarkingMatMultFloat(
|
|||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
repeat(numWarmUp) { lhs.dotAssign(rhs, false) }
|
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||||
val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } }
|
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,9 +11,9 @@ internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
|
|||||||
val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
|
val tensorSigma = randFeatures + randFeatures.transpose(0, 1)
|
||||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
|
||||||
val expressionAtX = tensorX.withGrad {
|
val expressionAtX = withGradAt(tensorX, { x ->
|
||||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
0.5 * (x dot (tensorSigma dot x)) + (tensorMu dot x) + 25.9
|
||||||
}
|
})
|
||||||
|
|
||||||
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||||
val hessianAtX = expressionAtX hess tensorX
|
val hessianAtX = expressionAtX hess tensorX
|
||||||
@ -25,9 +25,11 @@ internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun testingBatchedAutoGrad(bath: IntArray,
|
internal fun testingBatchedAutoGrad(
|
||||||
|
bath: IntArray,
|
||||||
dim: Int,
|
dim: Int,
|
||||||
device: Device = Device.CPU): Unit {
|
device: Device = Device.CPU
|
||||||
|
): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
|
|
||||||
@ -36,10 +38,10 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
|||||||
val tensorSigma = randFeatures + randFeatures.transpose(-2, -1)
|
val tensorSigma = randFeatures + randFeatures.transpose(-2, -1)
|
||||||
val tensorMu = randNormal(shape = bath + intArrayOf(1, dim), device = device)
|
val tensorMu = randNormal(shape = bath + intArrayOf(1, dim), device = device)
|
||||||
|
|
||||||
val expressionAtX = tensorX.withGrad{
|
val expressionAtX = withGradAt(tensorX, { x ->
|
||||||
val tensorXt = tensorX.transpose(-1,-2)
|
val xt = x.transpose(-1, -2)
|
||||||
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
|
0.5 * (x dot (tensorSigma dot xt)) + (tensorMu dot xt) + 58.2
|
||||||
}
|
})
|
||||||
expressionAtX.sumAssign()
|
expressionAtX.sumAssign()
|
||||||
|
|
||||||
val gradientAtX = expressionAtX grad tensorX
|
val gradientAtX = expressionAtX grad tensorX
|
||||||
@ -53,7 +55,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
|||||||
|
|
||||||
internal class TestAutograd {
|
internal class TestAutograd {
|
||||||
@Test
|
@Test
|
||||||
fun testAutoGrad() = testingAutoGrad(dim = 3)
|
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2, 10), dim = 30)
|
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2, 10), dim = 30)
|
||||||
|
@ -69,4 +69,5 @@ class TestTorchTensor {
|
|||||||
viewTensor[intArrayOf(0, 0)] = 10
|
viewTensor[intArrayOf(0, 0)] = 10
|
||||||
assertEquals(tensor[intArrayOf(0)], 10)
|
assertEquals(tensor[intArrayOf(0)], 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -51,6 +51,7 @@ internal fun testingMatrixMultiplication(device: Device = Device.CPU): Unit {
|
|||||||
|
|
||||||
internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
|
internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
|
withChecks {
|
||||||
val shape = intArrayOf(3)
|
val shape = intArrayOf(3)
|
||||||
val tensorA = full(value = -4.5, shape = shape, device = device)
|
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||||
val tensorB = full(value = 10.9, shape = shape, device = device)
|
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||||
@ -81,7 +82,8 @@ internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
|
|||||||
val error = (expected - result).abs().sum().value() +
|
val error = (expected - result).abs().sum().value() +
|
||||||
(expected - assignResult).abs().sum().value()
|
(expected - assignResult).abs().sum().value()
|
||||||
assertTrue(error < TOLERANCE)
|
assertTrue(error < TOLERANCE)
|
||||||
}
|
println(expected)
|
||||||
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun testingTensorTransformations(device: Device = Device.CPU): Unit {
|
internal fun testingTensorTransformations(device: Device = Device.CPU): Unit {
|
||||||
|
@ -3,9 +3,6 @@ package kscience.kmath.torch
|
|||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
internal val SEED = 987654
|
|
||||||
internal val TOLERANCE = 1e-6
|
|
||||||
|
|
||||||
internal fun testingSetSeed(device: Device = Device.CPU): Unit {
|
internal fun testingSetSeed(device: Device = Device.CPU): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
@ -22,10 +19,12 @@ internal fun testingSetSeed(device: Device = Device.CPU): Unit {
|
|||||||
internal class TestUtils {
|
internal class TestUtils {
|
||||||
@Test
|
@Test
|
||||||
fun testSetNumThreads() {
|
fun testSetNumThreads() {
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
val numThreads = 2
|
val numThreads = 2
|
||||||
setNumThreads(numThreads)
|
setNumThreads(numThreads)
|
||||||
assertEquals(numThreads, getNumThreads())
|
assertEquals(numThreads, getNumThreads())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSetSeed() = testingSetSeed()
|
fun testSetSeed() = testingSetSeed()
|
||||||
|
Loading…
Reference in New Issue
Block a user