value method for algebra

This commit is contained in:
Roland Grinis 2021-07-08 10:57:08 +01:00
parent 8f32397f50
commit 2a97aca9b6
4 changed files with 37 additions and 14 deletions

View File

@ -11,16 +11,26 @@ import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
public abstract class NoaAlgebra<T, TensorType: NoaTensor<T>>
internal constructor(protected val scope: NoaScope) public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
: TensorAlgebra<T> { constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
protected abstract fun Tensor<T>.cast(): TensorType protected abstract fun Tensor<T>.cast(): TensorType
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
/**
* A scalar tensor in this implementation must have empty shape
*/
override fun Tensor<T>.valueOrNull(): T? =
try { this.cast().item() } catch (e: NoaException) { null }
override fun Tensor<T>.value(): T = this.cast().item()
} }
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>> public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
internal constructor(scope: NoaScope) internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>,
: NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>, AnalyticTensorAlgebra<T> { AnalyticTensorAlgebra<T> {
} }

View File

@ -19,6 +19,8 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle) override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
internal abstract fun item(): T
override val dimension: Int get() = JNoa.getDim(tensorHandle) override val dimension: Int get() = JNoa.getDim(tensorHandle)
override val shape: IntArray override val shape: IntArray
@ -79,7 +81,7 @@ public class NoaDoubleTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensorOverField<Double>(scope, tensorHandle) { NoaTensorOverField<Double>(scope, tensorHandle) {
internal fun item(): Double = JNoa.getItemDouble(tensorHandle) override fun item(): Double = JNoa.getItemDouble(tensorHandle)
@PerformancePitfall @PerformancePitfall
override fun get(index: IntArray): Double = JNoa.getDouble(tensorHandle, index) override fun get(index: IntArray): Double = JNoa.getDouble(tensorHandle, index)
@ -94,7 +96,7 @@ public class NoaFloatTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensorOverField<Float>(scope, tensorHandle) { NoaTensorOverField<Float>(scope, tensorHandle) {
internal fun item(): Float = JNoa.getItemFloat(tensorHandle) override fun item(): Float = JNoa.getItemFloat(tensorHandle)
@PerformancePitfall @PerformancePitfall
override fun get(index: IntArray): Float = JNoa.getFloat(tensorHandle, index) override fun get(index: IntArray): Float = JNoa.getFloat(tensorHandle, index)
@ -109,7 +111,7 @@ public class NoaLongTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensor<Long>(scope, tensorHandle) { NoaTensor<Long>(scope, tensorHandle) {
internal fun item(): Long = JNoa.getItemLong(tensorHandle) override fun item(): Long = JNoa.getItemLong(tensorHandle)
@PerformancePitfall @PerformancePitfall
override fun get(index: IntArray): Long = JNoa.getLong(tensorHandle, index) override fun get(index: IntArray): Long = JNoa.getLong(tensorHandle, index)
@ -124,7 +126,7 @@ public class NoaIntTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensor<Int>(scope, tensorHandle) { NoaTensor<Int>(scope, tensorHandle) {
internal fun item(): Int = JNoa.getItemInt(tensorHandle) override fun item(): Int = JNoa.getItemInt(tensorHandle)
@PerformancePitfall @PerformancePitfall
override fun get(index: IntArray): Int = JNoa.getInt(tensorHandle, index) override fun get(index: IntArray): Int = JNoa.getInt(tensorHandle, index)

View File

@ -16,19 +16,19 @@ import space.kscience.kmath.operations.Algebra
public interface TensorAlgebra<T> : Algebra<Tensor<T>> { public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
/** /**
* Returns a single tensor value of unit dimension if tensor shape equals to [1]. * For a tensor containing a single element, i.e. a scalar tensor, the value of that element is returned.
* You need to check if the implementation puts any restrictions on the shape of a scalar tensor.
* *
* @return a nullable value of a potentially scalar tensor. * @return a nullable value of a potentially scalar tensor.
*/ */
public fun Tensor<T>.valueOrNull(): T? public fun Tensor<T>.valueOrNull(): T?
/** /**
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. * Unsafe version of [valueOrNull]
* *
* @return the value of a scalar tensor. * @return the value of a scalar tensor.
*/ */
public fun Tensor<T>.value(): T = public fun Tensor<T>.value(): T
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
/** /**
* Each element of the tensor [other] is added to this value. * Each element of the tensor [other] is added to this value.

View File

@ -24,9 +24,20 @@ public open class DoubleTensorAlgebra :
public companion object : DoubleTensorAlgebra() public companion object : DoubleTensorAlgebra()
/**
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
*
* @return a nullable value of a potentially scalar tensor.
*/
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null tensor.mutableBuffer.array()[tensor.bufferStart] else null
/**
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
*
* @return the value of a scalar tensor.
*/
override fun Tensor<Double>.value(): Double = valueOrNull() override fun Tensor<Double>.value(): Double = valueOrNull()
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")