diff --git a/kmath-torch/README.md b/kmath-torch/README.md index df5ace944..f4341f29c 100644 --- a/kmath-torch/README.md +++ b/kmath-torch/README.md @@ -106,10 +106,9 @@ TorchTensorRealAlgebra { val X = randNormal(shape = intArrayOf(dim,dim), device = device) val Q = X + X.transpose(0,1) val mu = randNormal(shape = intArrayOf(dim), device = device) - val c = randNormal(shape = IntArray(0), device = device) // expression to differentiate w.r.t. x - val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c + val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3 // value of the gradient at x val gradf = f grad x } diff --git a/kmath-torch/ctorch/include/ctorch.h b/kmath-torch/ctorch/include/ctorch.h index 9dcbd8288..fed68a6c3 100644 --- a/kmath-torch/ctorch/include/ctorch.h +++ b/kmath-torch/ctorch/include/ctorch.h @@ -78,10 +78,25 @@ extern "C" void times_assign_long(long value, TorchTensorHandle other); void times_assign_int(int value, TorchTensorHandle other); + TorchTensorHandle plus_double(double value, TorchTensorHandle other); + TorchTensorHandle plus_float(float value, TorchTensorHandle other); + TorchTensorHandle plus_long(long value, TorchTensorHandle other); + TorchTensorHandle plus_int(int value, TorchTensorHandle other); + + void plus_assign_double(double value, TorchTensorHandle other); + void plus_assign_float(float value, TorchTensorHandle other); + void plus_assign_long(long value, TorchTensorHandle other); + void plus_assign_int(int value, TorchTensorHandle other); + TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); + void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); + void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); + TorchTensorHandle unary_minus(TorchTensorHandle tensor); TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j); diff --git a/kmath-torch/ctorch/src/ctorch.cc b/kmath-torch/ctorch/src/ctorch.cc index effaff6fe..0cf373d7a 100644 --- a/kmath-torch/ctorch/src/ctorch.cc +++ b/kmath-torch/ctorch/src/ctorch.cc @@ -203,6 +203,39 @@ void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs)); } +TorchTensorHandle plus_double(double value, TorchTensorHandle other) +{ + return new torch::Tensor(ctorch::cast(other) + value); +} +TorchTensorHandle plus_float(float value, TorchTensorHandle other) +{ + return new torch::Tensor(ctorch::cast(other) + value); +} +TorchTensorHandle plus_long(long value, TorchTensorHandle other) +{ + return new torch::Tensor(ctorch::cast(other) + value); +} +TorchTensorHandle plus_int(int value, TorchTensorHandle other) +{ + return new torch::Tensor(ctorch::cast(other) + value); +} +void plus_assign_double(double value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} +void plus_assign_float(float value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} +void plus_assign_long(long value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} +void plus_assign_int(int value, TorchTensorHandle other) +{ + ctorch::cast(other) += value; +} + TorchTensorHandle times_double(double value, TorchTensorHandle other) { return new torch::Tensor(value * ctorch::cast(other)); @@ -252,6 +285,26 @@ void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) { ctorch::cast(lhs) -= ctorch::cast(rhs); } +TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs)); +} +void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + ctorch::cast(lhs) *= ctorch::cast(rhs); +} +TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + return new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs)); +} +void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) +{ + ctorch::cast(lhs) /= ctorch::cast(rhs); +} +TorchTensorHandle unary_minus(TorchTensorHandle tensor) +{ + return new torch::Tensor(-ctorch::cast(tensor)); +} TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle) { return new torch::Tensor(ctorch::cast(tensor_handle).abs()); diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index f3771b414..905475322 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -20,10 +20,27 @@ public sealed class TorchTensorAlgebra + public abstract operator fun T.plus(other: TorchTensor): TorchTensor + public abstract operator fun TorchTensor.plus(value: T): TorchTensor + public abstract operator fun TorchTensor.plusAssign(value: T): Unit + public abstract operator fun T.minus(other: TorchTensor): TorchTensor + public abstract operator fun TorchTensor.minus(value: T): TorchTensor + public abstract operator fun TorchTensor.minusAssign(value: T): Unit public abstract operator fun T.times(other: TorchTensor): TorchTensor public abstract operator fun TorchTensor.times(value: T): TorchTensor public abstract operator fun TorchTensor.timesAssign(value: T): Unit + public operator fun TorchTensor.times(other: TorchTensor): TorchTensor = + wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!) + public operator fun TorchTensor.timesAssign(other: TorchTensor): Unit { + times_tensor_assign(this.tensorHandle, other.tensorHandle) + } + public operator fun TorchTensor.div(other: TorchTensor): TorchTensor = + wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!) + public operator fun TorchTensor.divAssign(other: TorchTensor): Unit { + div_tensor_assign(this.tensorHandle, other.tensorHandle) + } + public infix fun TorchTensor.dot(other: TorchTensor): TorchTensor = wrap(matmul(this.tensorHandle, other.tensorHandle)!!) @@ -48,6 +65,8 @@ public sealed class TorchTensorAlgebra.minusAssign(other: TorchTensor): Unit { minus_tensor_assign(this.tensorHandle, other.tensorHandle) } + public operator fun TorchTensor.unaryMinus(): TorchTensor = + wrap(unary_minus(this.tensorHandle)!!) public fun TorchTensor.abs(): TorchTensor = wrap(abs_tensor(tensorHandle)!!) @@ -118,6 +137,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!! ) + override operator fun Double.plus(other: TorchTensor): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = plus_double(this, other.tensorHandle)!! + ) + + override fun TorchTensor.plus(value: Double): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = plus_double(value, this.tensorHandle)!! + ) + + override fun TorchTensor.plusAssign(value: Double): Unit { + plus_assign_double(value, this.tensorHandle) + } + + override operator fun Double.minus(other: TorchTensor): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = plus_double(-this, other.tensorHandle)!! + ) + + override fun TorchTensor.minus(value: Double): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = plus_double(-value, this.tensorHandle)!! + ) + + override fun TorchTensor.minusAssign(value: Double): Unit { + plus_assign_double(-value, this.tensorHandle) + } + override operator fun Double.times(other: TorchTensor): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = times_double(this, other.tensorHandle)!! diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt index 0163688b3..09595024e 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt @@ -56,9 +56,14 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): val tensorA = full(value = -4.5, shape = shape, device = device) val tensorB = full(value = 10.9, shape = shape, device = device) val tensorC = full(value = 789.3, shape = shape, device = device) - val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC + val tensorD = full(value = -72.9, shape = shape, device = device) + val tensorE = full(value = 553.1, shape = shape, device = device) + val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4 val expected = copyFromArray( - array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(), + array = (1..3).map { + 15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4 + } + .toDoubleArray(), shape = shape, device = device ) @@ -66,10 +71,13 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): val assignResult = full(value = 0.0, shape = shape, device = device) tensorA *= 15.8 tensorB *= 1.5 + tensorB *= -tensorD tensorC *= 0.02 + tensorC /= tensorE assignResult += tensorA assignResult -= tensorB assignResult += tensorC + assignResult += -39.4 val error = (expected - result).abs().sum().value() + (expected - assignResult).abs().sum().value() @@ -80,15 +88,19 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit { TorchTensorRealAlgebra { setSeed(SEED) - val x = randNormal(shape = intArrayOf(dim), device = device) - x.requiresGrad = true - val X = randNormal(shape = intArrayOf(dim,dim), device = device) - val Q = X + X.transpose(0,1) - val mu = randNormal(shape = intArrayOf(dim), device = device) - val c = randNormal(shape = IntArray(0), device = device) - val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c - val gradf = f grad x - val error = (gradf - ((Q dot x) + mu)).abs().sum().value() + val tensorX = randNormal(shape = intArrayOf(dim), device = device) + tensorX.requiresGrad = true + val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device) + val tensorSigma = randFeatures + randFeatures.transpose(0,1) + val tensorMu = randNormal(shape = intArrayOf(dim), device = device) + + val expressionAtX = + 0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 + + val gradientAtX = expressionAtX grad tensorX + val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu + + val error = (gradientAtX - expectedGradientAtX).abs().sum().value() assertTrue(error < TOLERANCE) } }