tensors implementation

This commit is contained in:
Roland Grinis 2021-07-08 09:37:16 +01:00
parent 78b1cd41da
commit 40d8a05bb0
11 changed files with 1135 additions and 66 deletions

View File

@ -33,5 +33,186 @@ class JNoa {
public static native void disposeTensor(long tensorHandle); public static native void disposeTensor(long tensorHandle);
} public static native long fromBlobDouble(double[] data, int[] shape, int device);
public static native long fromBlobFloat(float[] data, int[] shape, int device);
public static native long fromBlobLong(long[] data, int[] shape, int device);
public static native long fromBlobInt(int[] data, int[] shape, int device);
public static native long copyTensor(long tensorHandle);
public static native long copyToDevice(long tensorHandle, int device);
public static native long copyToDouble(long tensorHandle);
public static native long copyToFloat(long tensorHandle);
public static native long copyToLong(long tensorHandle);
public static native long copyToInt(long tensorHandle);
public static native long viewTensor(long tensorHandle, int[] shape);
public static native String tensorToString(long tensorHandle);
public static native int getDim(long tensorHandle);
public static native int getNumel(long tensorHandle);
public static native int getShapeAt(long tensorHandle, int d);
public static native int getStrideAt(long tensorHandle, int d);
public static native int getDevice(long tensorHandle);
public static native double getItemDouble(long tensorHandle);
public static native float getItemFloat(long tensorHandle);
public static native long getItemLong(long tensorHandle);
public static native int getItemInt(long tensorHandle);
public static native double getDouble(long tensorHandle, int[] index);
public static native float getFloat(long tensorHandle, int[] index);
public static native long getLong(long tensorHandle, int[] index);
public static native int getInt(long tensorHandle, int[] index);
public static native void setDouble(long tensorHandle, int[] index, double value);
public static native void setFloat(long tensorHandle, int[] index, float value);
public static native void setLong(long tensorHandle, int[] index, long value);
public static native void setInt(long tensorHandle, int[] index, int value);
public static native long randDouble(int[] shape, int device);
public static native long randnDouble(int[] shape, int device);
public static native long randFloat(int[] shape, int device);
public static native long randnFloat(int[] shape, int device);
public static native long randintLong(long low, long high, int[] shape, int device);
public static native long randintInt(long low, long high, int[] shape, int device);
public static native long randLike(long tensorHandle);
public static native void randLikeAssign(long tensorHandle);
public static native long randnLike(long tensorHandle);
public static native void randnLikeAssign(long tensorHandle);
public static native long randintLike(long tensorHandle, long low, long high);
public static native void randintLikeAssign(long tensorHandle, long low, long high);
public static native long fullDouble(double value, int[] shape, int device);
public static native long fullFloat(float value, int[] shape, int device);
public static native long fullLong(long value, int[] shape, int device);
public static native long fullInt(int value, int[] shape, int device);
public static native long timesDouble(double value, long other);
public static native long timesFloat(float value, long other);
public static native long timesLong(long value, long other);
public static native long timesInt(int value, long other);
public static native void timesDoubleAssign(double value, long other);
public static native void timesFloatAssign(float value, long other);
public static native void timesLongAssign(long value, long other);
public static native void timesIntAssign(int value, long other);
public static native long plusDouble(double value, long other);
public static native long plusFloat(float value, long other);
public static native long plusLong(long value, long other);
public static native long plusInt(int value, long other);
public static native void plusDoubleAssign(double value, long other);
public static native void plusFloatAssign(float value, long other);
public static native void plusLongAssign(long value, long other);
public static native void plusIntAssign(int value, long other);
public static native long timesTensor(long lhs, long rhs);
public static native void timesTensorAssign(long lhs, long rhs);
public static native long divTensor(long lhs, long rhs);
public static native void divTensorAssign(long lhs, long rhs);
public static native long plusTensor(long lhs, long rhs);
public static native void plusTensorAssign(long lhs, long rhs);
public static native long minusTensor(long lhs, long rhs);
public static native void minusTensorAssign(long lhs, long rhs);
public static native long unaryMinus(long tensorHandle);
public static native long absTensor(long tensorHandle);
public static native void absTensorAssign(long tensorHandle);
public static native long transposeTensor(long tensorHandle, int i, int j);
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
public static native long expTensor(long tensorHandle);
public static native void expTensorAssign(long tensorHandle);
public static native long logTensor(long tensorHandle);
public static native void logTensorAssign(long tensorHandle);
public static native long sumTensor(long tensorHandle);
public static native void sumTensorAssign(long tensorHandle);
public static native long matmul(long lhs, long rhs);
public static native void matmulAssign(long lhs, long rhs);
public static native void matmulRightAssign(long lhs, long rhs);
public static native long diagEmbed(long diagsHandle, int offset, int dim1, int dim2);
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors);
public static native boolean requiresGrad(long tensorHandle);
public static native void setRequiresGrad(long tensorHandle, boolean status);
public static native long detachFromGraph(long tensorHandle);
public static native long autogradTensor(long value, long variable, boolean retainGraph);
public static native long autohessTensor(long value, long variable);
}

