forked from kscience/kmath
Inefficient implementations in RingWithNumbers
This commit is contained in:
parent
6594ffc965
commit
0a4e7acb4c
@ -2,6 +2,7 @@ package space.kscience.kmath.torch
|
|||||||
|
|
||||||
import space.kscience.kmath.memory.DeferScope
|
import space.kscience.kmath.memory.DeferScope
|
||||||
import space.kscience.kmath.memory.withDeferScope
|
import space.kscience.kmath.memory.withDeferScope
|
||||||
|
import space.kscience.kmath.tensors.*
|
||||||
|
|
||||||
public sealed class TorchTensorAlgebraJVM<
|
public sealed class TorchTensorAlgebraJVM<
|
||||||
T,
|
T,
|
||||||
@ -59,6 +60,10 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
|
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
|
||||||
|
|
||||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
wrap(JTorch.unaryMinus(this.tensorHandle))
|
wrap(JTorch.unaryMinus(this.tensorHandle))
|
||||||
|
|
||||||
@ -230,6 +235,14 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
wrap(JTorch.fullDouble(value, shape, device.toInt()))
|
wrap(JTorch.fullDouble(value, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
||||||
|
|
||||||
|
override val zero: TorchTensorReal
|
||||||
|
get() = full(0.0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
override val one: TorchTensorReal
|
||||||
|
get() = full(1.0, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||||
@ -281,6 +294,14 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
wrap(JTorch.fullFloat(value, shape, device.toInt()))
|
wrap(JTorch.fullFloat(value, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
|
||||||
|
|
||||||
|
override val zero: TorchTensorFloat
|
||||||
|
get() = full(0f, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
override val one: TorchTensorFloat
|
||||||
|
get() = full(1f, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
@ -326,6 +347,14 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
wrap(JTorch.fullLong(value, shape, device.toInt()))
|
wrap(JTorch.fullLong(value, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
|
||||||
|
|
||||||
|
override val zero: TorchTensorLong
|
||||||
|
get() = full(0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
override val one: TorchTensorLong
|
||||||
|
get() = full(1, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
@ -371,6 +400,14 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
wrap(JTorch.fullInt(value, shape, device.toInt()))
|
wrap(JTorch.fullInt(value, shape, device.toInt()))
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
||||||
|
|
||||||
|
override val zero: TorchTensorInt
|
||||||
|
get() = full(0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
override val one: TorchTensorInt
|
||||||
|
get() = full(1, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
|
@ -259,7 +259,6 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
||||||
|
|
||||||
override val zero: TorchTensorReal
|
override val zero: TorchTensorReal
|
||||||
@ -335,7 +334,6 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
|||||||
override val zero: TorchTensorFloat
|
override val zero: TorchTensorFloat
|
||||||
get() = full(0f, IntArray(0), Device.CPU)
|
get() = full(0f, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
|
||||||
override val one: TorchTensorFloat
|
override val one: TorchTensorFloat
|
||||||
get() = full(1f, IntArray(0), Device.CPU)
|
get() = full(1f, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
@ -400,7 +398,6 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
|||||||
override val zero: TorchTensorLong
|
override val zero: TorchTensorLong
|
||||||
get() = full(0, IntArray(0), Device.CPU)
|
get() = full(0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
|
||||||
override val one: TorchTensorLong
|
override val one: TorchTensorLong
|
||||||
get() = full(1, IntArray(0), Device.CPU)
|
get() = full(1, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
@ -459,13 +456,11 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
|||||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
|
||||||
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
||||||
|
|
||||||
override val zero: TorchTensorInt
|
override val zero: TorchTensorInt
|
||||||
get() = full(0, IntArray(0), Device.CPU)
|
get() = full(0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
|
||||||
override val one: TorchTensorInt
|
override val one: TorchTensorInt
|
||||||
get() = full(1, IntArray(0), Device.CPU)
|
get() = full(1, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user