linear algebra
This commit is contained in:
parent
f328c7f266
commit
5da017ec2e
@ -255,9 +255,19 @@ class JNoa {
|
||||
|
||||
public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2);
|
||||
|
||||
public static native long detTensor(long tensorHandle);
|
||||
|
||||
public static native long invTensor(long tensorHandle);
|
||||
|
||||
public static native long choleskyTensor(long tensorHandle);
|
||||
|
||||
public static native void qrTensor(long tensorHandle, long Qhandle, long Rhandle);
|
||||
|
||||
public static native void luTensor(long tensorHandle, long Phandle, long Lhandle, long Uhandle);
|
||||
|
||||
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
|
||||
|
||||
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle);
|
||||
public static native void symEigTensor(long tensorHandle, long Shandle, long Vhandle);
|
||||
|
||||
public static native boolean requiresGrad(long tensorHandle);
|
||||
|
||||
|
@ -220,6 +220,29 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
|
||||
override fun Tensor<T>.floor(): TensorType =
|
||||
wrap(JNoa.floorTensor(tensor.tensorHandle))
|
||||
|
||||
override fun Tensor<T>.det(): Tensor<T> =
|
||||
wrap(JNoa.detTensor(tensor.tensorHandle))
|
||||
|
||||
override fun Tensor<T>.inv(): Tensor<T> =
|
||||
wrap(JNoa.invTensor(tensor.tensorHandle))
|
||||
|
||||
override fun Tensor<T>.cholesky(): Tensor<T> =
|
||||
wrap(JNoa.choleskyTensor(tensor.tensorHandle))
|
||||
|
||||
override fun Tensor<T>.qr(): Pair<TensorType, TensorType> {
|
||||
val Q = JNoa.emptyTensor()
|
||||
val R = JNoa.emptyTensor()
|
||||
JNoa.qrTensor(tensor.tensorHandle, Q, R)
|
||||
return Pair(wrap(Q), wrap(R))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.lu(): Triple<TensorType, TensorType, TensorType> {
|
||||
val P = JNoa.emptyTensor()
|
||||
val L = JNoa.emptyTensor()
|
||||
val U = JNoa.emptyTensor()
|
||||
JNoa.svdTensor(tensor.tensorHandle, P, L, U)
|
||||
return Triple(wrap(P), wrap(L), wrap(U))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.svd(): Triple<TensorType, TensorType, TensorType> {
|
||||
val U = JNoa.emptyTensor()
|
||||
@ -232,7 +255,7 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
|
||||
override fun Tensor<T>.symEig(): Pair<TensorType, TensorType> {
|
||||
val V = JNoa.emptyTensor()
|
||||
val S = JNoa.emptyTensor()
|
||||
JNoa.symeigTensor(tensor.tensorHandle, S, V)
|
||||
JNoa.symEigTensor(tensor.tensorHandle, S, V)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
|
@ -943,6 +943,46 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_matmulRightAssign
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_diagEmbed
|
||||
(JNIEnv *, jclass, jlong, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: detTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_detTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: invTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_invTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: choleskyTensor
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_choleskyTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: qrTensor
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_qrTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: luTensor
|
||||
* Signature: (JJJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_luTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: svdTensor
|
||||
@ -953,10 +993,10 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_svdTensor
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: symeigTensor
|
||||
* Method: symEigTensor
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symeigTensor
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symEigTensor
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
|
Loading…
Reference in New Issue
Block a user