forked from kscience/kmath
Cannot afford to derive from RingsWithNumbers
This commit is contained in:
parent
3e3d433a64
commit
86528fc943
@ -30,9 +30,9 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
|
|
||||||
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
|
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
|
||||||
|
|
||||||
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, b)
|
if (checks) checkLinearOperation(this, other)
|
||||||
return wrap(JTorch.timesTensor(this.tensorHandle, b.tensorHandle))
|
return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle))
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
@ -40,9 +40,9 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
|
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, b)
|
if (checks) checkLinearOperation(this, other)
|
||||||
return wrap(JTorch.plusTensor(this.tensorHandle, b.tensorHandle))
|
return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle))
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
@ -50,9 +50,9 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
|
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, b)
|
if (checks) checkLinearOperation(this, other)
|
||||||
return wrap(JTorch.minusTensor(this.tensorHandle, b.tensorHandle))
|
return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle))
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
@ -60,10 +60,6 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
|
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
|
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
|
|
||||||
|
|
||||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
wrap(JTorch.unaryMinus(this.tensorHandle))
|
wrap(JTorch.unaryMinus(this.tensorHandle))
|
||||||
|
|
||||||
@ -236,13 +232,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
wrap(JTorch.fullDouble(value, shape, device.toInt()))
|
wrap(JTorch.fullDouble(value, shape, device.toInt()))
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
|
||||||
|
|
||||||
override val zero: TorchTensorReal
|
|
||||||
get() = full(0.0, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorReal
|
|
||||||
get() = full(1.0, IntArray(0), Device.CPU)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||||
@ -295,13 +284,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
|||||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
wrap(JTorch.fullFloat(value, shape, device.toInt()))
|
wrap(JTorch.fullFloat(value, shape, device.toInt()))
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
|
|
||||||
|
|
||||||
override val zero: TorchTensorFloat
|
|
||||||
get() = full(0f, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorFloat
|
|
||||||
get() = full(1f, IntArray(0), Device.CPU)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
@ -347,14 +329,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
wrap(JTorch.fullLong(value, shape, device.toInt()))
|
wrap(JTorch.fullLong(value, shape, device.toInt()))
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
|
|
||||||
|
|
||||||
override val zero: TorchTensorLong
|
|
||||||
get() = full(0, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorLong
|
|
||||||
get() = full(1, IntArray(0), Device.CPU)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
@ -401,13 +375,6 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
|||||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
wrap(JTorch.fullInt(value, shape, device.toInt()))
|
wrap(JTorch.fullInt(value, shape, device.toInt()))
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
|
||||||
|
|
||||||
override val zero: TorchTensorInt
|
|
||||||
get() = full(0, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorInt
|
|
||||||
get() = full(1, IntArray(0), Device.CPU)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
|
@ -38,9 +38,9 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
||||||
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
||||||
|
|
||||||
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, b)
|
if (checks) checkLinearOperation(this, other)
|
||||||
return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!)
|
return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
@ -48,9 +48,9 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, b)
|
if (checks) checkLinearOperation(this, other)
|
||||||
return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
@ -58,9 +58,9 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, b)
|
if (checks) checkLinearOperation(this, other)
|
||||||
return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
@ -68,10 +68,6 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
|
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
|
|
||||||
|
|
||||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
wrap(unary_minus(this.tensorHandle)!!)
|
wrap(unary_minus(this.tensorHandle)!!)
|
||||||
|
|
||||||
@ -258,14 +254,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
|
||||||
|
|
||||||
override val zero: TorchTensorReal
|
|
||||||
get() = full(0.0, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorReal
|
|
||||||
get() = full(1.0, IntArray(0), Device.CPU)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -329,14 +317,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
|||||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
|
|
||||||
|
|
||||||
override val zero: TorchTensorFloat
|
|
||||||
get() = full(0f, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorFloat
|
|
||||||
get() = full(1f, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
@ -392,14 +372,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
|
|
||||||
|
|
||||||
override val zero: TorchTensorLong
|
|
||||||
get() = full(0, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorLong
|
|
||||||
get() = full(1, IntArray(0), Device.CPU)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
@ -455,14 +427,6 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
|
||||||
|
|
||||||
override val zero: TorchTensorInt
|
|
||||||
get() = full(0, IntArray(0), Device.CPU)
|
|
||||||
|
|
||||||
override val one: TorchTensorInt
|
|
||||||
get() = full(1, IntArray(0), Device.CPU)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user