From 5da017ec2e340bda2c050b23b1bf2e6a28565dc9 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 22:54:27 +0100 Subject: [PATCH] linear algebra --- .../java/space/kscience/kmath/noa/JNoa.java | 12 ++++- .../space/kscience/kmath/noa/algebras.kt | 25 ++++++++++- .../resources/space_kscience_kmath_noa_JNoa.h | 44 ++++++++++++++++++- 3 files changed, 77 insertions(+), 4 deletions(-) diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index e6e433f7b..fc65cd3b5 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -255,9 +255,19 @@ class JNoa { public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2); + public static native long detTensor(long tensorHandle); + + public static native long invTensor(long tensorHandle); + + public static native long choleskyTensor(long tensorHandle); + + public static native void qrTensor(long tensorHandle, long Qhandle, long Rhandle); + + public static native void luTensor(long tensorHandle, long Phandle, long Lhandle, long Uhandle); + public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle); - public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle); + public static native void symEigTensor(long tensorHandle, long Shandle, long Vhandle); public static native boolean requiresGrad(long tensorHandle); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index d477716fd..27bad3ff3 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -220,6 +220,29 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear override fun Tensor.floor(): TensorType = wrap(JNoa.floorTensor(tensor.tensorHandle)) + override fun Tensor.det(): Tensor = + wrap(JNoa.detTensor(tensor.tensorHandle)) + + override fun Tensor.inv(): Tensor = + wrap(JNoa.invTensor(tensor.tensorHandle)) + + override fun Tensor.cholesky(): Tensor = + wrap(JNoa.choleskyTensor(tensor.tensorHandle)) + + override fun Tensor.qr(): Pair { + val Q = JNoa.emptyTensor() + val R = JNoa.emptyTensor() + JNoa.qrTensor(tensor.tensorHandle, Q, R) + return Pair(wrap(Q), wrap(R)) + } + + override fun Tensor.lu(): Triple { + val P = JNoa.emptyTensor() + val L = JNoa.emptyTensor() + val U = JNoa.emptyTensor() + JNoa.svdTensor(tensor.tensorHandle, P, L, U) + return Triple(wrap(P), wrap(L), wrap(U)) + } override fun Tensor.svd(): Triple { val U = JNoa.emptyTensor() @@ -232,7 +255,7 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear override fun Tensor.symEig(): Pair { val V = JNoa.emptyTensor() val S = JNoa.emptyTensor() - JNoa.symeigTensor(tensor.tensorHandle, S, V) + JNoa.symEigTensor(tensor.tensorHandle, S, V) return Pair(wrap(S), wrap(V)) } diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index c31bb622f..f45c76982 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -943,6 +943,46 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_matmulRightAssign JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_diagEmbed (JNIEnv *, jclass, jlong, jint, jint, jint); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: detTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_detTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: invTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_invTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: choleskyTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_choleskyTensor + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: qrTensor + * Signature: (JJJ)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_qrTensor + (JNIEnv *, jclass, jlong, jlong, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: luTensor + * Signature: (JJJJ)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_luTensor + (JNIEnv *, jclass, jlong, jlong, jlong, jlong); + /* * Class: space_kscience_kmath_noa_JNoa * Method: svdTensor @@ -953,10 +993,10 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_svdTensor /* * Class: space_kscience_kmath_noa_JNoa - * Method: symeigTensor + * Method: symEigTensor * Signature: (JJJ)V */ -JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symeigTensor +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symEigTensor (JNIEnv *, jclass, jlong, jlong, jlong); /*