jit modules

This commit is contained in:
Roland Grinis 2021-07-12 15:48:07 +01:00
parent 1ad20cb143
commit e4300d0530
4 changed files with 71 additions and 8 deletions

View File

@ -280,4 +280,23 @@ class JNoa {
public static native long autoHessTensor(long value, long variable);
public static native void backwardPass(long tensorHandle);
public static native long tensorGrad(long tensorHandle);
public static native void disposeJitModule(long jitModuleHandle);
public static native void trainMode(long jitModuleHandle, boolean status);
public static native long loadJitModuleDouble(String path, int device);
public static native long loadJitModuleFloat(String path, int device);
public static native long loadJitModuleLong(String path, int device);
public static native long loadJitModuleInt(String path, int device);
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa
import space.kscience.kmath.noa.memory.NoaResource
import space.kscience.kmath.noa.memory.NoaScope
internal typealias JitModuleHandle = Long
public class NoaJitModule
internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle)
: NoaResource(scope){
override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle)
}

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.noa
import com.sun.security.auth.module.JndiLoginModule
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
@ -139,6 +140,14 @@ protected constructor(protected val scope: NoaScope) :
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt()))
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
public fun NoaJitModule.forward(parameters: TensorType): TensorType =
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensorHandle))
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
}
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
@ -278,17 +287,24 @@ protected constructor(scope: NoaScope) :
return Pair(wrap(S), wrap(V))
}
public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType {
return wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
}
public fun TensorType.autoGradient(variable: TensorType, retainGraph: Boolean): TensorType =
wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
public infix fun TensorType.hess(variable: TensorType): TensorType {
return wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle))
}
public fun TensorType.autoHessian(variable: TensorType): TensorType =
wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle))
public fun TensorType.detachFromGraph(): TensorType =
wrap(JNoa.detachFromGraph(tensorHandle))
public fun TensorType.backward(): Unit =
JNoa.backwardPass(tensorHandle)
public fun TensorType.grad(): TensorType =
wrap(JNoa.tensorGrad(tensorHandle))
public fun NoaJitModule.train(status: Boolean): Unit =
JNoa.trainMode(this.jitModuleHandle, status)
}
public sealed class NoaDoubleAlgebra
@ -365,6 +381,8 @@ protected constructor(scope: NoaScope) :
override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
wrap(JNoa.fullDouble(value, shape, device.toInt()))
override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt()))
}
public sealed class NoaFloatAlgebra
@ -441,6 +459,9 @@ protected constructor(scope: NoaScope) :
override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fullFloat(value, shape, device.toInt()))
override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt()))
}
public sealed class NoaLongAlgebra
@ -502,6 +523,9 @@ protected constructor(scope: NoaScope) :
override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fullLong(value, shape, device.toInt()))
override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt()))
}
public sealed class NoaIntAlgebra
@ -563,4 +587,7 @@ protected constructor(scope: NoaScope) :
override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fullInt(value, shape, device.toInt()))
override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt()))
}

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.core.TensorLinearStructure
internal typealias TensorHandle = Long
public sealed class NoaTensor<T>
constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
NoaResource(scope), Tensor<T> {
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
@ -69,7 +69,7 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
}
public sealed class NoaTensorOverField<T>
constructor(scope: NoaScope, tensorHandle: Long) :
protected constructor(scope: NoaScope, tensorHandle: Long) :
NoaTensor<T>(scope, tensorHandle) {
public var requiresGrad: Boolean
get() = JNoa.requiresGrad(tensorHandle)