From 2a97aca9b6f8cd2cd66647803a2f6b1ba7599b24 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 10:57:08 +0100 Subject: [PATCH] value method for algebra --- .../space/kscience/kmath/noa/algebras.kt | 22 ++++++++++++++----- .../space/kscience/kmath/noa/tensors.kt | 10 +++++---- .../kmath/tensors/api/TensorAlgebra.kt | 8 +++---- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 11 ++++++++++ 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index e9c80a465..deea592d4 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -11,16 +11,26 @@ import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra -public abstract class NoaAlgebra> -internal constructor(protected val scope: NoaScope) - : TensorAlgebra { + +public sealed class NoaAlgebra> +constructor(protected val scope: NoaScope) : TensorAlgebra { protected abstract fun Tensor.cast(): TensorType + + protected abstract fun wrap(tensorHandle: TensorHandle): TensorType + + /** + * A scalar tensor in this implementation must have empty shape + */ + override fun Tensor.valueOrNull(): T? = + try { this.cast().item() } catch (e: NoaException) { null } + + override fun Tensor.value(): T = this.cast().item() } -public abstract class NoaPartialDivisionAlgebra> -internal constructor(scope: NoaScope) - : NoaAlgebra(scope), LinearOpsTensorAlgebra, AnalyticTensorAlgebra { +public abstract class NoaPartialDivisionAlgebra> +internal constructor(scope: NoaScope) : NoaAlgebra(scope), LinearOpsTensorAlgebra, + AnalyticTensorAlgebra { } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt index dffb14886..e24f29b9b 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt @@ -19,6 +19,8 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) : override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle) + internal abstract fun item(): T + override val dimension: Int get() = JNoa.getDim(tensorHandle) override val shape: IntArray @@ -79,7 +81,7 @@ public class NoaDoubleTensor internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : NoaTensorOverField(scope, tensorHandle) { - internal fun item(): Double = JNoa.getItemDouble(tensorHandle) + override fun item(): Double = JNoa.getItemDouble(tensorHandle) @PerformancePitfall override fun get(index: IntArray): Double = JNoa.getDouble(tensorHandle, index) @@ -94,7 +96,7 @@ public class NoaFloatTensor internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : NoaTensorOverField(scope, tensorHandle) { - internal fun item(): Float = JNoa.getItemFloat(tensorHandle) + override fun item(): Float = JNoa.getItemFloat(tensorHandle) @PerformancePitfall override fun get(index: IntArray): Float = JNoa.getFloat(tensorHandle, index) @@ -109,7 +111,7 @@ public class NoaLongTensor internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : NoaTensor(scope, tensorHandle) { - internal fun item(): Long = JNoa.getItemLong(tensorHandle) + override fun item(): Long = JNoa.getItemLong(tensorHandle) @PerformancePitfall override fun get(index: IntArray): Long = JNoa.getLong(tensorHandle, index) @@ -124,7 +126,7 @@ public class NoaIntTensor internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : NoaTensor(scope, tensorHandle) { - internal fun item(): Int = JNoa.getItemInt(tensorHandle) + override fun item(): Int = JNoa.getItemInt(tensorHandle) @PerformancePitfall override fun get(index: IntArray): Int = JNoa.getInt(tensorHandle, index) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 62b8ef046..2f5ab1dbe 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -16,19 +16,19 @@ import space.kscience.kmath.operations.Algebra public interface TensorAlgebra : Algebra> { /** - * Returns a single tensor value of unit dimension if tensor shape equals to [1]. + * For a tensor containing a single element, i.e. a scalar tensor, the value of that element is returned. + * You need to check if the implementation puts any restrictions on the shape of a scalar tensor. * * @return a nullable value of a potentially scalar tensor. */ public fun Tensor.valueOrNull(): T? /** - * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. + * Unsafe version of [valueOrNull] * * @return the value of a scalar tensor. */ - public fun Tensor.value(): T = - valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") + public fun Tensor.value(): T /** * Each element of the tensor [other] is added to this value. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 1fd46bd57..b60ebffe2 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -24,9 +24,20 @@ public open class DoubleTensorAlgebra : public companion object : DoubleTensorAlgebra() + + /** + * Returns a single tensor value of unit dimension if tensor shape equals to [1]. + * + * @return a nullable value of a potentially scalar tensor. + */ override fun Tensor.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) tensor.mutableBuffer.array()[tensor.bufferStart] else null + /** + * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. + * + * @return the value of a scalar tensor. + */ override fun Tensor.value(): Double = valueOrNull() ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")