From f328c7f266bf78cd4bebfee05da0875a52e9506e Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 22:42:59 +0100 Subject: [PATCH] analytic algebra --- .../java/space/kscience/kmath/noa/JNoa.java | 46 +++- .../space/kscience/kmath/noa/algebras.kt | 78 +++++- .../resources/space_kscience_kmath_noa_JNoa.h | 240 +++++++++++++++++- .../kmath/tensors/api/TensorAlgebra.kt | 1 - 4 files changed, 345 insertions(+), 20 deletions(-) diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index ce7eddb93..e6e433f7b 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -189,6 +189,36 @@ class JNoa { public static native long lnTensor(long tensorHandle); + public static native long sqrtTensor(long tensorHandle); + + public static native long cosTensor(long tensorHandle); + + public static native long acosTensor(long tensorHandle); + + public static native long coshTensor(long tensorHandle); + + public static native long acoshTensor(long tensorHandle); + + public static native long sinTensor(long tensorHandle); + + public static native long asinTensor(long tensorHandle); + + public static native long sinhTensor(long tensorHandle); + + public static native long asinhTensor(long tensorHandle); + + public static native long tanTensor(long tensorHandle); + + public static native long atanTensor(long tensorHandle); + + public static native long tanhTensor(long tensorHandle); + + public static native long atanhTensor(long tensorHandle); + + public static native long ceilTensor(long tensorHandle); + + public static native long floorTensor(long tensorHandle); + public static native long sumTensor(long tensorHandle); public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim); @@ -201,6 +231,18 @@ class JNoa { public static native long maxDimTensor(long tensorHandle, int dim, boolean keepDim); + public static native long meanTensor(long tensorHandle); + + public static native long meanDimTensor(long tensorHandle, int dim, boolean keepDim); + + public static native long stdTensor(long tensorHandle); + + public static native long stdDimTensor(long tensorHandle, int dim, boolean keepDim); + + public static native long varTensor(long tensorHandle); + + public static native long varDimTensor(long tensorHandle, int dim, boolean keepDim); + public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim); public static native long flattenTensor(long tensorHandle); @@ -223,8 +265,8 @@ class JNoa { public static native long detachFromGraph(long tensorHandle); - public static native long autogradTensor(long value, long variable, boolean retainGraph); + public static native long autoGradTensor(long value, long variable, boolean retainGraph); - public static native long autohessTensor(long value, long variable); + public static native long autoHessTensor(long value, long variable); } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 56bd712c0..d477716fd 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -116,10 +116,10 @@ constructor(protected val scope: NoaScope) : TensorAlgebra { public fun Tensor.flatten(): TensorType = wrap(JNoa.flattenTensor(tensor.tensorHandle)) - public fun Tensor.randIntegral(low: Long, high: Long): TensorType = + public fun Tensor.randDiscrete(low: Long, high: Long): TensorType = wrap(JNoa.randintLike(tensor.tensorHandle, low, high)) - public fun Tensor.randIntegralAssign(low: Long, high: Long): Unit = + public fun Tensor.randDiscreteAssign(low: Long, high: Long): Unit = JNoa.randintLikeAssign(tensor.tensorHandle, low, high) public fun Tensor.copy(): TensorType = @@ -142,6 +142,21 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear JNoa.divTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) } + public fun Tensor.meanAll(): TensorType = wrap(JNoa.meanTensor(tensor.tensorHandle)) + override fun Tensor.mean(): T = meanAll().item() + override fun Tensor.mean(dim: Int, keepDim: Boolean): TensorType = + wrap(JNoa.meanDimTensor(tensor.tensorHandle, dim, keepDim)) + + public fun Tensor.stdAll(): TensorType = wrap(JNoa.stdTensor(tensor.tensorHandle)) + override fun Tensor.std(): T = stdAll().item() + override fun Tensor.std(dim: Int, keepDim: Boolean): TensorType = + wrap(JNoa.stdDimTensor(tensor.tensorHandle, dim, keepDim)) + + public fun Tensor.varAll(): TensorType = wrap(JNoa.varTensor(tensor.tensorHandle)) + override fun Tensor.variance(): T = varAll().item() + override fun Tensor.variance(dim: Int, keepDim: Boolean): TensorType = + wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim)) + public fun Tensor.randUniform(): TensorType = wrap(JNoa.randLike(tensor.tensorHandle)) @@ -154,12 +169,57 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear public fun Tensor.randNormalAssign(): Unit = JNoa.randnLikeAssign(tensor.tensorHandle) - override fun Tensor.exp(): TensorType = + override fun Tensor.exp(): TensorType = wrap(JNoa.expTensor(tensor.tensorHandle)) - - override fun Tensor.ln(): TensorType = + + override fun Tensor.ln(): TensorType = wrap(JNoa.lnTensor(tensor.tensorHandle)) - + + override fun Tensor.sqrt(): TensorType = + wrap(JNoa.sqrtTensor(tensor.tensorHandle)) + + override fun Tensor.cos(): TensorType = + wrap(JNoa.cosTensor(tensor.tensorHandle)) + + override fun Tensor.acos(): TensorType = + wrap(JNoa.acosTensor(tensor.tensorHandle)) + + override fun Tensor.cosh(): TensorType = + wrap(JNoa.coshTensor(tensor.tensorHandle)) + + override fun Tensor.acosh(): TensorType = + wrap(JNoa.acoshTensor(tensor.tensorHandle)) + + override fun Tensor.sin(): TensorType = + wrap(JNoa.sinTensor(tensor.tensorHandle)) + + override fun Tensor.asin(): TensorType = + wrap(JNoa.asinTensor(tensor.tensorHandle)) + + override fun Tensor.sinh(): TensorType = + wrap(JNoa.sinhTensor(tensor.tensorHandle)) + + override fun Tensor.asinh(): TensorType = + wrap(JNoa.asinhTensor(tensor.tensorHandle)) + + override fun Tensor.tan(): TensorType = + wrap(JNoa.tanTensor(tensor.tensorHandle)) + + override fun Tensor.atan(): TensorType = + wrap(JNoa.atanTensor(tensor.tensorHandle)) + + override fun Tensor.tanh(): TensorType = + wrap(JNoa.tanhTensor(tensor.tensorHandle)) + + override fun Tensor.atanh(): TensorType = + wrap(JNoa.atanhTensor(tensor.tensorHandle)) + + override fun Tensor.ceil(): TensorType = + wrap(JNoa.ceilTensor(tensor.tensorHandle)) + + override fun Tensor.floor(): TensorType = + wrap(JNoa.floorTensor(tensor.tensorHandle)) + override fun Tensor.svd(): Triple { val U = JNoa.emptyTensor() @@ -177,16 +237,16 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear } public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType { - return wrap(JNoa.autogradTensor(tensorHandle, variable.tensorHandle, retainGraph)) + return wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph)) } public infix fun TensorType.hess(variable: TensorType): TensorType { - return wrap(JNoa.autohessTensor(tensorHandle, variable.tensorHandle)) + return wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle)) } public fun TensorType.detachFromGraph(): TensorType = wrap(JNoa.detachFromGraph(tensorHandle)) - + } diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index 3c7342aa5..c31bb622f 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -55,6 +55,14 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSeed JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeTensor (JNIEnv *, jclass, jlong); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: emptyTensor + * Signature: ()J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_emptyTensor + (JNIEnv *, jclass); + /* * Class: space_kscience_kmath_noa_JNoa * Method: fromBlobDouble @@ -665,10 +673,130 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor /* * Class: space_kscience_kmath_noa_JNoa - * Method: logTensor + * Method: lnTensor * Signature: (J)J */ -JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_lnTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: sqrtTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sqrtTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: cosTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_cosTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: acosTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_acosTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: coshTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_coshTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: acoshTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_acoshTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: sinTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sinTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: asinTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_asinTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: sinhTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sinhTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: asinhTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_asinhTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: tanTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_tanTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: atanTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_atanTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: tanhTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_tanhTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: atanhTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_atanhTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: ceilTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_ceilTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: floorTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_floorTensor (JNIEnv *, jclass, jlong); /* @@ -687,6 +815,102 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumDimTensor (JNIEnv *, jclass, jlong, jint, jboolean); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: minTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_minTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: minDimTensor + * Signature: (JIZ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_minDimTensor + (JNIEnv *, jclass, jlong, jint, jboolean); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: maxTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_maxTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: maxDimTensor + * Signature: (JIZ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_maxDimTensor + (JNIEnv *, jclass, jlong, jint, jboolean); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: meanTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_meanTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: meanDimTensor + * Signature: (JIZ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_meanDimTensor + (JNIEnv *, jclass, jlong, jint, jboolean); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: stdTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_stdTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: stdDimTensor + * Signature: (JIZ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_stdDimTensor + (JNIEnv *, jclass, jlong, jint, jboolean); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: varTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_varTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: varDimTensor + * Signature: (JIZ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_varDimTensor + (JNIEnv *, jclass, jlong, jint, jboolean); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: argMaxTensor + * Signature: (JIZ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_argMaxTensor + (JNIEnv *, jclass, jlong, jint, jboolean); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: flattenTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_flattenTensor + (JNIEnv *, jclass, jlong); + /* * Class: space_kscience_kmath_noa_JNoa * Method: matmul @@ -730,10 +954,10 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_svdTensor /* * Class: space_kscience_kmath_noa_JNoa * Method: symeigTensor - * Signature: (JJJZ)V + * Signature: (JJJ)V */ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symeigTensor - (JNIEnv *, jclass, jlong, jlong, jlong, jboolean); + (JNIEnv *, jclass, jlong, jlong, jlong); /* * Class: space_kscience_kmath_noa_JNoa @@ -761,18 +985,18 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_detachFromGraph /* * Class: space_kscience_kmath_noa_JNoa - * Method: autogradTensor + * Method: autoGradTensor * Signature: (JJZ)J */ -JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autogradTensor +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autoGradTensor (JNIEnv *, jclass, jlong, jlong, jboolean); /* * Class: space_kscience_kmath_noa_JNoa - * Method: autohessTensor + * Method: autoHessTensor * Signature: (JJ)J */ -JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autohessTensor +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autoHessTensor (JNIEnv *, jclass, jlong, jlong); #ifdef __cplusplus diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 52a00c837..de3b12fd6 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -6,7 +6,6 @@ package space.kscience.kmath.tensors.api import space.kscience.kmath.operations.Algebra -import space.kscience.kmath.tensors.core.DoubleTensor /** * Algebra over a ring on [Tensor].