analytic algebra

This commit is contained in:
Roland Grinis 2021-07-08 22:42:59 +01:00
parent 0088be99f5
commit f328c7f266
4 changed files with 345 additions and 20 deletions

View File

@ -189,6 +189,36 @@ class JNoa {
public static native long lnTensor(long tensorHandle); public static native long lnTensor(long tensorHandle);
public static native long sqrtTensor(long tensorHandle);
public static native long cosTensor(long tensorHandle);
public static native long acosTensor(long tensorHandle);
public static native long coshTensor(long tensorHandle);
public static native long acoshTensor(long tensorHandle);
public static native long sinTensor(long tensorHandle);
public static native long asinTensor(long tensorHandle);
public static native long sinhTensor(long tensorHandle);
public static native long asinhTensor(long tensorHandle);
public static native long tanTensor(long tensorHandle);
public static native long atanTensor(long tensorHandle);
public static native long tanhTensor(long tensorHandle);
public static native long atanhTensor(long tensorHandle);
public static native long ceilTensor(long tensorHandle);
public static native long floorTensor(long tensorHandle);
public static native long sumTensor(long tensorHandle); public static native long sumTensor(long tensorHandle);
public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim); public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim);
@ -201,6 +231,18 @@ class JNoa {
public static native long maxDimTensor(long tensorHandle, int dim, boolean keepDim); public static native long maxDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long meanTensor(long tensorHandle);
public static native long meanDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long stdTensor(long tensorHandle);
public static native long stdDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long varTensor(long tensorHandle);
public static native long varDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim); public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim);
public static native long flattenTensor(long tensorHandle); public static native long flattenTensor(long tensorHandle);
@ -223,8 +265,8 @@ class JNoa {
public static native long detachFromGraph(long tensorHandle); public static native long detachFromGraph(long tensorHandle);
public static native long autogradTensor(long value, long variable, boolean retainGraph); public static native long autoGradTensor(long value, long variable, boolean retainGraph);
public static native long autohessTensor(long value, long variable); public static native long autoHessTensor(long value, long variable);
} }

View File

@ -116,10 +116,10 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
public fun Tensor<T>.flatten(): TensorType = public fun Tensor<T>.flatten(): TensorType =
wrap(JNoa.flattenTensor(tensor.tensorHandle)) wrap(JNoa.flattenTensor(tensor.tensorHandle))
public fun Tensor<T>.randIntegral(low: Long, high: Long): TensorType = public fun Tensor<T>.randDiscrete(low: Long, high: Long): TensorType =
wrap(JNoa.randintLike(tensor.tensorHandle, low, high)) wrap(JNoa.randintLike(tensor.tensorHandle, low, high))
public fun Tensor<T>.randIntegralAssign(low: Long, high: Long): Unit = public fun Tensor<T>.randDiscreteAssign(low: Long, high: Long): Unit =
JNoa.randintLikeAssign(tensor.tensorHandle, low, high) JNoa.randintLikeAssign(tensor.tensorHandle, low, high)
public fun Tensor<T>.copy(): TensorType = public fun Tensor<T>.copy(): TensorType =
@ -142,6 +142,21 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
JNoa.divTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) JNoa.divTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
} }
public fun Tensor<T>.meanAll(): TensorType = wrap(JNoa.meanTensor(tensor.tensorHandle))
override fun Tensor<T>.mean(): T = meanAll().item()
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.meanDimTensor(tensor.tensorHandle, dim, keepDim))
public fun Tensor<T>.stdAll(): TensorType = wrap(JNoa.stdTensor(tensor.tensorHandle))
override fun Tensor<T>.std(): T = stdAll().item()
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.stdDimTensor(tensor.tensorHandle, dim, keepDim))
public fun Tensor<T>.varAll(): TensorType = wrap(JNoa.varTensor(tensor.tensorHandle))
override fun Tensor<T>.variance(): T = varAll().item()
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
public fun Tensor<T>.randUniform(): TensorType = public fun Tensor<T>.randUniform(): TensorType =
wrap(JNoa.randLike(tensor.tensorHandle)) wrap(JNoa.randLike(tensor.tensorHandle))
@ -160,6 +175,51 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
override fun Tensor<T>.ln(): TensorType = override fun Tensor<T>.ln(): TensorType =
wrap(JNoa.lnTensor(tensor.tensorHandle)) wrap(JNoa.lnTensor(tensor.tensorHandle))
override fun Tensor<T>.sqrt(): TensorType =
wrap(JNoa.sqrtTensor(tensor.tensorHandle))
override fun Tensor<T>.cos(): TensorType =
wrap(JNoa.cosTensor(tensor.tensorHandle))
override fun Tensor<T>.acos(): TensorType =
wrap(JNoa.acosTensor(tensor.tensorHandle))
override fun Tensor<T>.cosh(): TensorType =
wrap(JNoa.coshTensor(tensor.tensorHandle))
override fun Tensor<T>.acosh(): TensorType =
wrap(JNoa.acoshTensor(tensor.tensorHandle))
override fun Tensor<T>.sin(): TensorType =
wrap(JNoa.sinTensor(tensor.tensorHandle))
override fun Tensor<T>.asin(): TensorType =
wrap(JNoa.asinTensor(tensor.tensorHandle))
override fun Tensor<T>.sinh(): TensorType =
wrap(JNoa.sinhTensor(tensor.tensorHandle))
override fun Tensor<T>.asinh(): TensorType =
wrap(JNoa.asinhTensor(tensor.tensorHandle))
override fun Tensor<T>.tan(): TensorType =
wrap(JNoa.tanTensor(tensor.tensorHandle))
override fun Tensor<T>.atan(): TensorType =
wrap(JNoa.atanTensor(tensor.tensorHandle))
override fun Tensor<T>.tanh(): TensorType =
wrap(JNoa.tanhTensor(tensor.tensorHandle))
override fun Tensor<T>.atanh(): TensorType =
wrap(JNoa.atanhTensor(tensor.tensorHandle))
override fun Tensor<T>.ceil(): TensorType =
wrap(JNoa.ceilTensor(tensor.tensorHandle))
override fun Tensor<T>.floor(): TensorType =
wrap(JNoa.floorTensor(tensor.tensorHandle))
override fun Tensor<T>.svd(): Triple<TensorType, TensorType, TensorType> { override fun Tensor<T>.svd(): Triple<TensorType, TensorType, TensorType> {
val U = JNoa.emptyTensor() val U = JNoa.emptyTensor()
@ -177,11 +237,11 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
} }
public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType { public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType {
return wrap(JNoa.autogradTensor(tensorHandle, variable.tensorHandle, retainGraph)) return wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
} }
public infix fun TensorType.hess(variable: TensorType): TensorType { public infix fun TensorType.hess(variable: TensorType): TensorType {
return wrap(JNoa.autohessTensor(tensorHandle, variable.tensorHandle)) return wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle))
} }
public fun TensorType.detachFromGraph(): TensorType = public fun TensorType.detachFromGraph(): TensorType =

