diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TensorAlgebra.kt deleted file mode 100644 index 96653f305..000000000 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TensorAlgebra.kt +++ /dev/null @@ -1,50 +0,0 @@ -package kscience.kmath.torch - -import kscience.kmath.operations.Field -import kscience.kmath.operations.Ring - - -public interface TensorAlgebra> : Ring { - - public operator fun T.plus(other: TorchTensorType): TorchTensorType - public operator fun TorchTensorType.plus(value: T): TorchTensorType - public operator fun TorchTensorType.plusAssign(value: T): Unit - public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit - - public operator fun T.minus(other: TorchTensorType): TorchTensorType - public operator fun TorchTensorType.minus(value: T): TorchTensorType - public operator fun TorchTensorType.minusAssign(value: T): Unit - public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit - - public operator fun T.times(other: TorchTensorType): TorchTensorType - public operator fun TorchTensorType.times(value: T): TorchTensorType - public operator fun TorchTensorType.timesAssign(value: T): Unit - public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit - - public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType - - public fun diagonalEmbedding( - diagonalEntries: TorchTensorType, - offset: Int = 0, dim1: Int = -2, dim2: Int = -1 - ): TorchTensorType - - public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType - public fun TorchTensorType.view(shape: IntArray): TorchTensorType - - public fun TorchTensorType.abs(): TorchTensorType - public fun TorchTensorType.sum(): TorchTensorType - -} - -public interface TensorFieldAlgebra> : - TensorAlgebra, Field { - - public operator fun TorchTensorType.divAssign(b: TorchTensorType) - - public fun TorchTensorType.exp(): TorchTensorType - public fun TorchTensorType.log(): TorchTensorType - - public fun TorchTensorType.svd(): Triple - public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair - -} \ No newline at end of file diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TensorStructure.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TensorStructure.kt deleted file mode 100644 index 53539045a..000000000 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TensorStructure.kt +++ /dev/null @@ -1,13 +0,0 @@ -package kscience.kmath.torch - -import kscience.kmath.structures.MutableNDStructure - -public abstract class TensorStructure: MutableNDStructure { - - // A tensor can have empty shape, in which case it represents just a value - public abstract fun value(): T - - // Tensors might hold shared resources - override fun equals(other: Any?): Boolean = false - override fun hashCode(): Int = 0 -} \ No newline at end of file diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt index 4302bc94c..cc41359a2 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt @@ -1,6 +1,8 @@ package kscience.kmath.torch +import kscience.kmath.structures.TensorStructure + import kotlinx.cinterop.* import kscience.kmath.ctorch.* @@ -12,15 +14,16 @@ public sealed class TorchTensor constructor( init { scope.defer(::close) } + private fun close(): Unit = dispose_tensor(tensorHandle) protected abstract fun item(): T override val dimension: Int get() = get_dim(tensorHandle) override val shape: IntArray - get() = (1..dimension).map{get_shape_at(tensorHandle, it-1)}.toIntArray() + get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray() public val strides: IntArray - get() = (1..dimension).map{get_stride_at(tensorHandle, it-1)}.toIntArray() + get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray() public val size: Int get() = get_numel(tensorHandle) public val device: Device get() = Device.fromInt(get_device(tensorHandle)) @@ -44,27 +47,27 @@ public sealed class TorchTensor constructor( internal inline fun checkIsValue() = check(isValue()) { "This tensor has shape ${shape.toList()}" } + override fun value(): T { checkIsValue() return item() } - public var requiresGrad: Boolean - get() = requires_grad(tensorHandle) - set(value) = requires_grad_(tensorHandle, value) - public fun copyToDouble(): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = copy_to_double(this.tensorHandle)!! ) + public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat( scope = scope, tensorHandle = copy_to_float(this.tensorHandle)!! ) + public fun copyToLong(): TorchTensorLong = TorchTensorLong( scope = scope, tensorHandle = copy_to_long(this.tensorHandle)!! ) + public fun copyToInt(): TorchTensorInt = TorchTensorInt( scope = scope, tensorHandle = copy_to_int(this.tensorHandle)!! @@ -72,10 +75,20 @@ public sealed class TorchTensor constructor( } +public sealed class TorchTensorOverField constructor( + scope: DeferScope, + tensorHandle: COpaquePointer +) : TorchTensor(scope, tensorHandle) { + internal var requiresGrad: Boolean + get() = requires_grad(tensorHandle) + set(value) = requires_grad_(tensorHandle, value) +} + + public class TorchTensorReal internal constructor( scope: DeferScope, tensorHandle: COpaquePointer -) : TorchTensor(scope, tensorHandle) { +) : TorchTensorOverField(scope, tensorHandle) { override fun item(): Double = get_item_double(tensorHandle) override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues()) override fun set(index: IntArray, value: Double) { @@ -86,7 +99,7 @@ public class TorchTensorReal internal constructor( public class TorchTensorFloat internal constructor( scope: DeferScope, tensorHandle: COpaquePointer -) : TorchTensor(scope, tensorHandle) { +) : TorchTensorOverField(scope, tensorHandle) { override fun item(): Float = get_item_float(tensorHandle) override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues()) override fun set(index: IntArray, value: Float) { diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index db644962d..dc66813de 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -1,5 +1,8 @@ package kscience.kmath.torch + +import kscience.kmath.structures.* + import kotlinx.cinterop.* import kscience.kmath.ctorch.* @@ -40,12 +43,12 @@ public sealed class TorchTensorAlgebra< override val one: TorchTensorType get() = number(1) - protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType) = + protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit = check(a.device == b.device) { "Tensors must be on the same device" } - protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType) = + protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit = check(a.shape contentEquals b.shape) { "Tensors must be of identical shape" } @@ -214,7 +217,7 @@ public sealed class TorchTensorAlgebra< } public sealed class TorchTensorFieldAlgebra>(scope: DeferScope) : + PrimitiveArrayType, TorchTensorType : TorchTensorOverField>(scope: DeferScope) : TorchTensorAlgebra(scope), TensorFieldAlgebra { @@ -597,4 +600,14 @@ public inline fun TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> memScoped { TorchTensorLongAlgebra(this).block() } public inline fun TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R = - memScoped { TorchTensorIntAlgebra(this).block() } \ No newline at end of file + memScoped { TorchTensorIntAlgebra(this).block() } + +public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal { + this.requiresGrad = true + return TorchTensorRealAlgebra(this.scope).block() +} + +public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat { + this.requiresGrad = true + return TorchTensorFloatAlgebra(this.scope).block() +} \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt index a8e27511e..0ce3ef23a 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt @@ -5,14 +5,15 @@ import kotlin.test.* internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit { TorchTensorRealAlgebra { setSeed(SEED) + val tensorX = randNormal(shape = intArrayOf(dim), device = device) - tensorX.requiresGrad = true val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device) val tensorSigma = randFeatures + randFeatures.transpose(0,1) val tensorMu = randNormal(shape = intArrayOf(dim), device = device) - val expressionAtX = + val expressionAtX = tensorX.withGrad { 0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 + } val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true) val hessianAtX = expressionAtX hess tensorX @@ -31,14 +32,14 @@ internal fun testingBatchedAutoGrad(bath: IntArray, setSeed(SEED) val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device) - tensorX.requiresGrad = true - val tensorXt = tensorX.transpose(-1,-2) val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device) val tensorSigma = randFeatures + randFeatures.transpose(-2,-1) val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device) - val expressionAtX = + val expressionAtX = tensorX.withGrad{ + val tensorXt = tensorX.transpose(-1,-2) 0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2 + } expressionAtX.sumAssign() val gradientAtX = expressionAtX grad tensorX @@ -52,7 +53,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray, internal class TestAutograd { @Test - fun testAutoGrad() = testingAutoGrad(dim = 100) + fun testAutoGrad() = testingAutoGrad(dim = 3) @Test fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)