diff --git a/kmath-torch/README.md b/kmath-torch/README.md index c46870946..df5ace944 100644 --- a/kmath-torch/README.md +++ b/kmath-torch/README.md @@ -76,7 +76,7 @@ pluginManagement { ## Usage -Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping: +Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping: ```kotlin TorchTensorRealAlgebra { @@ -94,4 +94,24 @@ TorchTensorRealAlgebra { println(gpuRealTensor) } ``` +Enjoy a high performance automatic differentiation engine: +```kotlin +TorchTensorRealAlgebra { + val dim = 10 + val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0) + val x = randNormal(shape = intArrayOf(dim), device = device) + // x is the variable + x.requiresGrad = true + + val X = randNormal(shape = intArrayOf(dim,dim), device = device) + val Q = X + X.transpose(0,1) + val mu = randNormal(shape = intArrayOf(dim), device = device) + val c = randNormal(shape = IntArray(0), device = device) + + // expression to differentiate w.r.t. x + val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c + // value of the gradient at x + val gradf = f grad x +} +``` diff --git a/kmath-torch/ctorch/include/ctorch.h b/kmath-torch/ctorch/include/ctorch.h index 553997d89..66bf41446 100644 --- a/kmath-torch/ctorch/include/ctorch.h +++ b/kmath-torch/ctorch/include/ctorch.h @@ -29,7 +29,7 @@ extern "C" float get_item_float(TorchTensorHandle tensor_handle); long get_item_long(TorchTensorHandle tensor_handle); int get_item_int(TorchTensorHandle tensor_handle); - + int get_dim(TorchTensorHandle tensor_handle); int get_numel(TorchTensorHandle tensor_handle); int get_shape_at(TorchTensorHandle tensor_handle, int d); @@ -40,25 +40,52 @@ extern "C" void dispose_char(char *ptr); void dispose_tensor(TorchTensorHandle tensor_handle); - - double get_double(TorchTensorHandle tensor_handle, int* index); - float get_float(TorchTensorHandle tensor_handle, int* index); - long get_long(TorchTensorHandle tensor_handle, int* index); - int get_int(TorchTensorHandle tensor_handle, int* index); - void set_double(TorchTensorHandle tensor_handle, int* index, double value); - void set_float(TorchTensorHandle tensor_handle, int* index, float value); - void set_long(TorchTensorHandle tensor_handle, int* index, long value); - void set_int(TorchTensorHandle tensor_handle, int* index, int value); + double get_double(TorchTensorHandle tensor_handle, int *index); + float get_float(TorchTensorHandle tensor_handle, int *index); + long get_long(TorchTensorHandle tensor_handle, int *index); + int get_int(TorchTensorHandle tensor_handle, int *index); + void set_double(TorchTensorHandle tensor_handle, int *index, double value); + void set_float(TorchTensorHandle tensor_handle, int *index, float value); + void set_long(TorchTensorHandle tensor_handle, int *index, long value); + void set_int(TorchTensorHandle tensor_handle, int *index, int value); + TorchTensorHandle randn_double(int *shape, int shape_size, int device); + TorchTensorHandle rand_double(int *shape, int shape_size, int device); + TorchTensorHandle randn_float(int *shape, int shape_size, int device); + TorchTensorHandle rand_float(int *shape, int shape_size, int device); - TorchTensorHandle randn_double(int* shape, int shape_size, int device); - TorchTensorHandle rand_double(int* shape, int shape_size, int device); - TorchTensorHandle randn_float(int* shape, int shape_size, int device); - TorchTensorHandle rand_float(int* shape, int shape_size, int device); + TorchTensorHandle full_double(double value, int *shape, int shape_size, int device); + TorchTensorHandle full_float(float value, int *shape, int shape_size, int device); + TorchTensorHandle full_long(long value, int *shape, int shape_size, int device); + TorchTensorHandle full_int(int value, int *shape, int shape_size, int device); TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs); void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); - + void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + + TorchTensorHandle times_double(double value, TorchTensorHandle other); + TorchTensorHandle times_float(float value, TorchTensorHandle other); + TorchTensorHandle times_long(long value, TorchTensorHandle other); + TorchTensorHandle times_int(int value, TorchTensorHandle other); + + void times_assign_double(double value, TorchTensorHandle other); + void times_assign_float(float value, TorchTensorHandle other); + void times_assign_long(long value, TorchTensorHandle other); + void times_assign_int(int value, TorchTensorHandle other); + + TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); + void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); + void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle); + TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle); + TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j); + + bool requires_grad(TorchTensorHandle tensor_handle); + void requires_grad_(TorchTensorHandle tensor_handle, bool status); + void detach_from_graph(TorchTensorHandle tensor_handle); + TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable); + #ifdef __cplusplus } #endif diff --git a/kmath-torch/ctorch/include/utils.hh b/kmath-torch/ctorch/include/utils.hh index d375438d4..79de077da 100644 --- a/kmath-torch/ctorch/include/utils.hh +++ b/kmath-torch/ctorch/include/utils.hh @@ -92,4 +92,10 @@ namespace ctorch return torch::rand(shape, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } + template + inline torch::Tensor full(Dtype value, std::vector shape, torch::Device device) + { + return torch::full(shape, value, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); + } + } // namespace ctorch diff --git a/kmath-torch/ctorch/src/ctorch.cc b/kmath-torch/ctorch/src/ctorch.cc index f5d1ce6ee..37b20811e 100644 --- a/kmath-torch/ctorch/src/ctorch.cc +++ b/kmath-torch/ctorch/src/ctorch.cc @@ -156,12 +156,111 @@ TorchTensorHandle rand_float(int *shape, int shape_size, int device) return new torch::Tensor(ctorch::rand(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); } +TorchTensorHandle full_double(double value, int *shape, int shape_size, int device) +{ + return new torch::Tensor(ctorch::full(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); +} +TorchTensorHandle full_float(float value, int *shape, int shape_size, int device) +{ + return new torch::Tensor(ctorch::full(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); +} +TorchTensorHandle full_long(long value, int *shape, int shape_size, int device) +{ + return new torch::Tensor(ctorch::full(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); +} +TorchTensorHandle full_int(int value, int *shape, int shape_size, int device) +{ + return new torch::Tensor(ctorch::full(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); +} + TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs) { return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs))); } - void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) { ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs)); -} \ No newline at end of file +} +void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs)); +} + +TorchTensorHandle times_double(double value, TorchTensorHandle other) +{ + return new torch::Tensor(value * ctorch::cast(other)); +} +TorchTensorHandle times_float(float value, TorchTensorHandle other) +{ + return new torch::Tensor(value * ctorch::cast(other)); +} +TorchTensorHandle times_long(long value, TorchTensorHandle other) +{ + return new torch::Tensor(value * ctorch::cast(other)); +} +TorchTensorHandle times_int(int value, TorchTensorHandle other) +{ + return new torch::Tensor(value * ctorch::cast(other)); +} +void times_assign_double(double value, TorchTensorHandle other) +{ + ctorch::cast(other) *= value; +} +void times_assign_float(float value, TorchTensorHandle other) +{ + ctorch::cast(other) *= value; +} +void times_assign_long(long value, TorchTensorHandle other) +{ + ctorch::cast(other) *= value; +} +void times_assign_int(int value, TorchTensorHandle other) +{ + ctorch::cast(other) *= value; +} + +TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs)); +} +void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + ctorch::cast(lhs) += ctorch::cast(rhs); +} +TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs)); +} +void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + ctorch::cast(lhs) -= ctorch::cast(rhs); +} +TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle) +{ + return new torch::Tensor(ctorch::cast(tensor_handle).abs()); +} +TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle) +{ + return new torch::Tensor(ctorch::cast(tensor_handle).sum()); +} +TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j) +{ + return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j)); +} + +bool requires_grad(TorchTensorHandle tensor_handle) +{ + return ctorch::cast(tensor_handle).requires_grad(); +} +void requires_grad_(TorchTensorHandle tensor_handle, bool status) +{ + ctorch::cast(tensor_handle).requires_grad_(status); +} +void detach_from_graph(TorchTensorHandle tensor_handle) +{ + ctorch::cast(tensor_handle).detach(); +} +TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable) +{ + return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]); +} diff --git a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt index 625529631..3f734c452 100644 --- a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt +++ b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt @@ -6,8 +6,17 @@ import kotlin.test.* class TestTorchTensorAlgebraGPU { @Test - fun testScalarProduct() = testingScalarProduct(device = TorchDevice.TorchCUDA(0)) + fun testScalarProduct() = + testingScalarProduct(device = TorchDevice.TorchCUDA(0)) @Test - fun testMatrixMultiplication() = testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0)) + fun testMatrixMultiplication() = + testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0)) + + @Test + fun testLinearStructure() = + testingLinearStructure(device = TorchDevice.TorchCUDA(0)) + + @Test + fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0)) } diff --git a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestUtilsGPU.kt b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestUtilsGPU.kt index f8ecb5554..3cf1d1c2a 100644 --- a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestUtilsGPU.kt +++ b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestUtilsGPU.kt @@ -13,21 +13,4 @@ internal class TestUtilsGPU { @Test fun testSetSeed() = testingSetSeed(TorchDevice.TorchCUDA(0)) - @Test - fun testReadmeFactory() = TorchTensorRealAlgebra { - - val realTensor: TorchTensorReal = copyFromArray( - array = (1..10).map { it + 50.0 }.toList().toDoubleArray(), - shape = intArrayOf(2,5) - ) - println(realTensor) - - val gpuRealTensor: TorchTensorReal = copyFromArray( - array = (1..8).map { it * 2.5 }.toList().toDoubleArray(), - shape = intArrayOf(2, 2, 2), - device = TorchDevice.TorchCUDA(0) - ) - println(gpuRealTensor) - } - } diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt index f8e992a6f..7e9904ecb 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt @@ -62,6 +62,13 @@ public sealed class TorchTensor constructor( outTensorHandle = copy_to_device(tensorHandle, device.toInt())!! ) + public var requiresGrad: Boolean + get() = requires_grad(tensorHandle) + set(value) = requires_grad_(tensorHandle, value) + + public fun detachFromGraph(): Unit { + detach_from_graph(tensorHandle) + } } public class TorchTensorReal internal constructor( diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index 3324912bd..80df73765 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -14,6 +14,11 @@ public sealed class TorchTensorAlgebra constructor( ): TorchTensor public abstract fun TorchTensor.copyToArray(): PrimitiveArrayType + public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor + + public abstract operator fun T.times(other: TorchTensor): TorchTensor + public abstract operator fun TorchTensor.times(value: T): TorchTensor + public abstract operator fun TorchTensor.timesAssign(value: T): Unit public infix fun TorchTensor.dot(other: TorchTensor): TorchTensor = wrap(matmul(this.tensorHandle, other.tensorHandle)!!) @@ -21,6 +26,34 @@ public sealed class TorchTensorAlgebra constructor( public infix fun TorchTensor.dotAssign(other: TorchTensor): Unit { matmul_assign(this.tensorHandle, other.tensorHandle) } + + public infix fun TorchTensor.dotRightAssign(other: TorchTensor): Unit { + matmul_right_assign(this.tensorHandle, other.tensorHandle) + } + + public operator fun TorchTensor.plus(other: TorchTensor): TorchTensor = + wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!) + + public operator fun TorchTensor.plusAssign(other: TorchTensor): Unit { + plus_tensor_assign(this.tensorHandle, other.tensorHandle) + } + + public operator fun TorchTensor.minus(other: TorchTensor): TorchTensor = + wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!) + + public operator fun TorchTensor.minusAssign(other: TorchTensor): Unit { + minus_tensor_assign(this.tensorHandle, other.tensorHandle) + } + + public fun TorchTensor.abs(): TorchTensor = wrap(abs_tensor(tensorHandle)!!) + + public fun TorchTensor.sum(): TorchTensor = wrap(sum_tensor(tensorHandle)!!) + + public fun TorchTensor.transpose(i: Int, j: Int): TorchTensor = + wrap(transpose_tensor(tensorHandle, i , j)!!) + + public infix fun TorchTensor.grad(variable: TorchTensor): TorchTensor = + wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) } public sealed class TorchTensorFieldAlgebra(scope: DeferScope) : @@ -60,6 +93,27 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra scope = scope, tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!! ) + + override operator fun Double.times(other: TorchTensor): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = times_double(this, other.tensorHandle)!! + ) + + override fun TorchTensor.times(value: Double): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = times_double(value, this.tensorHandle)!! + ) + + override fun TorchTensor.timesAssign(value: Double): Unit { + times_assign_double(value, this.tensorHandle) + } + + + override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!! + ) + } public fun TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt index c8185e6b8..5a6b07473 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt @@ -19,4 +19,14 @@ class TestTorchTensor { @Test fun testCopyFromArray() = testingCopyFromArray() + + @Test + fun testRequiresGrad() = TorchTensorRealAlgebra { + val tensor = randNormal(intArrayOf(3)) + assertTrue(!tensor.requiresGrad) + tensor.requiresGrad = true + assertTrue(tensor.requiresGrad) + tensor.requiresGrad = false + assertTrue(!tensor.requiresGrad) + } } \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt index c0518c056..0163688b3 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt @@ -4,12 +4,13 @@ import kscience.kmath.linear.RealMatrixContext import kscience.kmath.operations.invoke import kscience.kmath.structures.Matrix import kotlin.math.abs +import kotlin.math.exp import kotlin.test.* internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit { TorchTensorRealAlgebra { - val lhs = randUniform(shape = intArrayOf(10), device = device) - val rhs = randUniform(shape = intArrayOf(10), device = device) + val lhs = randUniform(shape = intArrayOf(3), device = device) + val rhs = randUniform(shape = intArrayOf(3), device = device) val product = lhs dot rhs var expected = 0.0 lhs.elements().forEach { @@ -23,27 +24,76 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch TorchTensorRealAlgebra { setSeed(SEED) - val lhsTensor = randNormal(shape = intArrayOf(20, 20), device = device) - val rhsTensor = randNormal(shape = intArrayOf(20, 20), device = device) + val lhsTensor = randNormal(shape = intArrayOf(3, 3), device = device) + val rhsTensor = randNormal(shape = intArrayOf(3, 3), device = device) val product = lhsTensor dot rhsTensor val expected: Matrix = RealMatrixContext { - val lhs = produce(20, 20) { i, j -> lhsTensor[intArrayOf(i, j)] } - val rhs = produce(20, 20) { i, j -> rhsTensor[intArrayOf(i, j)] } + val lhs = produce(3, 3) { i, j -> lhsTensor[intArrayOf(i, j)] } + val rhs = produce(3, 3) { i, j -> rhsTensor[intArrayOf(i, j)] } lhs dot rhs } - lhsTensor dotAssign rhsTensor + val lhsTensorCopy = lhsTensor.copy() + val rhsTensorCopy = rhsTensor.copy() + + lhsTensorCopy dotAssign rhsTensor + lhsTensor dotRightAssign rhsTensorCopy var error: Double = 0.0 product.elements().forEach { error += abs(expected[it.first] - it.second) + - abs(expected[it.first] - lhsTensor[it.first]) + abs(expected[it.first] - lhsTensorCopy[it.first]) + + abs(expected[it.first] - rhsTensorCopy[it.first]) } assertTrue(error < TOLERANCE) } } +internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): Unit { + TorchTensorRealAlgebra { + val shape = intArrayOf(3) + val tensorA = full(value = -4.5, shape = shape, device = device) + val tensorB = full(value = 10.9, shape = shape, device = device) + val tensorC = full(value = 789.3, shape = shape, device = device) + val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC + val expected = copyFromArray( + array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(), + shape = shape, + device = device + ) + + val assignResult = full(value = 0.0, shape = shape, device = device) + tensorA *= 15.8 + tensorB *= 1.5 + tensorC *= 0.02 + assignResult += tensorA + assignResult -= tensorB + assignResult += tensorC + + val error = (expected - result).abs().sum().value() + + (expected - assignResult).abs().sum().value() + assertTrue(error < TOLERANCE) + } +} + +internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit { + TorchTensorRealAlgebra { + setSeed(SEED) + val x = randNormal(shape = intArrayOf(dim), device = device) + x.requiresGrad = true + val X = randNormal(shape = intArrayOf(dim,dim), device = device) + val Q = X + X.transpose(0,1) + val mu = randNormal(shape = intArrayOf(dim), device = device) + val c = randNormal(shape = IntArray(0), device = device) + val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c + val gradf = f grad x + val error = (gradf - ((Q dot x) + mu)).abs().sum().value() + assertTrue(error < TOLERANCE) + } +} + + internal class TestTorchTensorAlgebra { @Test @@ -52,4 +102,10 @@ internal class TestTorchTensorAlgebra { @Test fun testMatrixMultiplication() = testingMatrixMultiplication() + @Test + fun testLinearStructure() = testingLinearStructure() + + @Test + fun testAutoGrad() = testingAutoGrad(dim = 100) + } \ No newline at end of file