forked from kscience/kmath
value method for algebra
This commit is contained in:
parent
8f32397f50
commit
2a97aca9b6
@ -11,16 +11,26 @@ import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
|||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
|
|
||||||
public abstract class NoaAlgebra<T, TensorType: NoaTensor<T>>
|
|
||||||
internal constructor(protected val scope: NoaScope)
|
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
||||||
: TensorAlgebra<T> {
|
constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
|
||||||
|
|
||||||
protected abstract fun Tensor<T>.cast(): TensorType
|
protected abstract fun Tensor<T>.cast(): TensorType
|
||||||
|
|
||||||
|
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A scalar tensor in this implementation must have empty shape
|
||||||
|
*/
|
||||||
|
override fun Tensor<T>.valueOrNull(): T? =
|
||||||
|
try { this.cast().item() } catch (e: NoaException) { null }
|
||||||
|
|
||||||
|
override fun Tensor<T>.value(): T = this.cast().item()
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract class NoaPartialDivisionAlgebra<T, TensorType: NoaTensor<T>>
|
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
||||||
internal constructor(scope: NoaScope)
|
internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>,
|
||||||
: NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>, AnalyticTensorAlgebra<T> {
|
AnalyticTensorAlgebra<T> {
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,8 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
|
|||||||
|
|
||||||
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
|
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
|
||||||
|
|
||||||
|
internal abstract fun item(): T
|
||||||
|
|
||||||
override val dimension: Int get() = JNoa.getDim(tensorHandle)
|
override val dimension: Int get() = JNoa.getDim(tensorHandle)
|
||||||
|
|
||||||
override val shape: IntArray
|
override val shape: IntArray
|
||||||
@ -79,7 +81,7 @@ public class NoaDoubleTensor
|
|||||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||||
NoaTensorOverField<Double>(scope, tensorHandle) {
|
NoaTensorOverField<Double>(scope, tensorHandle) {
|
||||||
|
|
||||||
internal fun item(): Double = JNoa.getItemDouble(tensorHandle)
|
override fun item(): Double = JNoa.getItemDouble(tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override fun get(index: IntArray): Double = JNoa.getDouble(tensorHandle, index)
|
override fun get(index: IntArray): Double = JNoa.getDouble(tensorHandle, index)
|
||||||
@ -94,7 +96,7 @@ public class NoaFloatTensor
|
|||||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||||
NoaTensorOverField<Float>(scope, tensorHandle) {
|
NoaTensorOverField<Float>(scope, tensorHandle) {
|
||||||
|
|
||||||
internal fun item(): Float = JNoa.getItemFloat(tensorHandle)
|
override fun item(): Float = JNoa.getItemFloat(tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override fun get(index: IntArray): Float = JNoa.getFloat(tensorHandle, index)
|
override fun get(index: IntArray): Float = JNoa.getFloat(tensorHandle, index)
|
||||||
@ -109,7 +111,7 @@ public class NoaLongTensor
|
|||||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||||
NoaTensor<Long>(scope, tensorHandle) {
|
NoaTensor<Long>(scope, tensorHandle) {
|
||||||
|
|
||||||
internal fun item(): Long = JNoa.getItemLong(tensorHandle)
|
override fun item(): Long = JNoa.getItemLong(tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override fun get(index: IntArray): Long = JNoa.getLong(tensorHandle, index)
|
override fun get(index: IntArray): Long = JNoa.getLong(tensorHandle, index)
|
||||||
@ -124,7 +126,7 @@ public class NoaIntTensor
|
|||||||
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
|
||||||
NoaTensor<Int>(scope, tensorHandle) {
|
NoaTensor<Int>(scope, tensorHandle) {
|
||||||
|
|
||||||
internal fun item(): Int = JNoa.getItemInt(tensorHandle)
|
override fun item(): Int = JNoa.getItemInt(tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override fun get(index: IntArray): Int = JNoa.getInt(tensorHandle, index)
|
override fun get(index: IntArray): Int = JNoa.getInt(tensorHandle, index)
|
||||||
|
@ -16,19 +16,19 @@ import space.kscience.kmath.operations.Algebra
|
|||||||
public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
* For a tensor containing a single element, i.e. a scalar tensor, the value of that element is returned.
|
||||||
|
* You need to check if the implementation puts any restrictions on the shape of a scalar tensor.
|
||||||
*
|
*
|
||||||
* @return a nullable value of a potentially scalar tensor.
|
* @return a nullable value of a potentially scalar tensor.
|
||||||
*/
|
*/
|
||||||
public fun Tensor<T>.valueOrNull(): T?
|
public fun Tensor<T>.valueOrNull(): T?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
|
* Unsafe version of [valueOrNull]
|
||||||
*
|
*
|
||||||
* @return the value of a scalar tensor.
|
* @return the value of a scalar tensor.
|
||||||
*/
|
*/
|
||||||
public fun Tensor<T>.value(): T =
|
public fun Tensor<T>.value(): T
|
||||||
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Each element of the tensor [other] is added to this value.
|
* Each element of the tensor [other] is added to this value.
|
||||||
|
@ -24,9 +24,20 @@ public open class DoubleTensorAlgebra :
|
|||||||
|
|
||||||
public companion object : DoubleTensorAlgebra()
|
public companion object : DoubleTensorAlgebra()
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
||||||
|
*
|
||||||
|
* @return a nullable value of a potentially scalar tensor.
|
||||||
|
*/
|
||||||
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
|
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart] else null
|
tensor.mutableBuffer.array()[tensor.bufferStart] else null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
|
||||||
|
*
|
||||||
|
* @return the value of a scalar tensor.
|
||||||
|
*/
|
||||||
override fun Tensor<Double>.value(): Double = valueOrNull()
|
override fun Tensor<Double>.value(): Double = valueOrNull()
|
||||||
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
|
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user