diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index a5c9ba75b..2e96ed03a 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -280,4 +280,23 @@ class JNoa { public static native long autoHessTensor(long value, long variable); + public static native void backwardPass(long tensorHandle); + + public static native long tensorGrad(long tensorHandle); + + public static native void disposeJitModule(long jitModuleHandle); + + public static native void trainMode(long jitModuleHandle, boolean status); + + public static native long loadJitModuleDouble(String path, int device); + + public static native long loadJitModuleFloat(String path, int device); + + public static native long loadJitModuleLong(String path, int device); + + public static native long loadJitModuleInt(String path, int device); + + public static native long forwardPass(long jitModuleHandle, long tensorHandle); + + public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle); } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt new file mode 100644 index 000000000..402cdf012 --- /dev/null +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt @@ -0,0 +1,17 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.noa + +import space.kscience.kmath.noa.memory.NoaResource +import space.kscience.kmath.noa.memory.NoaScope + +internal typealias JitModuleHandle = Long + +public class NoaJitModule +internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle) + : NoaResource(scope){ + override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle) +} \ No newline at end of file diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index fbd89a5a4..47c18f8cf 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.noa +import com.sun.security.auth.module.JndiLoginModule import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.noa.memory.NoaScope import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra @@ -139,6 +140,14 @@ protected constructor(protected val scope: NoaScope) : public fun Tensor.copyToDevice(device: Device): TensorType = wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt())) + public abstract fun loadJitModule(path: String, device: Device): NoaJitModule + + public fun NoaJitModule.forward(parameters: TensorType): TensorType = + wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensorHandle)) + + public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit = + JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle) + } public sealed class NoaPartialDivisionAlgebra> @@ -278,17 +287,24 @@ protected constructor(scope: NoaScope) : return Pair(wrap(S), wrap(V)) } - public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType { - return wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph)) - } + public fun TensorType.autoGradient(variable: TensorType, retainGraph: Boolean): TensorType = + wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph)) - public infix fun TensorType.hess(variable: TensorType): TensorType { - return wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle)) - } + public fun TensorType.autoHessian(variable: TensorType): TensorType = + wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle)) public fun TensorType.detachFromGraph(): TensorType = wrap(JNoa.detachFromGraph(tensorHandle)) + public fun TensorType.backward(): Unit = + JNoa.backwardPass(tensorHandle) + + public fun TensorType.grad(): TensorType = + wrap(JNoa.tensorGrad(tensorHandle)) + + public fun NoaJitModule.train(status: Boolean): Unit = + JNoa.trainMode(this.jitModuleHandle, status) + } public sealed class NoaDoubleAlgebra @@ -365,6 +381,8 @@ protected constructor(scope: NoaScope) : override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor = wrap(JNoa.fullDouble(value, shape, device.toInt())) + override fun loadJitModule(path: String, device: Device): NoaJitModule = + NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt())) } public sealed class NoaFloatAlgebra @@ -441,6 +459,9 @@ protected constructor(scope: NoaScope) : override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor = wrap(JNoa.fullFloat(value, shape, device.toInt())) + override fun loadJitModule(path: String, device: Device): NoaJitModule = + NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt())) + } public sealed class NoaLongAlgebra @@ -502,6 +523,9 @@ protected constructor(scope: NoaScope) : override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor = wrap(JNoa.fullLong(value, shape, device.toInt())) + override fun loadJitModule(path: String, device: Device): NoaJitModule = + NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt())) + } public sealed class NoaIntAlgebra @@ -563,4 +587,7 @@ protected constructor(scope: NoaScope) : override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor = wrap(JNoa.fullInt(value, shape, device.toInt())) + override fun loadJitModule(path: String, device: Device): NoaJitModule = + NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt())) + } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt index d6353f179..3c9508db2 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt @@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.core.TensorLinearStructure internal typealias TensorHandle = Long public sealed class NoaTensor -constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) : +protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) : NoaResource(scope), Tensor { override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle) @@ -69,7 +69,7 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) : } public sealed class NoaTensorOverField -constructor(scope: NoaScope, tensorHandle: Long) : +protected constructor(scope: NoaScope, tensorHandle: Long) : NoaTensor(scope, tensorHandle) { public var requiresGrad: Boolean get() = JNoa.requiresGrad(tensorHandle)