NoaDoubleAlgebra

This commit is contained in:
Roland Grinis 2021-07-08 23:10:59 +01:00
parent 5da017ec2e
commit 62b3ccd111
3 changed files with 88 additions and 1 deletions

View File

@ -107,6 +107,10 @@ class JNoa {
public static native long randnFloat(int[] shape, int device);
public static native long randintDouble(long low, long high, int[] shape, int device);
public static native long randintFloat(long low, long high, int[] shape, int device);
public static native long randintLong(long low, long high, int[] shape, int device);
public static native long randintInt(long low, long high, int[] shape, int device);

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.noa
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
@ -240,7 +241,7 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
val P = JNoa.emptyTensor()
val L = JNoa.emptyTensor()
val U = JNoa.emptyTensor()
JNoa.svdTensor(tensor.tensorHandle, P, L, U)
JNoa.luTensor(tensor.tensorHandle, P, L, U)
return Triple(wrap(P), wrap(L), wrap(U))
}
@ -272,4 +273,70 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
}
public class NoaDoubleAlgebra(scope: NoaScope) :
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
override val Tensor<Double>.tensor: NoaDoubleTensor
get() = TODO("Not yet implemented")
override fun wrap(tensorHandle: Long): NoaDoubleTensor =
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Double>.copyToArray(): DoubleArray =
tensor.elements().map { it.second }.toList().toDoubleArray()
public fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
public fun randNormalDouble(shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randnDouble(shape, device.toInt()))
public fun randUniformDouble(shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randDouble(shape, device.toInt()))
public fun randDiscreteDouble(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.randintDouble(low, high, shape, device.toInt()))
override operator fun Double.plus(other: Tensor<Double>): NoaDoubleTensor =
wrap(JNoa.plusDouble(this, other.tensor.tensorHandle))
override fun Tensor<Double>.plus(value: Double): NoaDoubleTensor =
wrap(JNoa.plusDouble(value, tensor.tensorHandle))
override fun Tensor<Double>.plusAssign(value: Double): Unit =
JNoa.plusDoubleAssign(value, tensor.tensorHandle)
override operator fun Double.minus(other: Tensor<Double>): NoaDoubleTensor =
wrap(JNoa.plusDouble(-this, other.tensor.tensorHandle))
override fun Tensor<Double>.minus(value: Double): NoaDoubleTensor =
wrap(JNoa.plusDouble(-value, tensor.tensorHandle))
override fun Tensor<Double>.minusAssign(value: Double): Unit =
JNoa.plusDoubleAssign(-value, tensor.tensorHandle)
override operator fun Double.times(other: Tensor<Double>): NoaDoubleTensor =
wrap(JNoa.timesDouble(this, other.tensor.tensorHandle))
override fun Tensor<Double>.times(value: Double): NoaDoubleTensor =
wrap(JNoa.timesDouble(value, tensor.tensorHandle))
override fun Tensor<Double>.timesAssign(value: Double): Unit =
JNoa.timesDoubleAssign(value, tensor.tensorHandle)
override fun Double.div(other: Tensor<Double>): NoaDoubleTensor =
other * (1 / this)
override fun Tensor<Double>.div(value: Double): NoaDoubleTensor =
tensor * (1 / value)
override fun Tensor<Double>.divAssign(value: Double): Unit =
tensor.timesAssign(1 / value)
public fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fullDouble(value, shape, device.toInt()))
}

View File

@ -351,6 +351,22 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randFloat
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randnFloat
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randintDouble
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randintDouble
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randintFloat
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randintFloat
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randintLong