From 86528fc943c3f69b9b64d633791354d6412d606a Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Mon, 1 Mar 2021 23:08:23 +0000 Subject: [PATCH] Cannot afford to derive from RingsWithNumbers --- .../kmath/torch/TorchTensorAlgebraJVM.kt | 51 ++++-------------- .../kmath/torch/TorchTensorAlgebraNative.kt | 54 ++++--------------- 2 files changed, 18 insertions(+), 87 deletions(-) diff --git a/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt b/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt index 7c5fca1eb..12a6c9586 100644 --- a/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt +++ b/kmath-torch/src/jvmMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraJVM.kt @@ -30,9 +30,9 @@ public sealed class TorchTensorAlgebraJVM< internal abstract fun wrap(tensorHandle: Long): TorchTensorType - override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, b) - return wrap(JTorch.timesTensor(this.tensorHandle, b.tensorHandle)) + override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle)) } override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit { @@ -40,9 +40,9 @@ public sealed class TorchTensorAlgebraJVM< JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, b) - return wrap(JTorch.plusTensor(this.tensorHandle, b.tensorHandle)) + override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle)) } override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit { @@ -50,9 +50,9 @@ public sealed class TorchTensorAlgebraJVM< JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, b) - return wrap(JTorch.minusTensor(this.tensorHandle, b.tensorHandle)) + override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle)) } override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit { @@ -60,10 +60,6 @@ public sealed class TorchTensorAlgebraJVM< JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle) } - override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b - - override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b - override operator fun TorchTensorType.unaryMinus(): TorchTensorType = wrap(JTorch.unaryMinus(this.tensorHandle)) @@ -236,13 +232,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = wrap(JTorch.fullDouble(value, shape, device.toInt())) - override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble() - - override val zero: TorchTensorReal - get() = full(0.0, IntArray(0), Device.CPU) - - override val one: TorchTensorReal - get() = full(1.0, IntArray(0), Device.CPU) } public class TorchTensorFloatAlgebra(scope: DeferScope) : @@ -295,13 +284,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) : override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = wrap(JTorch.fullFloat(value, shape, device.toInt())) - override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat() - - override val zero: TorchTensorFloat - get() = full(0f, IntArray(0), Device.CPU) - - override val one: TorchTensorFloat - get() = full(1f, IntArray(0), Device.CPU) } public class TorchTensorLongAlgebra(scope: DeferScope) : @@ -347,14 +329,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) : override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong = wrap(JTorch.fullLong(value, shape, device.toInt())) - - override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong() - - override val zero: TorchTensorLong - get() = full(0, IntArray(0), Device.CPU) - - override val one: TorchTensorLong - get() = full(1, IntArray(0), Device.CPU) } public class TorchTensorIntAlgebra(scope: DeferScope) : @@ -401,13 +375,6 @@ public class TorchTensorIntAlgebra(scope: DeferScope) : override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt = wrap(JTorch.fullInt(value, shape, device.toInt())) - override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt() - - override val zero: TorchTensorInt - get() = full(0, IntArray(0), Device.CPU) - - override val one: TorchTensorInt - get() = full(1, IntArray(0), Device.CPU) } public inline fun TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = diff --git a/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt b/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt index cf050f574..2d6a70b0f 100644 --- a/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt +++ b/kmath-torch/src/nativeMain/kotlin/space/kscience/kmath/torch/TorchTensorAlgebraNative.kt @@ -38,9 +38,9 @@ public sealed class TorchTensorAlgebraNative< public abstract fun fromBlob(arrayBlob: CPointer, shape: IntArray): TorchTensorType public abstract fun TorchTensorType.getData(): CPointer - override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, b) - return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!) + override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!) } override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit { @@ -48,9 +48,9 @@ public sealed class TorchTensorAlgebraNative< times_tensor_assign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, b) - return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!) + override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!) } override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit { @@ -58,9 +58,9 @@ public sealed class TorchTensorAlgebraNative< plus_tensor_assign(this.tensorHandle, other.tensorHandle) } - override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType { - if (checks) checkLinearOperation(this, b) - return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!) + override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!) } override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit { @@ -68,10 +68,6 @@ public sealed class TorchTensorAlgebraNative< minus_tensor_assign(this.tensorHandle, other.tensorHandle) } - override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b - - override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b - override operator fun TorchTensorType.unaryMinus(): TorchTensorType = wrap(unary_minus(this.tensorHandle)!!) @@ -258,14 +254,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!) - - override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble() - - override val zero: TorchTensorReal - get() = full(0.0, IntArray(0), Device.CPU) - - override val one: TorchTensorReal - get() = full(1.0, IntArray(0), Device.CPU) } @@ -329,14 +317,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) : override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!) - override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat() - - override val zero: TorchTensorFloat - get() = full(0f, IntArray(0), Device.CPU) - - override val one: TorchTensorFloat - get() = full(1f, IntArray(0), Device.CPU) - } public class TorchTensorLongAlgebra(scope: DeferScope) : @@ -392,14 +372,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) : override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong = wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!) - - override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong() - - override val zero: TorchTensorLong - get() = full(0, IntArray(0), Device.CPU) - - override val one: TorchTensorLong - get() = full(1, IntArray(0), Device.CPU) } public class TorchTensorIntAlgebra(scope: DeferScope) : @@ -455,14 +427,6 @@ public class TorchTensorIntAlgebra(scope: DeferScope) : override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt = wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!) - - override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt() - - override val zero: TorchTensorInt - get() = full(0, IntArray(0), Device.CPU) - - override val one: TorchTensorInt - get() = full(1, IntArray(0), Device.CPU) }