diff --git a/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt b/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt index 9c77b8229..7c5fca1eb 100644 --- a/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt +++ b/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt @@ -2,6 +2,7 @@ package space.kscience.kmath.torch import space.kscience.kmath.memory.DeferScope import space.kscience.kmath.memory.withDeferScope +import space.kscience.kmath.tensors.* public sealed class TorchTensorAlgebraJVM< T, @@ -59,6 +60,10 @@ public sealed class TorchTensorAlgebraJVM< JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle) } + override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b + + override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b + override operator fun TorchTensorType.unaryMinus(): TorchTensorType = wrap(JTorch.unaryMinus(this.tensorHandle)) @@ -230,6 +235,14 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = wrap(JTorch.fullDouble(value, shape, device.toInt())) + + override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble() + + override val zero: TorchTensorReal + get() = full(0.0, IntArray(0), Device.CPU) + + override val one: TorchTensorReal + get() = full(1.0, IntArray(0), Device.CPU) } public class TorchTensorFloatAlgebra(scope: DeferScope) : @@ -281,6 +294,14 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) : override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = wrap(JTorch.fullFloat(value, shape, device.toInt())) + + override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat() + + override val zero: TorchTensorFloat + get() = full(0f, IntArray(0), Device.CPU) + + override val one: TorchTensorFloat + get() = full(1f, IntArray(0), Device.CPU) } public class TorchTensorLongAlgebra(scope: DeferScope) : @@ -326,6 +347,14 @@ public class TorchTensorLongAlgebra(scope: DeferScope) : override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong = wrap(JTorch.fullLong(value, shape, device.toInt())) + + override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong() + + override val zero: TorchTensorLong + get() = full(0, IntArray(0), Device.CPU) + + override val one: TorchTensorLong + get() = full(1, IntArray(0), Device.CPU) } public class TorchTensorIntAlgebra(scope: DeferScope) : @@ -371,6 +400,14 @@ public class TorchTensorIntAlgebra(scope: DeferScope) : override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt = wrap(JTorch.fullInt(value, shape, device.toInt())) + + override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt() + + override val zero: TorchTensorInt + get() = full(0, IntArray(0), Device.CPU) + + override val one: TorchTensorInt + get() = full(1, IntArray(0), Device.CPU) } public inline fun TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = diff --git a/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt b/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt index 5f5515680..cf050f574 100644 --- a/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt +++ b/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt @@ -259,7 +259,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!) - override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble() override val zero: TorchTensorReal @@ -335,7 +334,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) : override val zero: TorchTensorFloat get() = full(0f, IntArray(0), Device.CPU) - override val one: TorchTensorFloat get() = full(1f, IntArray(0), Device.CPU) @@ -400,7 +398,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) : override val zero: TorchTensorLong get() = full(0, IntArray(0), Device.CPU) - override val one: TorchTensorLong get() = full(1, IntArray(0), Device.CPU) } @@ -459,13 +456,11 @@ public class TorchTensorIntAlgebra(scope: DeferScope) : override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt = wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!) - override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt() override val zero: TorchTensorInt get() = full(0, IntArray(0), Device.CPU) - override val one: TorchTensorInt get() = full(1, IntArray(0), Device.CPU) }