inheritance is back

This commit is contained in:
Roland Grinis 2021-07-10 15:22:15 +01:00
parent e6e117f694
commit 6d5e4a5776

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.core.TensorLinearStructure
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
protected constructor(protected val scope: NoaScope) :
TensorAlgebra<T> {
@ -36,6 +36,15 @@ protected constructor(protected val scope: NoaScope) :
override fun Tensor<T>.value(): T = tensor.item()
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): TensorType
@PerformancePitfall
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device): TensorType
public abstract fun full(value: T, shape: IntArray, device: Device): TensorType
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
}
@ -132,9 +141,9 @@ protected constructor(protected val scope: NoaScope) :
}
public sealed class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
protected constructor(scope: NoaScope) :
NoaAlgebra<T, TensorType>(scope),
NoaAlgebra<T, PrimitiveArray, TensorType>(scope),
LinearOpsTensorAlgebra<T>,
AnalyticTensorAlgebra<T> {
@ -161,6 +170,10 @@ protected constructor(scope: NoaScope) :
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
public abstract fun randNormal(shape: IntArray, device: Device): TensorType
public abstract fun randUniform(shape: IntArray, device: Device): TensorType
public fun Tensor<T>.randUniform(): TensorType =
wrap(JNoa.randLike(tensor.tensorHandle))
@ -278,7 +291,7 @@ protected constructor(scope: NoaScope) :
public sealed class NoaDoubleAlgebra
protected constructor(scope: NoaScope) :
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
NoaPartialDivisionAlgebra<Double, DoubleArray, NoaDoubleTensor>(scope) {
private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
copyFromArray(
@ -296,19 +309,19 @@ protected constructor(scope: NoaScope) :
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Double>.copyToArray(): DoubleArray =
override fun Tensor<Double>.copyToArray(): DoubleArray =
tensor.elements().map { it.second }.toList().toDoubleArray()
public fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
public fun randNormalDouble(shape: IntArray, device: Device): NoaDoubleTensor =
override fun randNormal(shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randnDouble(shape, device.toInt()))
public fun randUniformDouble(shape: IntArray, device: Device): NoaDoubleTensor =
override fun randUniform(shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randDouble(shape, device.toInt()))
public fun randDiscreteDouble(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randintDouble(low, high, shape, device.toInt()))
override operator fun Double.plus(other: Tensor<Double>): NoaDoubleTensor =
@ -347,14 +360,14 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Double>.divAssign(value: Double): Unit =
tensor.timesAssign(1 / value)
public fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fullDouble(value, shape, device.toInt()))
}
public sealed class NoaFloatAlgebra
protected constructor(scope: NoaScope) :
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
NoaPartialDivisionAlgebra<Float, FloatArray, NoaFloatTensor>(scope) {
private fun Tensor<Float>.castHelper(): NoaFloatTensor =
copyFromArray(
@ -372,19 +385,19 @@ protected constructor(scope: NoaScope) :
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Float>.copyToArray(): FloatArray =
override fun Tensor<Float>.copyToArray(): FloatArray =
tensor.elements().map { it.second }.toList().toFloatArray()
public fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
public fun randNormalFloat(shape: IntArray, device: Device): NoaFloatTensor =
override fun randNormal(shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randnFloat(shape, device.toInt()))
public fun randUniformFloat(shape: IntArray, device: Device): NoaFloatTensor =
override fun randUniform(shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randFloat(shape, device.toInt()))
public fun randDiscreteFloat(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randintFloat(low, high, shape, device.toInt()))
override operator fun Float.plus(other: Tensor<Float>): NoaFloatTensor =
@ -423,14 +436,14 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Float>.divAssign(value: Float): Unit =
tensor.timesAssign(1 / value)
public fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fullFloat(value, shape, device.toInt()))
}
public sealed class NoaLongAlgebra
protected constructor(scope: NoaScope) :
NoaAlgebra<Long, NoaLongTensor>(scope) {
NoaAlgebra<Long, LongArray, NoaLongTensor>(scope) {
private fun Tensor<Long>.castHelper(): NoaLongTensor =
copyFromArray(
@ -448,13 +461,13 @@ protected constructor(scope: NoaScope) :
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Long>.copyToArray(): LongArray =
override fun Tensor<Long>.copyToArray(): LongArray =
tensor.elements().map { it.second }.toList().toLongArray()
public fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
public fun randDiscreteLong(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.randintLong(low, high, shape, device.toInt()))
override operator fun Long.plus(other: Tensor<Long>): NoaLongTensor =
@ -484,14 +497,14 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Long>.timesAssign(value: Long): Unit =
JNoa.timesLongAssign(value, tensor.tensorHandle)
public fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fullLong(value, shape, device.toInt()))
}
public sealed class NoaIntAlgebra
protected constructor(scope: NoaScope) :
NoaAlgebra<Int, NoaIntTensor>(scope) {
NoaAlgebra<Int, IntArray, NoaIntTensor>(scope) {
private fun Tensor<Int>.castHelper(): NoaIntTensor =
copyFromArray(
@ -509,13 +522,13 @@ protected constructor(scope: NoaScope) :
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Int>.copyToArray(): IntArray =
override fun Tensor<Int>.copyToArray(): IntArray =
tensor.elements().map { it.second }.toList().toIntArray()
public fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
public fun randDiscreteInt(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.randintInt(low, high, shape, device.toInt()))
override operator fun Int.plus(other: Tensor<Int>): NoaIntTensor =
@ -545,7 +558,7 @@ protected constructor(scope: NoaScope) :
override fun Tensor<Int>.timesAssign(value: Int): Unit =
JNoa.timesIntAssign(value, tensor.tensorHandle)
public fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fullInt(value, shape, device.toInt()))
}