inheritance is back

This commit is contained in:
Roland Grinis 2021-07-10 15:22:15 +01:00
parent e6e117f694
commit 6d5e4a5776

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.core.TensorLinearStructure import space.kscience.kmath.tensors.core.TensorLinearStructure
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>> public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
protected constructor(protected val scope: NoaScope) : protected constructor(protected val scope: NoaScope) :
TensorAlgebra<T> { TensorAlgebra<T> {
@ -36,6 +36,15 @@ protected constructor(protected val scope: NoaScope) :
override fun Tensor<T>.value(): T = tensor.item() override fun Tensor<T>.value(): T = tensor.item()
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): TensorType
@PerformancePitfall
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device): TensorType
public abstract fun full(value: T, shape: IntArray, device: Device): TensorType
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType { override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle)) return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
} }
@ -132,9 +141,9 @@ protected constructor(protected val scope: NoaScope) :
} }
public sealed class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>> public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
protected constructor(scope: NoaScope) : protected constructor(scope: NoaScope) :
NoaAlgebra<T, TensorType>(scope), NoaAlgebra<T, PrimitiveArray, TensorType>(scope),
LinearOpsTensorAlgebra<T>, LinearOpsTensorAlgebra<T>,
AnalyticTensorAlgebra<T> { AnalyticTensorAlgebra<T> {
@ -161,6 +170,10 @@ protected constructor(scope: NoaScope) :
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType = override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim)) wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
public abstract fun randNormal(shape: IntArray, device: Device): TensorType
public abstract fun randUniform(shape: IntArray, device: Device): TensorType
public fun Tensor<T>.randUniform(): TensorType = public fun Tensor<T>.randUniform(): TensorType =
wrap(JNoa.randLike(tensor.tensorHandle)) wrap(JNoa.randLike(tensor.tensorHandle))
@ -278,7 +291,7 @@ protected constructor(scope: NoaScope) :
public sealed class NoaDoubleAlgebra public sealed class NoaDoubleAlgebra
protected constructor(scope: NoaScope) : protected constructor(scope: NoaScope) :
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) { NoaPartialDivisionAlgebra<Double, DoubleArray, NoaDoubleTensor>(scope) {
private fun Tensor<Double>.castHelper(): NoaDoubleTensor = private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
copyFromArray( copyFromArray(
@ -296,19 +309,19 @@ protected constructor(scope: NoaScope) :
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle) NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall @PerformancePitfall
public fun Tensor<Double>.copyToArray(): DoubleArray = override fun Tensor<Double>.copyToArray(): DoubleArray =
tensor.elements().map { it.second }.toList().toDoubleArray() tensor.elements().map { it.second }.toList().toDoubleArray()
public fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor = override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fromBlobDouble(array, shape, device.toInt())) wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
public fun randNormalDouble(shape: IntArray, device: Device): NoaDoubleTensor = override fun randNormal(shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randnDouble(shape, device.toInt())) wrap(JNoa.randnDouble(shape, device.toInt()))
public fun randUniformDouble(shape: IntArray, device: Device): NoaDoubleTensor = override fun randUniform(shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randDouble(shape, device.toInt())) wrap(JNoa.randDouble(shape, device.toInt()))
public fun randDiscreteDouble(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor = override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randintDouble(low, high, shape, device.toInt())) wrap(JNoa.randintDouble(low, high, shape, device.toInt()))
override operator fun Double.plus(other: Tensor<Double>): NoaDoubleTensor = override operator fun Double.plus(other: Tensor<Double>): NoaDoubleTensor =
@ -347,14 +360,14 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Double>.divAssign(value: Double): Unit = override fun Tensor<Double>.divAssign(value: Double): Unit =
tensor.timesAssign(1 / value) tensor.timesAssign(1 / value)
public fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor = override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fullDouble(value, shape, device.toInt())) wrap(JNoa.fullDouble(value, shape, device.toInt()))
} }
public sealed class NoaFloatAlgebra public sealed class NoaFloatAlgebra
protected constructor(scope: NoaScope) : protected constructor(scope: NoaScope) :
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) { NoaPartialDivisionAlgebra<Float, FloatArray, NoaFloatTensor>(scope) {
private fun Tensor<Float>.castHelper(): NoaFloatTensor = private fun Tensor<Float>.castHelper(): NoaFloatTensor =
copyFromArray( copyFromArray(
@ -372,19 +385,19 @@ protected constructor(scope: NoaScope) :
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle) NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall @PerformancePitfall
public fun Tensor<Float>.copyToArray(): FloatArray = override fun Tensor<Float>.copyToArray(): FloatArray =
tensor.elements().map { it.second }.toList().toFloatArray() tensor.elements().map { it.second }.toList().toFloatArray()
public fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor = override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fromBlobFloat(array, shape, device.toInt())) wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
public fun randNormalFloat(shape: IntArray, device: Device): NoaFloatTensor = override fun randNormal(shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randnFloat(shape, device.toInt())) wrap(JNoa.randnFloat(shape, device.toInt()))
public fun randUniformFloat(shape: IntArray, device: Device): NoaFloatTensor = override fun randUniform(shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randFloat(shape, device.toInt())) wrap(JNoa.randFloat(shape, device.toInt()))
public fun randDiscreteFloat(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor = override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randintFloat(low, high, shape, device.toInt())) wrap(JNoa.randintFloat(low, high, shape, device.toInt()))
override operator fun Float.plus(other: Tensor<Float>): NoaFloatTensor = override operator fun Float.plus(other: Tensor<Float>): NoaFloatTensor =
@ -423,14 +436,14 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Float>.divAssign(value: Float): Unit = override fun Tensor<Float>.divAssign(value: Float): Unit =
tensor.timesAssign(1 / value) tensor.timesAssign(1 / value)
public fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor = override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fullFloat(value, shape, device.toInt())) wrap(JNoa.fullFloat(value, shape, device.toInt()))
} }
public sealed class NoaLongAlgebra public sealed class NoaLongAlgebra
protected constructor(scope: NoaScope) : protected constructor(scope: NoaScope) :
NoaAlgebra<Long, NoaLongTensor>(scope) { NoaAlgebra<Long, LongArray, NoaLongTensor>(scope) {
private fun Tensor<Long>.castHelper(): NoaLongTensor = private fun Tensor<Long>.castHelper(): NoaLongTensor =
copyFromArray( copyFromArray(
@ -448,13 +461,13 @@ protected constructor(scope: NoaScope) :
NoaLongTensor(scope = scope, tensorHandle = tensorHandle) NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall @PerformancePitfall
public fun Tensor<Long>.copyToArray(): LongArray = override fun Tensor<Long>.copyToArray(): LongArray =
tensor.elements().map { it.second }.toList().toLongArray() tensor.elements().map { it.second }.toList().toLongArray()
public fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor = override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fromBlobLong(array, shape, device.toInt())) wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
public fun randDiscreteLong(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor = override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.randintLong(low, high, shape, device.toInt())) wrap(JNoa.randintLong(low, high, shape, device.toInt()))
override operator fun Long.plus(other: Tensor<Long>): NoaLongTensor = override operator fun Long.plus(other: Tensor<Long>): NoaLongTensor =
@ -484,14 +497,14 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Long>.timesAssign(value: Long): Unit = override fun Tensor<Long>.timesAssign(value: Long): Unit =
JNoa.timesLongAssign(value, tensor.tensorHandle) JNoa.timesLongAssign(value, tensor.tensorHandle)
public fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor = override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fullLong(value, shape, device.toInt())) wrap(JNoa.fullLong(value, shape, device.toInt()))
} }
public sealed class NoaIntAlgebra public sealed class NoaIntAlgebra
protected constructor(scope: NoaScope) : protected constructor(scope: NoaScope) :
NoaAlgebra<Int, NoaIntTensor>(scope) { NoaAlgebra<Int, IntArray, NoaIntTensor>(scope) {
private fun Tensor<Int>.castHelper(): NoaIntTensor = private fun Tensor<Int>.castHelper(): NoaIntTensor =
copyFromArray( copyFromArray(
@ -509,13 +522,13 @@ protected constructor(scope: NoaScope) :
NoaIntTensor(scope = scope, tensorHandle = tensorHandle) NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall @PerformancePitfall
public fun Tensor<Int>.copyToArray(): IntArray = override fun Tensor<Int>.copyToArray(): IntArray =
tensor.elements().map { it.second }.toList().toIntArray() tensor.elements().map { it.second }.toList().toIntArray()
public fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor = override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fromBlobInt(array, shape, device.toInt())) wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
public fun randDiscreteInt(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor = override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.randintInt(low, high, shape, device.toInt())) wrap(JNoa.randintInt(low, high, shape, device.toInt()))
override operator fun Int.plus(other: Tensor<Int>): NoaIntTensor = override operator fun Int.plus(other: Tensor<Int>): NoaIntTensor =
@ -545,7 +558,7 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Int>.timesAssign(value: Int): Unit = override fun Tensor<Int>.timesAssign(value: Int): Unit =
JNoa.timesIntAssign(value, tensor.tensorHandle) JNoa.timesIntAssign(value, tensor.tensorHandle)
public fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor = override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fullInt(value, shape, device.toInt())) wrap(JNoa.fullInt(value, shape, device.toInt()))
} }