linear algebra

This commit is contained in:
Roland Grinis 2021-07-08 22:54:27 +01:00
parent f328c7f266
commit 5da017ec2e
3 changed files with 77 additions and 4 deletions

View File

@ -255,9 +255,19 @@ class JNoa {
public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2); public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2);
public static native long detTensor(long tensorHandle);
public static native long invTensor(long tensorHandle);
public static native long choleskyTensor(long tensorHandle);
public static native void qrTensor(long tensorHandle, long Qhandle, long Rhandle);
public static native void luTensor(long tensorHandle, long Phandle, long Lhandle, long Uhandle);
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle); public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle); public static native void symEigTensor(long tensorHandle, long Shandle, long Vhandle);
public static native boolean requiresGrad(long tensorHandle); public static native boolean requiresGrad(long tensorHandle);

View File

@ -220,6 +220,29 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
override fun Tensor<T>.floor(): TensorType = override fun Tensor<T>.floor(): TensorType =
wrap(JNoa.floorTensor(tensor.tensorHandle)) wrap(JNoa.floorTensor(tensor.tensorHandle))
override fun Tensor<T>.det(): Tensor<T> =
wrap(JNoa.detTensor(tensor.tensorHandle))
override fun Tensor<T>.inv(): Tensor<T> =
wrap(JNoa.invTensor(tensor.tensorHandle))
override fun Tensor<T>.cholesky(): Tensor<T> =
wrap(JNoa.choleskyTensor(tensor.tensorHandle))
override fun Tensor<T>.qr(): Pair<TensorType, TensorType> {
val Q = JNoa.emptyTensor()
val R = JNoa.emptyTensor()
JNoa.qrTensor(tensor.tensorHandle, Q, R)
return Pair(wrap(Q), wrap(R))
}
override fun Tensor<T>.lu(): Triple<TensorType, TensorType, TensorType> {
val P = JNoa.emptyTensor()
val L = JNoa.emptyTensor()
val U = JNoa.emptyTensor()
JNoa.svdTensor(tensor.tensorHandle, P, L, U)
return Triple(wrap(P), wrap(L), wrap(U))
}
override fun Tensor<T>.svd(): Triple<TensorType, TensorType, TensorType> { override fun Tensor<T>.svd(): Triple<TensorType, TensorType, TensorType> {
val U = JNoa.emptyTensor() val U = JNoa.emptyTensor()
@ -232,7 +255,7 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
override fun Tensor<T>.symEig(): Pair<TensorType, TensorType> { override fun Tensor<T>.symEig(): Pair<TensorType, TensorType> {
val V = JNoa.emptyTensor() val V = JNoa.emptyTensor()
val S = JNoa.emptyTensor() val S = JNoa.emptyTensor()
JNoa.symeigTensor(tensor.tensorHandle, S, V) JNoa.symEigTensor(tensor.tensorHandle, S, V)
return Pair(wrap(S), wrap(V)) return Pair(wrap(S), wrap(V))
} }

View File

@ -943,6 +943,46 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_matmulRightAssign
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_diagEmbed JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_diagEmbed
(JNIEnv *, jclass, jlong, jint, jint, jint); (JNIEnv *, jclass, jlong, jint, jint, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: detTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_detTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: invTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_invTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: choleskyTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_choleskyTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: qrTensor
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_qrTensor
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: luTensor
* Signature: (JJJJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_luTensor
(JNIEnv *, jclass, jlong, jlong, jlong, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: svdTensor * Method: svdTensor
@ -953,10 +993,10 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_svdTensor
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: symeigTensor * Method: symEigTensor
* Signature: (JJJ)V * Signature: (JJJ)V
*/ */
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symeigTensor JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symEigTensor
(JNIEnv *, jclass, jlong, jlong, jlong); (JNIEnv *, jclass, jlong, jlong, jlong);
/* /*