View File

@ -5,4 +5,24 @@
package space.kscience.kmath.noa package space.kscience.kmath.noa
public sealed class NoaAlgebra<T>{} import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
public abstract class NoaAlgebra<T, TensorType: NoaTensor<T>>
internal constructor(protected val scope: NoaScope)
: TensorAlgebra<T> {
protected abstract fun Tensor<T>.cast(): TensorType
}
public abstract class NoaPartialDivisionAlgebra<T, TensorType: NoaTensor<T>>
internal constructor(scope: NoaScope)
: NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>, AnalyticTensorAlgebra<T> {
}

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.noa.memory package space.kscience.kmath.noa.memory
public abstract class NoaResource public abstract class NoaResource
internal constructor(scope: NoaScope) { internal constructor(internal val scope: NoaScope) {
init { init {
scope.add(::dispose) scope.add(::dispose)
} }

View File

@ -26,7 +26,7 @@ public class NoaScope {
} }
} }
internal inline fun <R> withNoaScope(i: Int, block: NoaScope.() -> R): R { internal inline fun <R> withNoaScope(block: NoaScope.() -> R): R {
val noaScope = NoaScope() val noaScope = NoaScope()
val result = noaScope.block() val result = noaScope.block()
noaScope.disposeAll() noaScope.disposeAll()

View File

@ -5,19 +5,159 @@
package space.kscience.kmath.noa package space.kscience.kmath.noa
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.noa.memory.NoaResource import space.kscience.kmath.noa.memory.NoaResource
import space.kscience.kmath.noa.memory.NoaScope import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.TensorLinearStructure
internal typealias TensorHandle = Long internal typealias TensorHandle = Long
public sealed class NoaTensor<T> public sealed class NoaTensor<T>
constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) : constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
NoaResource(scope){ NoaResource(scope), Tensor<T> {
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle) override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
override val dimension: Int get() = JNoa.getDim(tensorHandle)
override val shape: IntArray
get() = (1..dimension).map { JNoa.getShapeAt(tensorHandle, it - 1) }.toIntArray()
public val strides: IntArray
get() = (1..dimension).map { JNoa.getStrideAt(tensorHandle, it - 1) }.toIntArray()
public val numElements: Int get() = JNoa.getNumel(tensorHandle)
public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle))
override fun toString(): String = JNoa.tensorToString(tensorHandle)
@PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> {
if (dimension == 0) {
return emptySequence()
}
val indices = (1..numElements).asSequence().map {
TensorLinearStructure.indexFromOffset(it - 1, strides, dimension)
}
return indices.map { it to get(it) }
}
public fun copyToDouble(): NoaDoubleTensor = NoaDoubleTensor(
scope = scope,
tensorHandle = JNoa.copyToDouble(this.tensorHandle)
)
/*
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
scope = scope,
tensorHandle = JTorch.copyToFloat(this.tensorHandle)
)
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
scope = scope,
tensorHandle = JTorch.copyToLong(this.tensorHandle)
)
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
scope = scope,
tensorHandle = JTorch.copyToInt(this.tensorHandle)
)*/
} }
public sealed class NoaTensorOverField<T>
constructor(scope: NoaScope, tensorHandle: Long) :
NoaTensor<T>(scope, tensorHandle) {
public var requiresGrad: Boolean
get() = JNoa.requiresGrad(tensorHandle)
set(value) = JNoa.setRequiresGrad(tensorHandle, value)
}
public class NoaDoubleTensor public class NoaDoubleTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensor<Double>(scope, tensorHandle) NoaTensorOverField<Double>(scope, tensorHandle) {
internal fun item(): Double = JNoa.getItemDouble(tensorHandle)
@PerformancePitfall
override fun get(index: IntArray): Double = JNoa.getDouble(tensorHandle, index)
@PerformancePitfall
override fun set(index: IntArray, value: Double) {
JNoa.setDouble(tensorHandle, index, value)
}
}
public class NoaFloatTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensorOverField<Float>(scope, tensorHandle) {
internal fun item(): Float = JNoa.getItemFloat(tensorHandle)
@PerformancePitfall
override fun get(index: IntArray): Float = JNoa.getFloat(tensorHandle, index)
@PerformancePitfall
override fun set(index: IntArray, value: Float) {
JNoa.setFloat(tensorHandle, index, value)
}
}
public class NoaLongTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensor<Long>(scope, tensorHandle) {
internal fun item(): Long = JNoa.getItemLong(tensorHandle)
@PerformancePitfall
override fun get(index: IntArray): Long = JNoa.getLong(tensorHandle, index)
@PerformancePitfall
override fun set(index: IntArray, value: Long) {
JNoa.setLong(tensorHandle, index, value)
}
}
public class NoaIntTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensor<Int>(scope, tensorHandle) {
internal fun item(): Int = JNoa.getItemInt(tensorHandle)
@PerformancePitfall
override fun get(index: IntArray): Int = JNoa.getInt(tensorHandle, index)
@PerformancePitfall
override fun set(index: IntArray, value: Int) {
JNoa.setInt(tensorHandle, index, value)
}
}
public sealed class Device {
public object CPU : Device() {
override fun toString(): String {
return "CPU"
}
}
public data class CUDA(val index: Int) : Device()
public fun toInt(): Int {
when (this) {
is CPU -> return 0
is CUDA -> return this.index + 1
}
}
public companion object {
public fun fromInt(deviceInt: Int): Device {
return if (deviceInt == 0) CPU else CUDA(
deviceInt - 1
)
}
}
}

