value method for algebra
This commit is contained in:
parent
8f32397f50
commit
2a97aca9b6
@ -11,16 +11,26 @@ import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
|
||||
public abstract class NoaAlgebra<T, TensorType: NoaTensor<T>>
|
||||
internal constructor(protected val scope: NoaScope)
|
||||
: TensorAlgebra<T> {
|
||||
|
||||
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
||||
constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
|
||||
|
||||
protected abstract fun Tensor<T>.cast(): TensorType
|
||||
|
||||
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
|
||||
|
||||
/**
|
||||
* A scalar tensor in this implementation must have empty shape
|
||||
*/
|
||||
override fun Tensor<T>.valueOrNull(): T? =
|
||||
try { this.cast().item() } catch (e: NoaException) { null }
|
||||
|
||||
override fun Tensor<T>.value(): T = this.cast().item()
|
||||
}
|
||||
|
||||
public abstract class NoaPartialDivisionAlgebra<T, TensorType: NoaTensor<T>>
|
||||
internal constructor(scope: NoaScope)
|
||||
: NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>, AnalyticTensorAlgebra<T> {
|
||||
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
||||
internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>,
|
||||
AnalyticTensorAlgebra<T> {
|
||||
|
||||
|
||||
}
|
||||
|
@ -19,6 +19,8 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
|
||||
|
||||
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
|
||||
|
||||
internal abstract fun item(): T
|
||||
|
||||
override val dimension: Int get() = JNoa.getDim(tensorHandle)
|
||||
|
||||
override val shape: IntArray
|
||||
@ -79,7 +81,7 @@ public class NoaDoubleTensor
|
||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||
NoaTensorOverField<Double>(scope, tensorHandle) {
|
||||
|
||||
internal fun item(): Double = JNoa.getItemDouble(tensorHandle)
|
||||
override fun item(): Double = JNoa.getItemDouble(tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun get(index: IntArray): Double = JNoa.getDouble(tensorHandle, index)
|
||||
@ -94,7 +96,7 @@ public class NoaFloatTensor
|
||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||
NoaTensorOverField<Float>(scope, tensorHandle) {
|
||||
|
||||
internal fun item(): Float = JNoa.getItemFloat(tensorHandle)
|
||||
override fun item(): Float = JNoa.getItemFloat(tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun get(index: IntArray): Float = JNoa.getFloat(tensorHandle, index)
|
||||
@ -109,7 +111,7 @@ public class NoaLongTensor
|
||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||
NoaTensor<Long>(scope, tensorHandle) {
|
||||
|
||||
internal fun item(): Long = JNoa.getItemLong(tensorHandle)
|
||||
override fun item(): Long = JNoa.getItemLong(tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun get(index: IntArray): Long = JNoa.getLong(tensorHandle, index)
|
||||
@ -124,7 +126,7 @@ public class NoaIntTensor
|
||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||
NoaTensor<Int>(scope, tensorHandle) {
|
||||
|
||||
internal fun item(): Int = JNoa.getItemInt(tensorHandle)
|
||||
override fun item(): Int = JNoa.getItemInt(tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun get(index: IntArray): Int = JNoa.getInt(tensorHandle, index)
|
||||
|
@ -16,19 +16,19 @@ import space.kscience.kmath.operations.Algebra
|
||||
public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
||||
* For a tensor containing a single element, i.e. a scalar tensor, the value of that element is returned.
|
||||
* You need to check if the implementation puts any restrictions on the shape of a scalar tensor.
|
||||
*
|
||||
* @return a nullable value of a potentially scalar tensor.
|
||||
*/
|
||||
public fun Tensor<T>.valueOrNull(): T?
|
||||
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
|
||||
* Unsafe version of [valueOrNull]
|
||||
*
|
||||
* @return the value of a scalar tensor.
|
||||
*/
|
||||
public fun Tensor<T>.value(): T =
|
||||
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
|
||||
public fun Tensor<T>.value(): T
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is added to this value.
|
||||
|
@ -24,9 +24,20 @@ public open class DoubleTensorAlgebra :
|
||||
|
||||
public companion object : DoubleTensorAlgebra()
|
||||
|
||||
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
||||
*
|
||||
* @return a nullable value of a potentially scalar tensor.
|
||||
*/
|
||||
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart] else null
|
||||
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
|
||||
*
|
||||
* @return the value of a scalar tensor.
|
||||
*/
|
||||
override fun Tensor<Double>.value(): Double = valueOrNull()
|
||||
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user