From 6d5e4a5776ff0626df97c2a9625f0e15bee07053 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Sat, 10 Jul 2021 15:22:15 +0100 Subject: [PATCH] inheritance is back --- .../space/kscience/kmath/noa/algebras.kt | 67 +++++++++++-------- 1 file changed, 40 insertions(+), 27 deletions(-) diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index f92acdebd..df010b7e8 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.core.TensorLinearStructure -public sealed class NoaAlgebra> +public sealed class NoaAlgebra> protected constructor(protected val scope: NoaScope) : TensorAlgebra { @@ -36,6 +36,15 @@ protected constructor(protected val scope: NoaScope) : override fun Tensor.value(): T = tensor.item() + public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): TensorType + + @PerformancePitfall + public abstract fun Tensor.copyToArray(): PrimitiveArray + + public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device): TensorType + + public abstract fun full(value: T, shape: IntArray, device: Device): TensorType + override operator fun Tensor.times(other: Tensor): TensorType { return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle)) } @@ -132,9 +141,9 @@ protected constructor(protected val scope: NoaScope) : } -public sealed class NoaPartialDivisionAlgebra> +public sealed class NoaPartialDivisionAlgebra> protected constructor(scope: NoaScope) : - NoaAlgebra(scope), + NoaAlgebra(scope), LinearOpsTensorAlgebra, AnalyticTensorAlgebra { @@ -161,6 +170,10 @@ protected constructor(scope: NoaScope) : override fun Tensor.variance(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim)) + public abstract fun randNormal(shape: IntArray, device: Device): TensorType + + public abstract fun randUniform(shape: IntArray, device: Device): TensorType + public fun Tensor.randUniform(): TensorType = wrap(JNoa.randLike(tensor.tensorHandle)) @@ -278,7 +291,7 @@ protected constructor(scope: NoaScope) : public sealed class NoaDoubleAlgebra protected constructor(scope: NoaScope) : - NoaPartialDivisionAlgebra(scope) { + NoaPartialDivisionAlgebra(scope) { private fun Tensor.castHelper(): NoaDoubleTensor = copyFromArray( @@ -296,19 +309,19 @@ protected constructor(scope: NoaScope) : NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle) @PerformancePitfall - public fun Tensor.copyToArray(): DoubleArray = + override fun Tensor.copyToArray(): DoubleArray = tensor.elements().map { it.second }.toList().toDoubleArray() - public fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor = + override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor = wrap(JNoa.fromBlobDouble(array, shape, device.toInt())) - public fun randNormalDouble(shape: IntArray, device: Device): NoaDoubleTensor = + override fun randNormal(shape: IntArray, device: Device): NoaDoubleTensor = wrap(JNoa.randnDouble(shape, device.toInt())) - public fun randUniformDouble(shape: IntArray, device: Device): NoaDoubleTensor = + override fun randUniform(shape: IntArray, device: Device): NoaDoubleTensor = wrap(JNoa.randDouble(shape, device.toInt())) - public fun randDiscreteDouble(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor = + override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor = wrap(JNoa.randintDouble(low, high, shape, device.toInt())) override operator fun Double.plus(other: Tensor): NoaDoubleTensor = @@ -347,14 +360,14 @@ protected constructor(scope: NoaScope) : override fun Tensor.divAssign(value: Double): Unit = tensor.timesAssign(1 / value) - public fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor = + override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor = wrap(JNoa.fullDouble(value, shape, device.toInt())) } public sealed class NoaFloatAlgebra protected constructor(scope: NoaScope) : - NoaPartialDivisionAlgebra(scope) { + NoaPartialDivisionAlgebra(scope) { private fun Tensor.castHelper(): NoaFloatTensor = copyFromArray( @@ -372,19 +385,19 @@ protected constructor(scope: NoaScope) : NoaFloatTensor(scope = scope, tensorHandle = tensorHandle) @PerformancePitfall - public fun Tensor.copyToArray(): FloatArray = + override fun Tensor.copyToArray(): FloatArray = tensor.elements().map { it.second }.toList().toFloatArray() - public fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor = + override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor = wrap(JNoa.fromBlobFloat(array, shape, device.toInt())) - public fun randNormalFloat(shape: IntArray, device: Device): NoaFloatTensor = + override fun randNormal(shape: IntArray, device: Device): NoaFloatTensor = wrap(JNoa.randnFloat(shape, device.toInt())) - public fun randUniformFloat(shape: IntArray, device: Device): NoaFloatTensor = + override fun randUniform(shape: IntArray, device: Device): NoaFloatTensor = wrap(JNoa.randFloat(shape, device.toInt())) - public fun randDiscreteFloat(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor = + override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor = wrap(JNoa.randintFloat(low, high, shape, device.toInt())) override operator fun Float.plus(other: Tensor): NoaFloatTensor = @@ -423,14 +436,14 @@ protected constructor(scope: NoaScope) : override fun Tensor.divAssign(value: Float): Unit = tensor.timesAssign(1 / value) - public fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor = + override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor = wrap(JNoa.fullFloat(value, shape, device.toInt())) } public sealed class NoaLongAlgebra protected constructor(scope: NoaScope) : - NoaAlgebra(scope) { + NoaAlgebra(scope) { private fun Tensor.castHelper(): NoaLongTensor = copyFromArray( @@ -448,13 +461,13 @@ protected constructor(scope: NoaScope) : NoaLongTensor(scope = scope, tensorHandle = tensorHandle) @PerformancePitfall - public fun Tensor.copyToArray(): LongArray = + override fun Tensor.copyToArray(): LongArray = tensor.elements().map { it.second }.toList().toLongArray() - public fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor = + override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor = wrap(JNoa.fromBlobLong(array, shape, device.toInt())) - public fun randDiscreteLong(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor = + override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor = wrap(JNoa.randintLong(low, high, shape, device.toInt())) override operator fun Long.plus(other: Tensor): NoaLongTensor = @@ -484,14 +497,14 @@ protected constructor(scope: NoaScope) : override fun Tensor.timesAssign(value: Long): Unit = JNoa.timesLongAssign(value, tensor.tensorHandle) - public fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor = + override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor = wrap(JNoa.fullLong(value, shape, device.toInt())) } public sealed class NoaIntAlgebra protected constructor(scope: NoaScope) : - NoaAlgebra(scope) { + NoaAlgebra(scope) { private fun Tensor.castHelper(): NoaIntTensor = copyFromArray( @@ -509,13 +522,13 @@ protected constructor(scope: NoaScope) : NoaIntTensor(scope = scope, tensorHandle = tensorHandle) @PerformancePitfall - public fun Tensor.copyToArray(): IntArray = + override fun Tensor.copyToArray(): IntArray = tensor.elements().map { it.second }.toList().toIntArray() - public fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor = + override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor = wrap(JNoa.fromBlobInt(array, shape, device.toInt())) - public fun randDiscreteInt(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor = + override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor = wrap(JNoa.randintInt(low, high, shape, device.toInt())) override operator fun Int.plus(other: Tensor): NoaIntTensor = @@ -545,7 +558,7 @@ protected constructor(scope: NoaScope) : override fun Tensor.timesAssign(value: Int): Unit = JNoa.timesIntAssign(value, tensor.tensorHandle) - public fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor = + override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor = wrap(JNoa.fullInt(value, shape, device.toInt())) }