Inefficient implementations in RingWithNumbers

This commit is contained in:
Roland Grinis 2021-03-01 22:44:46 +00:00
parent 6594ffc965
commit 0a4e7acb4c
2 changed files with 37 additions and 5 deletions

View File

@ -2,6 +2,7 @@ package space.kscience.kmath.torch
import space.kscience.kmath.memory.DeferScope
import space.kscience.kmath.memory.withDeferScope
import space.kscience.kmath.tensors.*
public sealed class TorchTensorAlgebraJVM<
T,
@ -59,6 +60,10 @@ public sealed class TorchTensorAlgebraJVM<
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
}
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(JTorch.unaryMinus(this.tensorHandle))
@ -230,6 +235,14 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(JTorch.fullDouble(value, shape, device.toInt()))
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
override val zero: TorchTensorReal
get() = full(0.0, IntArray(0), Device.CPU)
override val one: TorchTensorReal
get() = full(1.0, IntArray(0), Device.CPU)
}
public class TorchTensorFloatAlgebra(scope: DeferScope) :
@ -281,6 +294,14 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(JTorch.fullFloat(value, shape, device.toInt()))
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
override val zero: TorchTensorFloat
get() = full(0f, IntArray(0), Device.CPU)
override val one: TorchTensorFloat
get() = full(1f, IntArray(0), Device.CPU)
}
public class TorchTensorLongAlgebra(scope: DeferScope) :
@ -326,6 +347,14 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(JTorch.fullLong(value, shape, device.toInt()))
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
override val zero: TorchTensorLong
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorLong
get() = full(1, IntArray(0), Device.CPU)
}
public class TorchTensorIntAlgebra(scope: DeferScope) :
@ -371,6 +400,14 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(JTorch.fullInt(value, shape, device.toInt()))
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
override val zero: TorchTensorInt
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorInt
get() = full(1, IntArray(0), Device.CPU)
}
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =

View File

@ -259,7 +259,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
override val zero: TorchTensorReal
@ -335,7 +334,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override val zero: TorchTensorFloat
get() = full(0f, IntArray(0), Device.CPU)
override val one: TorchTensorFloat
get() = full(1f, IntArray(0), Device.CPU)
@ -400,7 +398,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override val zero: TorchTensorLong
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorLong
get() = full(1, IntArray(0), Device.CPU)
}
@ -459,13 +456,11 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
override val zero: TorchTensorInt
get() = full(0, IntArray(0), Device.CPU)
override val one: TorchTensorInt
get() = full(1, IntArray(0), Device.CPU)
}