View File

@ -55,6 +55,734 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSeed
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeTensor JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fromBlobDouble
* Signature: ([D[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fromBlobDouble
(JNIEnv *, jclass, jdoubleArray, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fromBlobFloat
* Signature: ([F[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fromBlobFloat
(JNIEnv *, jclass, jfloatArray, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fromBlobLong
* Signature: ([J[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fromBlobLong
(JNIEnv *, jclass, jlongArray, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fromBlobInt
* Signature: ([I[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fromBlobInt
(JNIEnv *, jclass, jintArray, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: copyTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: copyToDevice
* Signature: (JI)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToDevice
(JNIEnv *, jclass, jlong, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: copyToDouble
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToDouble
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: copyToFloat
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToFloat
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: copyToLong
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToLong
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: copyToInt
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToInt
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: viewTensor
* Signature: (J[I)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewTensor
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: tensorToString
* Signature: (J)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_space_kscience_kmath_noa_JNoa_tensorToString
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getDim
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getDim
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getNumel
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getNumel
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getShapeAt
* Signature: (JI)I
*/
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getShapeAt
(JNIEnv *, jclass, jlong, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getStrideAt
* Signature: (JI)I
*/
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getStrideAt
(JNIEnv *, jclass, jlong, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getDevice
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getDevice
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getItemDouble
* Signature: (J)D
*/
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_noa_JNoa_getItemDouble
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getItemFloat
* Signature: (J)F
*/
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_noa_JNoa_getItemFloat
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getItemLong
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getItemLong
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getItemInt
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getItemInt
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getDouble
* Signature: (J[I)D
*/
JNIEXPORT jdouble JNICALL Java_space_kscience_kmath_noa_JNoa_getDouble
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getFloat
* Signature: (J[I)F
*/
JNIEXPORT jfloat JNICALL Java_space_kscience_kmath_noa_JNoa_getFloat
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getLong
* Signature: (J[I)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getLong
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getInt
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getInt
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setDouble
* Signature: (J[ID)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setDouble
(JNIEnv *, jclass, jlong, jintArray, jdouble);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setFloat
* Signature: (J[IF)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setFloat
(JNIEnv *, jclass, jlong, jintArray, jfloat);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setLong
* Signature: (J[IJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setLong
(JNIEnv *, jclass, jlong, jintArray, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setInt
* Signature: (J[II)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setInt
(JNIEnv *, jclass, jlong, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randDouble
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randDouble
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randnDouble
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randnDouble
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randFloat
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randFloat
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randnFloat
* Signature: ([II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randnFloat
(JNIEnv *, jclass, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randintLong
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randintLong
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randintInt
* Signature: (JJ[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randintInt
(JNIEnv *, jclass, jlong, jlong, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randLike
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randLike
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randLikeAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_randLikeAssign
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randnLike
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randnLike
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randnLikeAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_randnLikeAssign
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randintLike
* Signature: (JJJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_randintLike
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: randintLikeAssign
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_randintLikeAssign
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fullDouble
* Signature: (D[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fullDouble
(JNIEnv *, jclass, jdouble, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fullFloat
* Signature: (F[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fullFloat
(JNIEnv *, jclass, jfloat, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fullLong
* Signature: (J[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fullLong
(JNIEnv *, jclass, jlong, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: fullInt
* Signature: (I[II)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_fullInt
(JNIEnv *, jclass, jint, jintArray, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesDouble
* Signature: (DJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_timesDouble
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesFloat
* Signature: (FJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_timesFloat
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesLong
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_timesLong
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesInt
* Signature: (IJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_timesInt
(JNIEnv *, jclass, jint, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesDoubleAssign
* Signature: (DJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_timesDoubleAssign
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesFloatAssign
* Signature: (FJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_timesFloatAssign
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesLongAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_timesLongAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesIntAssign
* Signature: (IJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_timesIntAssign
(JNIEnv *, jclass, jint, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusDouble
* Signature: (DJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_plusDouble
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusFloat
* Signature: (FJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_plusFloat
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusLong
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_plusLong
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusInt
* Signature: (IJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_plusInt
(JNIEnv *, jclass, jint, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusDoubleAssign
* Signature: (DJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_plusDoubleAssign
(JNIEnv *, jclass, jdouble, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusFloatAssign
* Signature: (FJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_plusFloatAssign
(JNIEnv *, jclass, jfloat, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusLongAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_plusLongAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusIntAssign
* Signature: (IJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_plusIntAssign
(JNIEnv *, jclass, jint, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_timesTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: timesTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_timesTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: divTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_divTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: divTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_divTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_plusTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: plusTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_plusTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: minusTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_minusTensor
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: minusTensorAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_minusTensorAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: unaryMinus
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: absTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: absTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: transposeTensor
* Signature: (JII)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor
(JNIEnv *, jclass, jlong, jint, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: transposeTensorAssign
* Signature: (JII)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign
(JNIEnv *, jclass, jlong, jint, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: expTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: expTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: logTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: logTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_logTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: sumTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensor
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: sumTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensorAssign
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: matmul
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_matmul
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: matmulAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_matmulAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: matmulRightAssign
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_matmulRightAssign
(JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: diagEmbed
* Signature: (JIII)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_diagEmbed
(JNIEnv *, jclass, jlong, jint, jint, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: svdTensor
* Signature: (JJJJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_svdTensor
(JNIEnv *, jclass, jlong, jlong, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: symeigTensor
* Signature: (JJJZ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_symeigTensor
(JNIEnv *, jclass, jlong, jlong, jlong, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: requiresGrad
* Signature: (J)Z
*/
JNIEXPORT jboolean JNICALL Java_space_kscience_kmath_noa_JNoa_requiresGrad
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setRequiresGrad
* Signature: (JZ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setRequiresGrad
(JNIEnv *, jclass, jlong, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: detachFromGraph
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_detachFromGraph
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: autogradTensor
* Signature: (JJZ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autogradTensor
(JNIEnv *, jclass, jlong, jlong, jboolean);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: autohessTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_autohessTensor
(JNIEnv *, jclass, jlong, jlong);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -5,8 +5,6 @@
package space.kscience.kmath.noa package space.kscience.kmath.noa
import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.noa.memory.withNoaScope
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -4,7 +4,6 @@ import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Strides import space.kscience.kmath.nd.Strides
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
/** /**
* Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor] * Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor]

View File

@ -0,0 +1,59 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.Strides
import kotlin.math.max
/**
* This [Strides] implementation follows the last dimension first convention
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
*
* @param shape the shape of the tensor.
*/
public class TensorLinearStructure(override val shape: IntArray) : Strides {
override val strides: IntArray
get() = stridesFromShape(shape)
override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size)
override val linearSize: Int
get() = shape.reduce(Int::times)
public companion object {
public fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
return res
var current = nDim - 1
res[current] = 1
while (current > 0) {
res[current - 1] = max(1, shape[current]) * res[current]
current--
}
return res
}
public fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}
}
}

View File

@ -1,57 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.Strides
import kotlin.math.max
internal fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
return res
var current = nDim - 1
res[current] = 1
while (current > 0) {
res[current - 1] = max(1, shape[current]) * res[current]
current--
}
return res
}
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}
/**
* This [Strides] implementation follows the last dimension first convention
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
*
* @param shape the shape of the tensor.
*/
internal class TensorLinearStructure(override val shape: IntArray) : Strides {
override val strides: IntArray
get() = stridesFromShape(shape)
override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size)
override val linearSize: Int
get() = shape.reduce(Int::times)
}

View File

@ -11,6 +11,7 @@ import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.BufferedTensor import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.IntTensor import space.kscience.kmath.tensors.core.IntTensor
import space.kscience.kmath.tensors.core.TensorLinearStructure
internal fun BufferedTensor<Int>.asTensor(): IntTensor = internal fun BufferedTensor<Int>.asTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)