View File

@ -55,6 +55,14 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSeed
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeTensor JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: emptyTensor
* Signature: ()J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_emptyTensor
(JNIEnv *, jclass);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: fromBlobDouble * Method: fromBlobDouble
@ -665,10 +673,130 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: logTensor * Method: lnTensor
* Signature: (J)J * Signature: (J)J
*/ */
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_lnTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: sqrtTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sqrtTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: cosTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_cosTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: acosTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_acosTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: coshTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_coshTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: acoshTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_acoshTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: sinTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sinTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: asinTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_asinTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: sinhTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sinhTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: asinhTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_asinhTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: tanTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_tanTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: atanTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_atanTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: tanhTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_tanhTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: atanhTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_atanhTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: ceilTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_ceilTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: floorTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_floorTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/* /*
@ -687,6 +815,102 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensor
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumDimTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumDimTensor
(JNIEnv *, jclass, jlong, jint, jboolean); (JNIEnv *, jclass, jlong, jint, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: minTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_minTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: minDimTensor
* Signature: (JIZ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_minDimTensor
(JNIEnv *, jclass, jlong, jint, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: maxTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_maxTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: maxDimTensor
* Signature: (JIZ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_maxDimTensor
(JNIEnv *, jclass, jlong, jint, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: meanTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_meanTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: meanDimTensor
* Signature: (JIZ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_meanDimTensor
(JNIEnv *, jclass, jlong, jint, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: stdTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_stdTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: stdDimTensor
* Signature: (JIZ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_stdDimTensor
(JNIEnv *, jclass, jlong, jint, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: varTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_varTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: varDimTensor
* Signature: (JIZ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_varDimTensor
(JNIEnv *, jclass, jlong, jint, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: argMaxTensor
* Signature: (JIZ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_argMaxTensor
(JNIEnv *, jclass, jlong, jint, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: flattenTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_flattenTensor
(JNIEnv *, jclass, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: matmul * Method: matmul
@ -730,10 +954,10 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_svdTensor
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: symeigTensor * Method: symeigTensor
* Signature: (JJJZ)V * Signature: (JJJ)V
*/ */
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symeigTensor JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symeigTensor
(JNIEnv *, jclass, jlong, jlong, jlong, jboolean); (JNIEnv *, jclass, jlong, jlong, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
@ -761,18 +985,18 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_detachFromGraph
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: autogradTensor * Method: autoGradTensor
* Signature: (JJZ)J * Signature: (JJZ)J
*/ */
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autogradTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autoGradTensor
(JNIEnv *, jclass, jlong, jlong, jboolean); (JNIEnv *, jclass, jlong, jlong, jboolean);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: autohessTensor * Method: autoHessTensor
* Signature: (JJ)J * Signature: (JJ)J
*/ */
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autohessTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autoHessTensor
(JNIEnv *, jclass, jlong, jlong); (JNIEnv *, jclass, jlong, jlong);
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -6,7 +6,6 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.tensors.core.DoubleTensor
/** /**
* Algebra over a ring on [Tensor]. * Algebra over a ring on [Tensor].