diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt index b25db6e0f..50e44e81c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt @@ -162,7 +162,7 @@ public interface Strides { /** * Array strides */ - public val strides: List + public val strides: IntArray /** * Get linear index from multidimensional index @@ -189,6 +189,11 @@ public interface Strides { } } +internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int = index.mapIndexed { i, value -> + if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})") + value * strides[i] +}.sum() + /** * Simple implementation of [Strides]. */ @@ -199,7 +204,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) : /** * Strides for memory access */ - override val strides: List by lazy { + override val strides: IntArray by lazy { sequence { var current = 1 yield(1) @@ -208,13 +213,10 @@ public class DefaultStrides private constructor(override val shape: IntArray) : current *= it yield(current) } - }.toList() + }.toList().toIntArray() } - override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> - if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") - value * strides[i] - }.sum() + override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides) override fun index(offset: Int): IntArray { val res = IntArray(shape.size) @@ -323,8 +325,7 @@ public inline fun NDStructure.mapToBuffer( /** * Mutable ND buffer based on linear [MutableBuffer]. */ - -public class MutableNDBuffer( +public open class MutableNDBuffer( strides: Strides, buffer: MutableBuffer, ) : NDBuffer(strides, buffer), MutableNDStructure { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt new file mode 100644 index 000000000..0a28ace4f --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt @@ -0,0 +1,173 @@ +package space.kscience.kmath.tensors + +import space.kscience.kmath.nd.MutableNDBuffer +import space.kscience.kmath.structures.RealBuffer +import space.kscience.kmath.structures.array + + +public class RealTensor( + override val shape: IntArray, + buffer: DoubleArray +) : + TensorStructure, + MutableNDBuffer( + TensorStrides(shape), + RealBuffer(buffer) + ) { + override fun item(): Double = buffer[0] +} + +public class RealTensorAlgebra : TensorPartialDivisionAlgebra { + + override fun add(a: RealTensor, b: RealTensor): RealTensor { + TODO("Not yet implemented") + } + + override fun multiply(a: RealTensor, k: Number): RealTensor { + TODO("Not yet implemented") + } + + override val zero: RealTensor + get() = TODO("Not yet implemented") + + override fun multiply(a: RealTensor, b: RealTensor): RealTensor { + TODO("Not yet implemented") + } + + override val one: RealTensor + get() = TODO("Not yet implemented") + + + override fun Double.plus(other: RealTensor): RealTensor { + val n = other.buffer.size + val arr = other.buffer.array + val res = DoubleArray(n) + for (i in 1..n) + res[i - 1] = arr[i - 1] + this + return RealTensor(other.shape, res) + } + + override fun RealTensor.plus(value: Double): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.plusAssign(value: Double) { + TODO("Not yet implemented") + } + + override fun RealTensor.plusAssign(other: RealTensor) { + TODO("Not yet implemented") + } + + override fun Double.minus(other: RealTensor): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.minus(value: Double): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.minusAssign(value: Double) { + TODO("Not yet implemented") + } + + override fun RealTensor.minusAssign(other: RealTensor) { + TODO("Not yet implemented") + } + + override fun Double.times(other: RealTensor): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.times(value: Double): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.timesAssign(value: Double) { + TODO("Not yet implemented") + } + + override fun RealTensor.timesAssign(other: RealTensor) { + TODO("Not yet implemented") + } + + override fun RealTensor.dot(other: RealTensor): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.dotAssign(other: RealTensor) { + TODO("Not yet implemented") + } + + override fun RealTensor.dotRightAssign(other: RealTensor) { + TODO("Not yet implemented") + } + + override fun diagonalEmbedding(diagonalEntries: RealTensor, offset: Int, dim1: Int, dim2: Int): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.transpose(i: Int, j: Int): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.transposeAssign(i: Int, j: Int) { + TODO("Not yet implemented") + } + + override fun RealTensor.view(shape: IntArray): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.abs(): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.absAssign() { + TODO("Not yet implemented") + } + + override fun RealTensor.sum(): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.sumAssign() { + TODO("Not yet implemented") + } + + override fun RealTensor.div(other: RealTensor): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.divAssign(other: RealTensor) { + TODO("Not yet implemented") + } + + override fun RealTensor.exp(): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.expAssign() { + TODO("Not yet implemented") + } + + override fun RealTensor.log(): RealTensor { + TODO("Not yet implemented") + } + + override fun RealTensor.logAssign() { + TODO("Not yet implemented") + } + + override fun RealTensor.svd(): Triple { + TODO("Not yet implemented") + } + + override fun RealTensor.symEig(eigenvectors: Boolean): Pair { + TODO("Not yet implemented") + } + +} + +public inline fun RealTensorAlgebra(block: RealTensorAlgebra.() -> R): R = + RealTensorAlgebra().block() \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index cf16769c9..83c45feec 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -1,70 +1,114 @@ package space.kscience.kmath.tensors -import space.kscience.kmath.nd.MutableNDStructure - -public interface TensorStructure : MutableNDStructure { - // A tensor can have empty shape, in which case it represents just a value - public fun value(): T -} +import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.RingWithNumbers // https://proofwiki.org/wiki/Definition:Algebra_over_Ring - -public interface TensorAlgebra> { +public interface TensorAlgebra>: RingWithNumbers { public operator fun T.plus(other: TensorType): TensorType public operator fun TensorType.plus(value: T): TensorType - public operator fun TensorType.plus(other: TensorType): TensorType public operator fun TensorType.plusAssign(value: T): Unit public operator fun TensorType.plusAssign(other: TensorType): Unit public operator fun T.minus(other: TensorType): TensorType public operator fun TensorType.minus(value: T): TensorType - public operator fun TensorType.minus(other: TensorType): TensorType public operator fun TensorType.minusAssign(value: T): Unit public operator fun TensorType.minusAssign(other: TensorType): Unit public operator fun T.times(other: TensorType): TensorType public operator fun TensorType.times(value: T): TensorType - public operator fun TensorType.times(other: TensorType): TensorType public operator fun TensorType.timesAssign(value: T): Unit public operator fun TensorType.timesAssign(other: TensorType): Unit - public operator fun TensorType.unaryMinus(): TensorType + //https://pytorch.org/docs/stable/generated/torch.matmul.html public infix fun TensorType.dot(other: TensorType): TensorType public infix fun TensorType.dotAssign(other: TensorType): Unit public infix fun TensorType.dotRightAssign(other: TensorType): Unit + //https://pytorch.org/docs/stable/generated/torch.diag_embed.html public fun diagonalEmbedding( diagonalEntries: TensorType, offset: Int = 0, dim1: Int = -2, dim2: Int = -1 ): TensorType + //https://pytorch.org/docs/stable/generated/torch.transpose.html public fun TensorType.transpose(i: Int, j: Int): TensorType public fun TensorType.transposeAssign(i: Int, j: Int): Unit + //https://pytorch.org/docs/stable/tensor_view.html public fun TensorType.view(shape: IntArray): TensorType + //https://pytorch.org/docs/stable/generated/torch.abs.html public fun TensorType.abs(): TensorType public fun TensorType.absAssign(): Unit + + //https://pytorch.org/docs/stable/generated/torch.sum.html public fun TensorType.sum(): TensorType public fun TensorType.sumAssign(): Unit } // https://proofwiki.org/wiki/Definition:Division_Algebra - public interface TensorPartialDivisionAlgebra> : TensorAlgebra { public operator fun TensorType.div(other: TensorType): TensorType public operator fun TensorType.divAssign(other: TensorType) + //https://pytorch.org/docs/stable/generated/torch.exp.html public fun TensorType.exp(): TensorType public fun TensorType.expAssign(): Unit + + //https://pytorch.org/docs/stable/generated/torch.log.html public fun TensorType.log(): TensorType public fun TensorType.logAssign(): Unit + //https://pytorch.org/docs/stable/generated/torch.svd.html public fun TensorType.svd(): Triple + + //https://pytorch.org/docs/stable/generated/torch.symeig.html public fun TensorType.symEig(eigenvectors: Boolean = true): Pair -} \ No newline at end of file +} + +public inline fun , + TorchTensorAlgebraType : TensorAlgebra> + TorchTensorAlgebraType.checkShapeCompatible( + a: TensorType, b: TensorType +): Unit = + check(a.shape contentEquals b.shape) { + "Tensors must be of identical shape" + } + +public inline fun , + TorchTensorAlgebraType : TensorAlgebra> + TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit { + val sa = a.shape + val sb = b.shape + val na = sa.size + val nb = sb.size + var status: Boolean + if (nb == 1) { + status = sa.last() == sb[0] + } else { + status = sa.last() == sb[nb - 2] + if ((na > 2) and (nb > 2)) { + status = status and + (sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray()) + } + } + check(status) { "Incompatible shapes $sa and $sb for dot product" } +} + +public inline fun , + TorchTensorAlgebraType : TensorAlgebra> + TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit = + check((i < dim) and (j < dim)) { + "Cannot transpose $i to $j for a tensor of dim $dim" + } + +public inline fun , + TorchTensorAlgebraType : TensorAlgebra> + TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit = + check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStrides.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStrides.kt new file mode 100644 index 000000000..822e514d5 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStrides.kt @@ -0,0 +1,50 @@ +package space.kscience.kmath.tensors + +import space.kscience.kmath.nd.Strides +import space.kscience.kmath.nd.offsetFromIndex +import kotlin.math.max + + +inline public fun stridesFromShape(shape: IntArray): IntArray { + val nDim = shape.size + val res = IntArray(nDim) + if (nDim == 0) + return res + + var current = nDim - 1 + res[current] = 1 + + while (current > 0) { + res[current - 1] = max(1, shape[current]) * res[current] + current-- + } + return res + +} + +inline public fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray { + val res = IntArray(nDim) + var current = offset + var strideIndex = 0 + + while (strideIndex < nDim) { + res[strideIndex] = (current / strides[strideIndex]) + current %= strides[strideIndex] + strideIndex++ + } + return res +} + + +public class TensorStrides(override val shape: IntArray) : Strides { + override val strides: IntArray + get() = stridesFromShape(shape) + + override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides) + + override fun index(offset: Int): IntArray = + indexFromOffset(offset, strides, shape.size) + + override val linearSize: Int + get() = shape.fold(1) { acc, i -> acc * i } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStructure.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStructure.kt new file mode 100644 index 000000000..6d2d855b6 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorStructure.kt @@ -0,0 +1,23 @@ +package space.kscience.kmath.tensors + +import space.kscience.kmath.nd.MutableNDStructure + +public interface TensorStructure : MutableNDStructure { + public fun item(): T + + // A tensor can have empty shape, in which case it represents just a value + public fun value(): T { + checkIsValue() + return item() + } +} + +public inline fun TensorStructure.isValue(): Boolean { + return (dimension == 0) +} + +public inline fun TensorStructure.isNotValue(): Boolean = !this.isValue() + +public inline fun TensorStructure.checkIsValue(): Unit = check(this.isValue()) { + "This tensor has shape ${shape.toList()}" +} \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensor.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensor.kt new file mode 100644 index 000000000..7938eb864 --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensor.kt @@ -0,0 +1,24 @@ +package space.kscience.kmath.tensors + + +import space.kscience.kmath.structures.array +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class TestRealTensor { + + @Test + fun valueTest(){ + val value = 12.5 + val tensor = RealTensor(IntArray(0), doubleArrayOf(value)) + assertEquals(tensor.value(), value) + } + + @Test + fun stridesTest(){ + val tensor = RealTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4)) + assertEquals(tensor[intArrayOf(0,1)], 5.8) + assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.array) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt new file mode 100644 index 000000000..86caa0338 --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt @@ -0,0 +1,16 @@ +package space.kscience.kmath.tensors + +import space.kscience.kmath.structures.array +import kotlin.test.Test +import kotlin.test.assertTrue + +class TestRealTensorAlgebra { + + @Test + fun doublePlus() = RealTensorAlgebra { + val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0)) + val res = 10.0 + tensor + assertTrue(res.buffer.array contentEquals doubleArrayOf(11.0,12.0)) + } + +} \ No newline at end of file diff --git a/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/Device.kt b/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/Device.kt index 8c086c396..6b74212ea 100644 --- a/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/Device.kt +++ b/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/Device.kt @@ -10,13 +10,13 @@ public sealed class Device { public data class CUDA(val index: Int): Device() public fun toInt(): Int { when(this) { - is Device.CPU -> return 0 - is Device.CUDA -> return this.index + 1 + is CPU -> return 0 + is CUDA -> return this.index + 1 } } public companion object { public fun fromInt(deviceInt: Int): Device { - return if (deviceInt == 0) Device.CPU else Device.CUDA( + return if (deviceInt == 0) CPU else CUDA( deviceInt - 1 ) } diff --git a/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensor.kt b/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensor.kt index 8175f0009..0cda1a375 100644 --- a/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensor.kt +++ b/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensor.kt @@ -2,17 +2,14 @@ package space.kscience.kmath.torch -import space.kscience.kmath.tensors.TensorStructure +import space.kscience.kmath.tensors.* public interface TorchTensor : TensorStructure { - public fun item(): T + public val strides: IntArray public val size: Int public val device: Device - override fun value(): T { - checkIsValue() - return item() - } + override fun elements(): Sequence> { if (dimension == 0) { return emptySequence() @@ -22,31 +19,8 @@ public interface TorchTensor : TensorStructure { } } -public inline fun TorchTensor.isValue(): Boolean { - return (dimension == 0) -} - -public inline fun TorchTensor.isNotValue(): Boolean = !this.isValue() - -public inline fun TorchTensor.checkIsValue(): Unit = check(this.isValue()) { - "This tensor has shape ${shape.toList()}" -} - public interface TorchTensorOverField: TorchTensor { public var requiresGrad: Boolean } -private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray { - val res = IntArray(nDim) - var current = offset - var strideIndex = 0 - - while (strideIndex < nDim) { - res[strideIndex] = (current / strides[strideIndex]) - current %= strides[strideIndex] - strideIndex++ - } - return res -} - diff --git a/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebra.kt b/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebra.kt index 9def17c8e..ac760d109 100644 --- a/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/commonMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebra.kt @@ -2,8 +2,7 @@ package space.kscience.kmath.torch -import space.kscience.kmath.tensors.TensorAlgebra -import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra +import space.kscience.kmath.tensors.* public interface TorchTensorAlgebra> : TensorAlgebra { @@ -75,15 +74,6 @@ public inline fun , "Tensors must be on the same device" } -public inline fun , - TorchTensorAlgebraType : TorchTensorAlgebra> - TorchTensorAlgebraType.checkShapeCompatible( - a: TorchTensorType, - b: TorchTensorType -): Unit = - check(a.shape contentEquals b.shape) { - "Tensors must be of identical shape" - } public inline fun , TorchTensorAlgebraType : TorchTensorAlgebra> @@ -92,8 +82,8 @@ public inline fun , b: TorchTensorType ) { if (a.isNotValue() and b.isNotValue()) { - this.checkDeviceCompatible(a, b) - this.checkShapeCompatible(a, b) + checkDeviceCompatible(a, b) + checkShapeCompatible(a, b) } } @@ -101,36 +91,9 @@ public inline fun , TorchTensorAlgebraType : TorchTensorAlgebra> TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit { checkDeviceCompatible(a, b) - val sa = a.shape - val sb = b.shape - val na = sa.size - val nb = sb.size - var status: Boolean - if (nb == 1) { - status = sa.last() == sb[0] - } else { - status = sa.last() == sb[nb - 2] - if ((na > 2) and (nb > 2)) { - status = status and - (sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray()) - } - } - check(status) { "Incompatible shapes $sa and $sb for dot product" } + checkDot(a,b) } -public inline fun , - TorchTensorAlgebraType : TorchTensorAlgebra> - TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit = - check((i < dim) and (j < dim)) { - "Cannot transpose $i to $j for a tensor of dim $dim" - } - -public inline fun , - TorchTensorAlgebraType : TorchTensorAlgebra> - TorchTensorAlgebraType.checkView(a: TorchTensorType, shape: IntArray): Unit = - check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) - - public inline fun , TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra> TorchTensorDivisionAlgebraType.withGradAt( diff --git a/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt b/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt index a11ed9cad..9c77b8229 100644 --- a/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt +++ b/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt @@ -29,9 +29,9 @@ public sealed class TorchTensorAlgebraJVM< internal abstract fun wrap(tensorHandle: Long): TorchTensorType - override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, other) - return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle)) + override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, b) + return wrap(JTorch.timesTensor(this.tensorHandle, b.tensorHandle)) } override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit { @@ -39,9 +39,9 @@ public sealed class TorchTensorAlgebraJVM< JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, other) - return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle)) + override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, b) + return wrap(JTorch.plusTensor(this.tensorHandle, b.tensorHandle)) } override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit { @@ -49,9 +49,9 @@ public sealed class TorchTensorAlgebraJVM< JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, other) - return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle)) + override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, b) + return wrap(JTorch.minusTensor(this.tensorHandle, b.tensorHandle)) } override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit { diff --git a/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt b/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt index 3aba636e9..5f5515680 100644 --- a/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt +++ b/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt @@ -1,10 +1,10 @@ package space.kscience.kmath.torch - import space.kscience.kmath.memory.DeferScope import space.kscience.kmath.memory.withDeferScope import kotlinx.cinterop.* +import space.kscience.kmath.tensors.* import space.kscience.kmath.torch.ctorch.* public sealed class TorchTensorAlgebraNative< @@ -38,9 +38,9 @@ public sealed class TorchTensorAlgebraNative< public abstract fun fromBlob(arrayBlob: CPointer, shape: IntArray): TorchTensorType public abstract fun TorchTensorType.getData(): CPointer - override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, other) - return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!) + override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, b) + return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!) } override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit { @@ -48,9 +48,9 @@ public sealed class TorchTensorAlgebraNative< times_tensor_assign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, other) - return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!) + override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, b) + return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!) } override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit { @@ -58,9 +58,9 @@ public sealed class TorchTensorAlgebraNative< plus_tensor_assign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, other) - return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!) + override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, b) + return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!) } override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit { @@ -68,6 +68,10 @@ public sealed class TorchTensorAlgebraNative< minus_tensor_assign(this.tensorHandle, other.tensorHandle) } + override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b + + override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b + override operator fun TorchTensorType.unaryMinus(): TorchTensorType = wrap(unary_minus(this.tensorHandle)!!) @@ -254,6 +258,15 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!) + + + override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble() + + override val zero: TorchTensorReal + get() = full(0.0, IntArray(0), Device.CPU) + + override val one: TorchTensorReal + get() = full(1.0, IntArray(0), Device.CPU) } @@ -317,6 +330,15 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) : override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!) + override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat() + + override val zero: TorchTensorFloat + get() = full(0f, IntArray(0), Device.CPU) + + + override val one: TorchTensorFloat + get() = full(1f, IntArray(0), Device.CPU) + } public class TorchTensorLongAlgebra(scope: DeferScope) : @@ -372,6 +394,15 @@ public class TorchTensorLongAlgebra(scope: DeferScope) : override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong = wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!) + + override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong() + + override val zero: TorchTensorLong + get() = full(0, IntArray(0), Device.CPU) + + + override val one: TorchTensorLong + get() = full(1, IntArray(0), Device.CPU) } public class TorchTensorIntAlgebra(scope: DeferScope) : @@ -427,6 +458,16 @@ public class TorchTensorIntAlgebra(scope: DeferScope) : override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt = wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!) + + + override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt() + + override val zero: TorchTensorInt + get() = full(0, IntArray(0), Device.CPU) + + + override val one: TorchTensorInt + get() = full(1, IntArray(0), Device.CPU) }