From 62b3ccd111f2b4e5804bd6569eebc1db724fc9a1 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 23:10:59 +0100 Subject: [PATCH] NoaDoubleAlgebra --- .../java/space/kscience/kmath/noa/JNoa.java | 4 ++ .../space/kscience/kmath/noa/algebras.kt | 69 ++++++++++++++++++- .../resources/space_kscience_kmath_noa_JNoa.h | 16 +++++ 3 files changed, 88 insertions(+), 1 deletion(-) diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index fc65cd3b5..55566d753 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -107,6 +107,10 @@ class JNoa { public static native long randnFloat(int[] shape, int device); + public static native long randintDouble(long low, long high, int[] shape, int device); + + public static native long randintFloat(long low, long high, int[] shape, int device); + public static native long randintLong(long low, long high, int[] shape, int device); public static native long randintInt(long low, long high, int[] shape, int device); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 27bad3ff3..1ae6ae2e9 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.noa +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.noa.memory.NoaScope import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra @@ -240,7 +241,7 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear val P = JNoa.emptyTensor() val L = JNoa.emptyTensor() val U = JNoa.emptyTensor() - JNoa.svdTensor(tensor.tensorHandle, P, L, U) + JNoa.luTensor(tensor.tensorHandle, P, L, U) return Triple(wrap(P), wrap(L), wrap(U)) } @@ -272,4 +273,70 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear } +public class NoaDoubleAlgebra(scope: NoaScope) : + NoaPartialDivisionAlgebra(scope) { + + override val Tensor.tensor: NoaDoubleTensor + get() = TODO("Not yet implemented") + + override fun wrap(tensorHandle: Long): NoaDoubleTensor = + NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle) + + @PerformancePitfall + public fun Tensor.copyToArray(): DoubleArray = + tensor.elements().map { it.second }.toList().toDoubleArray() + + public fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor = + wrap(JNoa.fromBlobDouble(array, shape, device.toInt())) + + public fun randNormalDouble(shape: IntArray, device: Device): NoaDoubleTensor = + wrap(JNoa.randnDouble(shape, device.toInt())) + + public fun randUniformDouble(shape: IntArray, device: Device): NoaDoubleTensor = + wrap(JNoa.randDouble(shape, device.toInt())) + + public fun randDiscreteDouble(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor = + wrap(JNoa.randintDouble(low, high, shape, device.toInt())) + + override operator fun Double.plus(other: Tensor): NoaDoubleTensor = + wrap(JNoa.plusDouble(this, other.tensor.tensorHandle)) + + override fun Tensor.plus(value: Double): NoaDoubleTensor = + wrap(JNoa.plusDouble(value, tensor.tensorHandle)) + + override fun Tensor.plusAssign(value: Double): Unit = + JNoa.plusDoubleAssign(value, tensor.tensorHandle) + + override operator fun Double.minus(other: Tensor): NoaDoubleTensor = + wrap(JNoa.plusDouble(-this, other.tensor.tensorHandle)) + + override fun Tensor.minus(value: Double): NoaDoubleTensor = + wrap(JNoa.plusDouble(-value, tensor.tensorHandle)) + + override fun Tensor.minusAssign(value: Double): Unit = + JNoa.plusDoubleAssign(-value, tensor.tensorHandle) + + override operator fun Double.times(other: Tensor): NoaDoubleTensor = + wrap(JNoa.timesDouble(this, other.tensor.tensorHandle)) + + override fun Tensor.times(value: Double): NoaDoubleTensor = + wrap(JNoa.timesDouble(value, tensor.tensorHandle)) + + override fun Tensor.timesAssign(value: Double): Unit = + JNoa.timesDoubleAssign(value, tensor.tensorHandle) + + override fun Double.div(other: Tensor): NoaDoubleTensor = + other * (1 / this) + + override fun Tensor.div(value: Double): NoaDoubleTensor = + tensor * (1 / value) + + override fun Tensor.divAssign(value: Double): Unit = + tensor.timesAssign(1 / value) + + public fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor = + wrap(JNoa.fullDouble(value, shape, device.toInt())) + +} + diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index f45c76982..483f457a7 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -351,6 +351,22 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randFloat JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randnFloat (JNIEnv *, jclass, jintArray, jint); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: randintDouble + * Signature: (JJ[II)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randintDouble + (JNIEnv *, jclass, jlong, jlong, jintArray, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: randintFloat + * Signature: (JJ[II)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randintFloat + (JNIEnv *, jclass, jlong, jlong, jintArray, jint); + /* * Class: space_kscience_kmath_noa_JNoa * Method: randintLong