Cannot afford to derive from RingsWithNumbers

This commit is contained in:
Roland Grinis 2021-03-01 23:08:23 +00:00
parent 3e3d433a64
commit 86528fc943
2 changed files with 18 additions and 87 deletions

View File

@ -30,9 +30,9 @@ public sealed class TorchTensorAlgebraJVM<
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, b)
return wrap(JTorch.timesTensor(this.tensorHandle, b.tensorHandle))
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle))
}
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
@ -40,9 +40,9 @@ public sealed class TorchTensorAlgebraJVM<
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
}
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, b)
return wrap(JTorch.plusTensor(this.tensorHandle, b.tensorHandle))
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle))
}
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
@ -50,9 +50,9 @@ public sealed class TorchTensorAlgebraJVM<
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
}
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, b)
return wrap(JTorch.minusTensor(this.tensorHandle, b.tensorHandle))
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle))
}
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
@ -60,10 +60,6 @@ public sealed class TorchTensorAlgebraJVM<
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
}
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(JTorch.unaryMinus(this.tensorHandle))
@ -236,13 +232,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(JTorch.fullDouble(value, shape, device.toInt()))
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
override val zero: TorchTensorReal
get() = full(0.0, IntArray(0), Device.CPU)
override val one: TorchTensorReal
get() = full(1.0, IntArray(0), Device.CPU)
}
public class TorchTensorFloatAlgebra(scope: DeferScope) :
@ -295,13 +284,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(JTorch.fullFloat(value, shape, device.toInt()))
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
override val zero: TorchTensorFloat
get() = full(0f, IntArray(0), Device.CPU)
override val one: TorchTensorFloat
get() = full(1f, IntArray(0), Device.CPU)
}
public class TorchTensorLongAlgebra(scope: DeferScope) :
@ -347,14 +329,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(JTorch.fullLong(value, shape, device.toInt()))
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
override val zero: TorchTensorLong
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorLong
get() = full(1, IntArray(0), Device.CPU)
}
public class TorchTensorIntAlgebra(scope: DeferScope) :
@ -401,13 +375,6 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(JTorch.fullInt(value, shape, device.toInt()))
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
override val zero: TorchTensorInt
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorInt
get() = full(1, IntArray(0), Device.CPU)
}
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =

View File

@ -38,9 +38,9 @@ public sealed class TorchTensorAlgebraNative<
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
public abstract fun TorchTensorType.getData(): CPointer<TVar>
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, b)
return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!)
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
}
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
@ -48,9 +48,9 @@ public sealed class TorchTensorAlgebraNative<
times_tensor_assign(this.tensorHandle, other.tensorHandle)
}
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, b)
return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!)
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
}
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
@ -58,9 +58,9 @@ public sealed class TorchTensorAlgebraNative<
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
}
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, b)
return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!)
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
}
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
@ -68,10 +68,6 @@ public sealed class TorchTensorAlgebraNative<
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
}
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(unary_minus(this.tensorHandle)!!)
@ -258,14 +254,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
override val zero: TorchTensorReal
get() = full(0.0, IntArray(0), Device.CPU)
override val one: TorchTensorReal
get() = full(1.0, IntArray(0), Device.CPU)
}
@ -329,14 +317,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
override val zero: TorchTensorFloat
get() = full(0f, IntArray(0), Device.CPU)
override val one: TorchTensorFloat
get() = full(1f, IntArray(0), Device.CPU)
}
public class TorchTensorLongAlgebra(scope: DeferScope) :
@ -392,14 +372,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
override val zero: TorchTensorLong
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorLong
get() = full(1, IntArray(0), Device.CPU)
}
public class TorchTensorIntAlgebra(scope: DeferScope) :
@ -455,14 +427,6 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
override val zero: TorchTensorInt
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorInt
get() = full(1, IntArray(0), Device.CPU)
}