diff --git a/kmath-torch/ctorch/include/ctorch.h b/kmath-torch/ctorch/include/ctorch.h index fed68a6c3..e5a8562fe 100644 --- a/kmath-torch/ctorch/include/ctorch.h +++ b/kmath-torch/ctorch/include/ctorch.h @@ -73,20 +73,20 @@ extern "C" TorchTensorHandle times_long(long value, TorchTensorHandle other); TorchTensorHandle times_int(int value, TorchTensorHandle other); - void times_assign_double(double value, TorchTensorHandle other); - void times_assign_float(float value, TorchTensorHandle other); - void times_assign_long(long value, TorchTensorHandle other); - void times_assign_int(int value, TorchTensorHandle other); + void times_double_assign(double value, TorchTensorHandle other); + void times_float_assign(float value, TorchTensorHandle other); + void times_long_assign(long value, TorchTensorHandle other); + void times_int_assign(int value, TorchTensorHandle other); TorchTensorHandle plus_double(double value, TorchTensorHandle other); TorchTensorHandle plus_float(float value, TorchTensorHandle other); TorchTensorHandle plus_long(long value, TorchTensorHandle other); TorchTensorHandle plus_int(int value, TorchTensorHandle other); - void plus_assign_double(double value, TorchTensorHandle other); - void plus_assign_float(float value, TorchTensorHandle other); - void plus_assign_long(long value, TorchTensorHandle other); - void plus_assign_int(int value, TorchTensorHandle other); + void plus_double_assign(double value, TorchTensorHandle other); + void plus_float_assign(float value, TorchTensorHandle other); + void plus_long_assign(long value, TorchTensorHandle other); + void plus_int_assign(int value, TorchTensorHandle other); TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); @@ -96,10 +96,22 @@ extern "C" void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); - TorchTensorHandle unary_minus(TorchTensorHandle tensor); + TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle); + TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle); - TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle); + void abs_tensor_assign(TorchTensorHandle tensor_handle); + TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j); + void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j); + + TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle); + void exp_tensor_assign(TorchTensorHandle tensor_handle); + + TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle); + void log_tensor_assign(TorchTensorHandle tensor_handle); + + TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle); + void sum_tensor_assign(TorchTensorHandle tensor_handle); bool requires_grad(TorchTensorHandle tensor_handle); void requires_grad_(TorchTensorHandle tensor_handle, bool status); diff --git a/kmath-torch/ctorch/src/ctorch.cc b/kmath-torch/ctorch/src/ctorch.cc index 0cf373d7a..7148f9e15 100644 --- a/kmath-torch/ctorch/src/ctorch.cc +++ b/kmath-torch/ctorch/src/ctorch.cc @@ -219,19 +219,19 @@ TorchTensorHandle plus_int(int value, TorchTensorHandle other) { return new torch::Tensor(ctorch::cast(other) + value); } -void plus_assign_double(double value, TorchTensorHandle other) +void plus_double_assign(double value, TorchTensorHandle other) { ctorch::cast(other) += value; } -void plus_assign_float(float value, TorchTensorHandle other) +void plus_float_assign(float value, TorchTensorHandle other) { ctorch::cast(other) += value; } -void plus_assign_long(long value, TorchTensorHandle other) +void plus_long_assign(long value, TorchTensorHandle other) { ctorch::cast(other) += value; } -void plus_assign_int(int value, TorchTensorHandle other) +void plus_int_assign(int value, TorchTensorHandle other) { ctorch::cast(other) += value; } @@ -252,19 +252,19 @@ TorchTensorHandle times_int(int value, TorchTensorHandle other) { return new torch::Tensor(value * ctorch::cast(other)); } -void times_assign_double(double value, TorchTensorHandle other) +void times_double_assign(double value, TorchTensorHandle other) { ctorch::cast(other) *= value; } -void times_assign_float(float value, TorchTensorHandle other) +void times_float_assign(float value, TorchTensorHandle other) { ctorch::cast(other) *= value; } -void times_assign_long(long value, TorchTensorHandle other) +void times_long_assign(long value, TorchTensorHandle other) { ctorch::cast(other) *= value; } -void times_assign_int(int value, TorchTensorHandle other) +void times_int_assign(int value, TorchTensorHandle other) { ctorch::cast(other) *= value; } @@ -301,21 +301,54 @@ void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) { ctorch::cast(lhs) /= ctorch::cast(rhs); } -TorchTensorHandle unary_minus(TorchTensorHandle tensor) +TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle) { - return new torch::Tensor(-ctorch::cast(tensor)); + return new torch::Tensor(-ctorch::cast(tensor_handle)); } + TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle) { return new torch::Tensor(ctorch::cast(tensor_handle).abs()); } +void abs_tensor_assign(TorchTensorHandle tensor_handle) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs(); +} + +TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j) +{ + return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j)); +} +void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j); +} + +TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle) +{ + return new torch::Tensor(ctorch::cast(tensor_handle).exp()); +} +void exp_tensor_assign(TorchTensorHandle tensor_handle) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp(); +} + +TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle) +{ + return new torch::Tensor(ctorch::cast(tensor_handle).log()); +} +void log_tensor_assign(TorchTensorHandle tensor_handle) +{ + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log(); +} + TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle) { return new torch::Tensor(ctorch::cast(tensor_handle).sum()); } -TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j) +void sum_tensor_assign(TorchTensorHandle tensor_handle) { - return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j)); + ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum(); } bool requires_grad(TorchTensorHandle tensor_handle) diff --git a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestAutogradGPU.kt b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestAutogradGPU.kt new file mode 100644 index 000000000..18f13099d --- /dev/null +++ b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestAutogradGPU.kt @@ -0,0 +1,9 @@ +package kscience.kmath.torch + +import kotlin.test.* + + +internal class TestAutogradGPU { + @Test + fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0)) +} \ No newline at end of file diff --git a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt index 3f734c452..dcdc4d7df 100644 --- a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt +++ b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt @@ -18,5 +18,7 @@ class TestTorchTensorAlgebraGPU { testingLinearStructure(device = TorchDevice.TorchCUDA(0)) @Test - fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0)) + fun testTensorTransformations() = + testingTensorTransformations(device = TorchDevice.TorchCUDA(0)) + } diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt index 7e9904ecb..953368b4a 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt @@ -16,7 +16,6 @@ public sealed class TorchTensor constructor( private fun close(): Unit = dispose_tensor(tensorHandle) protected abstract fun item(): T - internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensor override val dimension: Int get() = get_dim(tensorHandle) override val shape: IntArray @@ -50,18 +49,6 @@ public sealed class TorchTensor constructor( return item() } - public fun copy(): TorchTensor = - wrap( - outScope = scope, - outTensorHandle = copy_tensor(tensorHandle)!! - ) - - public fun copyToDevice(device: TorchDevice): TorchTensor = - wrap( - outScope = scope, - outTensorHandle = copy_to_device(tensorHandle, device.toInt())!! - ) - public var requiresGrad: Boolean get() = requires_grad(tensorHandle) set(value) = requires_grad_(tensorHandle, value) @@ -76,9 +63,6 @@ public class TorchTensorReal internal constructor( tensorHandle: COpaquePointer ) : TorchTensor(scope, tensorHandle) { override fun item(): Double = get_item_double(tensorHandle) - override fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer - ): TorchTensorReal = TorchTensorReal(scope = outScope, tensorHandle = outTensorHandle) - override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues()) override fun set(index: IntArray, value: Double) { set_double(tensorHandle, index.toCValues(), value) diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index 905475322..8f0f2c5e9 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -3,93 +3,129 @@ package kscience.kmath.torch import kotlinx.cinterop.* import kscience.kmath.ctorch.* -public sealed class TorchTensorAlgebra constructor( +public sealed class TorchTensorAlgebra< + T, + TVar : CPrimitiveVar, + PrimitiveArrayType, + TorchTensorType : TorchTensor> constructor( internal val scope: DeferScope ) { - internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor + internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType public abstract fun copyFromArray( array: PrimitiveArrayType, shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU - ): TorchTensor - public abstract fun TorchTensor.copyToArray(): PrimitiveArrayType + ): TorchTensorType - public abstract fun fromBlob(arrayBlob: CPointer, shape: IntArray): TorchTensor - public abstract fun TorchTensor.getData(): CPointer + public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType - public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor + public abstract fun fromBlob(arrayBlob: CPointer, shape: IntArray): TorchTensorType + public abstract fun TorchTensorType.getData(): CPointer - public abstract operator fun T.plus(other: TorchTensor): TorchTensor - public abstract operator fun TorchTensor.plus(value: T): TorchTensor - public abstract operator fun TorchTensor.plusAssign(value: T): Unit - public abstract operator fun T.minus(other: TorchTensor): TorchTensor - public abstract operator fun TorchTensor.minus(value: T): TorchTensor - public abstract operator fun TorchTensor.minusAssign(value: T): Unit - public abstract operator fun T.times(other: TorchTensor): TorchTensor - public abstract operator fun TorchTensor.times(value: T): TorchTensor - public abstract operator fun TorchTensor.timesAssign(value: T): Unit + public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensorType - public operator fun TorchTensor.times(other: TorchTensor): TorchTensor = + public abstract operator fun T.plus(other: TorchTensorType): TorchTensorType + public abstract operator fun TorchTensorType.plus(value: T): TorchTensorType + public abstract operator fun TorchTensorType.plusAssign(value: T): Unit + public abstract operator fun T.minus(other: TorchTensorType): TorchTensorType + public abstract operator fun TorchTensorType.minus(value: T): TorchTensorType + public abstract operator fun TorchTensorType.minusAssign(value: T): Unit + public abstract operator fun T.times(other: TorchTensorType): TorchTensorType + public abstract operator fun TorchTensorType.times(value: T): TorchTensorType + public abstract operator fun TorchTensorType.timesAssign(value: T): Unit + + public operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType = wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!) - public operator fun TorchTensor.timesAssign(other: TorchTensor): Unit { + + public operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit { times_tensor_assign(this.tensorHandle, other.tensorHandle) } - public operator fun TorchTensor.div(other: TorchTensor): TorchTensor = + + public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType = wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!) - public operator fun TorchTensor.divAssign(other: TorchTensor): Unit { + + public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit { div_tensor_assign(this.tensorHandle, other.tensorHandle) } - public infix fun TorchTensor.dot(other: TorchTensor): TorchTensor = + public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType = wrap(matmul(this.tensorHandle, other.tensorHandle)!!) - public infix fun TorchTensor.dotAssign(other: TorchTensor): Unit { + public infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit { matmul_assign(this.tensorHandle, other.tensorHandle) } - public infix fun TorchTensor.dotRightAssign(other: TorchTensor): Unit { + public infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit { matmul_right_assign(this.tensorHandle, other.tensorHandle) } - public operator fun TorchTensor.plus(other: TorchTensor): TorchTensor = + public operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType = wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!) - public operator fun TorchTensor.plusAssign(other: TorchTensor): Unit { + public operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit { plus_tensor_assign(this.tensorHandle, other.tensorHandle) } - public operator fun TorchTensor.minus(other: TorchTensor): TorchTensor = + public operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType = wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!) - public operator fun TorchTensor.minusAssign(other: TorchTensor): Unit { + public operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit { minus_tensor_assign(this.tensorHandle, other.tensorHandle) } - public operator fun TorchTensor.unaryMinus(): TorchTensor = + + public operator fun TorchTensorType.unaryMinus(): TorchTensorType = wrap(unary_minus(this.tensorHandle)!!) - public fun TorchTensor.abs(): TorchTensor = wrap(abs_tensor(tensorHandle)!!) + public fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!) + public fun TorchTensorType.absAssign(): Unit { + abs_tensor_assign(tensorHandle) + } - public fun TorchTensor.sum(): TorchTensor = wrap(sum_tensor(tensorHandle)!!) + public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType = + wrap(transpose_tensor(tensorHandle, i, j)!!) + public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit { + transpose_tensor_assign(tensorHandle, i, j) + } - public fun TorchTensor.transpose(i: Int, j: Int): TorchTensor = - wrap(transpose_tensor(tensorHandle, i , j)!!) + public fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!) + public fun TorchTensorType.sumAssign(): Unit { + sum_tensor_assign(tensorHandle) + } - public infix fun TorchTensor.grad(variable: TorchTensor): TorchTensor = + public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType = wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) + + public fun TorchTensorType.copy(): TorchTensorType = + wrap(tensorHandle = copy_tensor(this.tensorHandle)!!) + public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType = + wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!) } -public sealed class TorchTensorFieldAlgebra(scope: DeferScope) : - TorchTensorAlgebra(scope) { - public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor - public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor +public sealed class TorchTensorFieldAlgebra>(scope: DeferScope) : + TorchTensorAlgebra(scope) { + public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType + public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType + + public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!) + public fun TorchTensorType.expAssign(): Unit { + exp_tensor_assign(tensorHandle) + } + + public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!) + public fun TorchTensorType.logAssign(): Unit { + log_tensor_assign(tensorHandle) + } + } -public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra(scope) { +public class TorchTensorRealAlgebra(scope: DeferScope) : + TorchTensorFieldAlgebra(scope) { override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal = TorchTensorReal(scope = scope, tensorHandle = tensorHandle) - override fun TorchTensor.copyToArray(): DoubleArray = + override fun TorchTensorReal.copyToArray(): DoubleArray = this.elements().map { it.second }.toList().toDoubleArray() override fun copyFromArray( @@ -120,8 +156,8 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra )!! ) - override fun TorchTensor.getData(): CPointer { - require(this.device is TorchDevice.TorchCPU){ + override fun TorchTensorReal.getData(): CPointer { + require(this.device is TorchDevice.TorchCPU) { "This tensor is not on available on CPU" } return get_data_double(this.tensorHandle)!! @@ -137,50 +173,50 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!! ) - override operator fun Double.plus(other: TorchTensor): TorchTensorReal = TorchTensorReal( + override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = plus_double(this, other.tensorHandle)!! ) - override fun TorchTensor.plus(value: Double): TorchTensorReal = TorchTensorReal( + override fun TorchTensorReal.plus(value: Double): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = plus_double(value, this.tensorHandle)!! ) - override fun TorchTensor.plusAssign(value: Double): Unit { - plus_assign_double(value, this.tensorHandle) + override fun TorchTensorReal.plusAssign(value: Double): Unit { + plus_double_assign(value, this.tensorHandle) } - override operator fun Double.minus(other: TorchTensor): TorchTensorReal = TorchTensorReal( + override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = plus_double(-this, other.tensorHandle)!! ) - override fun TorchTensor.minus(value: Double): TorchTensorReal = TorchTensorReal( + override fun TorchTensorReal.minus(value: Double): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = plus_double(-value, this.tensorHandle)!! ) - override fun TorchTensor.minusAssign(value: Double): Unit { - plus_assign_double(-value, this.tensorHandle) + override fun TorchTensorReal.minusAssign(value: Double): Unit { + plus_double_assign(-value, this.tensorHandle) } - override operator fun Double.times(other: TorchTensor): TorchTensorReal = TorchTensorReal( + override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = times_double(this, other.tensorHandle)!! ) - override fun TorchTensor.times(value: Double): TorchTensorReal = TorchTensorReal( + override fun TorchTensorReal.times(value: Double): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = times_double(value, this.tensorHandle)!! ) - override fun TorchTensor.timesAssign(value: Double): Unit { - times_assign_double(value, this.tensorHandle) + override fun TorchTensorReal.timesAssign(value: Double): Unit { + times_double_assign(value, this.tensorHandle) } - override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( + override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!! ) diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt new file mode 100644 index 000000000..2aa3e14a1 --- /dev/null +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt @@ -0,0 +1,28 @@ +package kscience.kmath.torch + +import kotlin.test.* + +internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit { + TorchTensorRealAlgebra { + setSeed(SEED) + val tensorX = randNormal(shape = intArrayOf(dim), device = device) + tensorX.requiresGrad = true + val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device) + val tensorSigma = randFeatures + randFeatures.transpose(0,1) + val tensorMu = randNormal(shape = intArrayOf(dim), device = device) + + val expressionAtX = + 0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 + + val gradientAtX = expressionAtX grad tensorX + val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu + + val error = (gradientAtX - expectedGradientAtX).abs().sum().value() + assertTrue(error < TOLERANCE) + } +} + +internal class TestAutograd { + @Test + fun testAutoGrad() = testingAutoGrad(dim = 100) +} \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt index 09595024e..118128d12 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt @@ -3,8 +3,7 @@ package kscience.kmath.torch import kscience.kmath.linear.RealMatrixContext import kscience.kmath.operations.invoke import kscience.kmath.structures.Matrix -import kotlin.math.abs -import kotlin.math.exp +import kotlin.math.* import kotlin.test.* internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit { @@ -40,7 +39,7 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch lhsTensorCopy dotAssign rhsTensor lhsTensor dotRightAssign rhsTensorCopy - var error: Double = 0.0 + var error = 0.0 product.elements().forEach { error += abs(expected[it.first] - it.second) + abs(expected[it.first] - lhsTensorCopy[it.first]) + @@ -85,27 +84,24 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): } } -internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit { +internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.TorchCPU): Unit { TorchTensorRealAlgebra { setSeed(SEED) - val tensorX = randNormal(shape = intArrayOf(dim), device = device) - tensorX.requiresGrad = true - val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device) - val tensorSigma = randFeatures + randFeatures.transpose(0,1) - val tensorMu = randNormal(shape = intArrayOf(dim), device = device) - - val expressionAtX = - 0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 - - val gradientAtX = expressionAtX grad tensorX - val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu - - val error = (gradientAtX - expectedGradientAtX).abs().sum().value() - assertTrue(error < TOLERANCE) + val tensor = randNormal(shape = intArrayOf(3, 3), device = device) + val result = tensor.exp().log() + val assignResult = tensor.copy() + assignResult.transposeAssign(0,1) + assignResult.expAssign() + assignResult.logAssign() + assignResult.transposeAssign(0,1) + val error = tensor - result + error.absAssign() + error.sumAssign() + error += (tensor - assignResult).abs().sum() + assertTrue (error.value()< TOLERANCE) } } - internal class TestTorchTensorAlgebra { @Test @@ -118,6 +114,6 @@ internal class TestTorchTensorAlgebra { fun testLinearStructure() = testingLinearStructure() @Test - fun testAutoGrad() = testingAutoGrad(dim = 100) + fun testTensorTransformations() = testingTensorTransformations() } \ No newline at end of file