inheritance is back
This commit is contained in:
parent
e6e117f694
commit
6d5e4a5776
@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||
|
||||
|
||||
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
||||
public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||
protected constructor(protected val scope: NoaScope) :
|
||||
TensorAlgebra<T> {
|
||||
|
||||
@ -36,6 +36,15 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
override fun Tensor<T>.value(): T = tensor.item()
|
||||
|
||||
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): TensorType
|
||||
|
||||
@PerformancePitfall
|
||||
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
|
||||
|
||||
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device): TensorType
|
||||
|
||||
public abstract fun full(value: T, shape: IntArray, device: Device): TensorType
|
||||
|
||||
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||
}
|
||||
@ -132,9 +141,9 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
||||
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||
protected constructor(scope: NoaScope) :
|
||||
NoaAlgebra<T, TensorType>(scope),
|
||||
NoaAlgebra<T, PrimitiveArray, TensorType>(scope),
|
||||
LinearOpsTensorAlgebra<T>,
|
||||
AnalyticTensorAlgebra<T> {
|
||||
|
||||
@ -161,6 +170,10 @@ protected constructor(scope: NoaScope) :
|
||||
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
|
||||
|
||||
public abstract fun randNormal(shape: IntArray, device: Device): TensorType
|
||||
|
||||
public abstract fun randUniform(shape: IntArray, device: Device): TensorType
|
||||
|
||||
public fun Tensor<T>.randUniform(): TensorType =
|
||||
wrap(JNoa.randLike(tensor.tensorHandle))
|
||||
|
||||
@ -278,7 +291,7 @@ protected constructor(scope: NoaScope) :
|
||||
|
||||
public sealed class NoaDoubleAlgebra
|
||||
protected constructor(scope: NoaScope) :
|
||||
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
|
||||
NoaPartialDivisionAlgebra<Double, DoubleArray, NoaDoubleTensor>(scope) {
|
||||
|
||||
private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
|
||||
copyFromArray(
|
||||
@ -296,19 +309,19 @@ protected constructor(scope: NoaScope) :
|
||||
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
public fun Tensor<Double>.copyToArray(): DoubleArray =
|
||||
override fun Tensor<Double>.copyToArray(): DoubleArray =
|
||||
tensor.elements().map { it.second }.toList().toDoubleArray()
|
||||
|
||||
public fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
|
||||
|
||||
public fun randNormalDouble(shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
override fun randNormal(shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.randnDouble(shape, device.toInt()))
|
||||
|
||||
public fun randUniformDouble(shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
override fun randUniform(shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.randDouble(shape, device.toInt()))
|
||||
|
||||
public fun randDiscreteDouble(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.randintDouble(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Double.plus(other: Tensor<Double>): NoaDoubleTensor =
|
||||
@ -347,14 +360,14 @@ protected constructor(scope: NoaScope) :
|
||||
override fun Tensor<Double>.divAssign(value: Double): Unit =
|
||||
tensor.timesAssign(1 / value)
|
||||
|
||||
public fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.fullDouble(value, shape, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaFloatAlgebra
|
||||
protected constructor(scope: NoaScope) :
|
||||
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
|
||||
NoaPartialDivisionAlgebra<Float, FloatArray, NoaFloatTensor>(scope) {
|
||||
|
||||
private fun Tensor<Float>.castHelper(): NoaFloatTensor =
|
||||
copyFromArray(
|
||||
@ -372,19 +385,19 @@ protected constructor(scope: NoaScope) :
|
||||
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
public fun Tensor<Float>.copyToArray(): FloatArray =
|
||||
override fun Tensor<Float>.copyToArray(): FloatArray =
|
||||
tensor.elements().map { it.second }.toList().toFloatArray()
|
||||
|
||||
public fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
|
||||
|
||||
public fun randNormalFloat(shape: IntArray, device: Device): NoaFloatTensor =
|
||||
override fun randNormal(shape: IntArray, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.randnFloat(shape, device.toInt()))
|
||||
|
||||
public fun randUniformFloat(shape: IntArray, device: Device): NoaFloatTensor =
|
||||
override fun randUniform(shape: IntArray, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.randFloat(shape, device.toInt()))
|
||||
|
||||
public fun randDiscreteFloat(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.randintFloat(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Float.plus(other: Tensor<Float>): NoaFloatTensor =
|
||||
@ -423,14 +436,14 @@ protected constructor(scope: NoaScope) :
|
||||
override fun Tensor<Float>.divAssign(value: Float): Unit =
|
||||
tensor.timesAssign(1 / value)
|
||||
|
||||
public fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.fullFloat(value, shape, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaLongAlgebra
|
||||
protected constructor(scope: NoaScope) :
|
||||
NoaAlgebra<Long, NoaLongTensor>(scope) {
|
||||
NoaAlgebra<Long, LongArray, NoaLongTensor>(scope) {
|
||||
|
||||
private fun Tensor<Long>.castHelper(): NoaLongTensor =
|
||||
copyFromArray(
|
||||
@ -448,13 +461,13 @@ protected constructor(scope: NoaScope) :
|
||||
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
public fun Tensor<Long>.copyToArray(): LongArray =
|
||||
override fun Tensor<Long>.copyToArray(): LongArray =
|
||||
tensor.elements().map { it.second }.toList().toLongArray()
|
||||
|
||||
public fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
|
||||
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
|
||||
wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
|
||||
|
||||
public fun randDiscreteLong(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
|
||||
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
|
||||
wrap(JNoa.randintLong(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Long.plus(other: Tensor<Long>): NoaLongTensor =
|
||||
@ -484,14 +497,14 @@ protected constructor(scope: NoaScope) :
|
||||
override fun Tensor<Long>.timesAssign(value: Long): Unit =
|
||||
JNoa.timesLongAssign(value, tensor.tensorHandle)
|
||||
|
||||
public fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
|
||||
override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
|
||||
wrap(JNoa.fullLong(value, shape, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaIntAlgebra
|
||||
protected constructor(scope: NoaScope) :
|
||||
NoaAlgebra<Int, NoaIntTensor>(scope) {
|
||||
NoaAlgebra<Int, IntArray, NoaIntTensor>(scope) {
|
||||
|
||||
private fun Tensor<Int>.castHelper(): NoaIntTensor =
|
||||
copyFromArray(
|
||||
@ -509,13 +522,13 @@ protected constructor(scope: NoaScope) :
|
||||
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
public fun Tensor<Int>.copyToArray(): IntArray =
|
||||
override fun Tensor<Int>.copyToArray(): IntArray =
|
||||
tensor.elements().map { it.second }.toList().toIntArray()
|
||||
|
||||
public fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
|
||||
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
|
||||
wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
|
||||
|
||||
public fun randDiscreteInt(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
|
||||
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
|
||||
wrap(JNoa.randintInt(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Int.plus(other: Tensor<Int>): NoaIntTensor =
|
||||
@ -545,7 +558,7 @@ protected constructor(scope: NoaScope) :
|
||||
override fun Tensor<Int>.timesAssign(value: Int): Unit =
|
||||
JNoa.timesIntAssign(value, tensor.tensorHandle)
|
||||
|
||||
public fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
|
||||
override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
|
||||
wrap(JNoa.fullInt(value, shape, device.toInt()))
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user