NoaTensorAlgebra

This commit is contained in:
Roland Grinis 2021-07-08 21:37:13 +01:00
parent 773ff10dd1
commit 803a88ac2c
3 changed files with 29 additions and 4 deletions

View File

@ -28,8 +28,6 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
*/ */
public fun INDArray.wrap(): Nd4jArrayStructure<T> public fun INDArray.wrap(): Nd4jArrayStructure<T>
public fun INDArray.wrapInt(): Nd4jArrayStructure<Int>
/** /**
* Unwraps to or acquires [INDArray] from [StructureND]. * Unwraps to or acquires [INDArray] from [StructureND].
*/ */
@ -93,7 +91,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape) public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> = override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
ndBase.get().argmax(ndArray, keepDim, dim).wrapInt() ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure()
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap() public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
@ -146,7 +144,6 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
*/ */
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> { public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure() public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
public override fun INDArray.wrapInt(): Nd4jArrayStructure<Int> = asIntStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
public override val StructureND<Double>.ndArray: INDArray public override val StructureND<Double>.ndArray: INDArray

View File

@ -191,6 +191,18 @@ class JNoa {
public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim); public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long minTensor(long tensorHandle);
public static native long minDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long maxTensor(long tensorHandle);
public static native long maxDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim);
public static native long flattenTensor(long tensorHandle);
public static native long matmul(long lhs, long rhs); public static native long matmul(long lhs, long rhs);
public static native void matmulAssign(long lhs, long rhs); public static native void matmulAssign(long lhs, long rhs);

View File

@ -100,6 +100,22 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType = override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim)) wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim))
public fun Tensor<T>.minAll(): TensorType = wrap(JNoa.minTensor(this.cast().tensorHandle))
override fun Tensor<T>.min(): T = minAll().item()
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.minDimTensor(this.cast().tensorHandle, dim, keepDim))
public fun Tensor<T>.maxAll(): TensorType = wrap(JNoa.maxTensor(this.cast().tensorHandle))
override fun Tensor<T>.max(): T = maxAll().item()
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.maxDimTensor(this.cast().tensorHandle, dim, keepDim))
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): NoaIntTensor =
NoaIntTensor(scope, JNoa.argMaxTensor(this.cast().tensorHandle, dim, keepDim))
public fun Tensor<T>.flatten(): TensorType =
wrap(JNoa.flattenTensor(this.cast().tensorHandle))
public fun Tensor<T>.copy(): TensorType = public fun Tensor<T>.copy(): TensorType =
wrap(JNoa.copyTensor(this.cast().tensorHandle)) wrap(JNoa.copyTensor(this.cast().tensorHandle))