forked from kscience/kmath
NoaTensorAlgebra
This commit is contained in:
parent
773ff10dd1
commit
803a88ac2c
@ -28,8 +28,6 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
||||
*/
|
||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||
|
||||
public fun INDArray.wrapInt(): Nd4jArrayStructure<Int>
|
||||
|
||||
/**
|
||||
* Unwraps to or acquires [INDArray] from [StructureND].
|
||||
*/
|
||||
@ -93,7 +91,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
||||
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
||||
|
||||
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
|
||||
ndBase.get().argmax(ndArray, keepDim, dim).wrapInt()
|
||||
ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure()
|
||||
|
||||
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
||||
|
||||
@ -146,7 +144,6 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
||||
*/
|
||||
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||
public override fun INDArray.wrapInt(): Nd4jArrayStructure<Int> = asIntStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public override val StructureND<Double>.ndArray: INDArray
|
||||
|
@ -191,6 +191,18 @@ class JNoa {
|
||||
|
||||
public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim);
|
||||
|
||||
public static native long minTensor(long tensorHandle);
|
||||
|
||||
public static native long minDimTensor(long tensorHandle, int dim, boolean keepDim);
|
||||
|
||||
public static native long maxTensor(long tensorHandle);
|
||||
|
||||
public static native long maxDimTensor(long tensorHandle, int dim, boolean keepDim);
|
||||
|
||||
public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim);
|
||||
|
||||
public static native long flattenTensor(long tensorHandle);
|
||||
|
||||
public static native long matmul(long lhs, long rhs);
|
||||
|
||||
public static native void matmulAssign(long lhs, long rhs);
|
||||
|
@ -100,6 +100,22 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
|
||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
|
||||
public fun Tensor<T>.minAll(): TensorType = wrap(JNoa.minTensor(this.cast().tensorHandle))
|
||||
override fun Tensor<T>.min(): T = minAll().item()
|
||||
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.minDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
|
||||
public fun Tensor<T>.maxAll(): TensorType = wrap(JNoa.maxTensor(this.cast().tensorHandle))
|
||||
override fun Tensor<T>.max(): T = maxAll().item()
|
||||
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.maxDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
|
||||
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): NoaIntTensor =
|
||||
NoaIntTensor(scope, JNoa.argMaxTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
|
||||
public fun Tensor<T>.flatten(): TensorType =
|
||||
wrap(JNoa.flattenTensor(this.cast().tensorHandle))
|
||||
|
||||
public fun Tensor<T>.copy(): TensorType =
|
||||
wrap(JNoa.copyTensor(this.cast().tensorHandle))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user