forked from kscience/kmath
jit modules
This commit is contained in:
parent
1ad20cb143
commit
e4300d0530
@ -280,4 +280,23 @@ class JNoa {
|
||||
|
||||
public static native long autoHessTensor(long value, long variable);
|
||||
|
||||
public static native void backwardPass(long tensorHandle);
|
||||
|
||||
public static native long tensorGrad(long tensorHandle);
|
||||
|
||||
public static native void disposeJitModule(long jitModuleHandle);
|
||||
|
||||
public static native void trainMode(long jitModuleHandle, boolean status);
|
||||
|
||||
public static native long loadJitModuleDouble(String path, int device);
|
||||
|
||||
public static native long loadJitModuleFloat(String path, int device);
|
||||
|
||||
public static native long loadJitModuleLong(String path, int device);
|
||||
|
||||
public static native long loadJitModuleInt(String path, int device);
|
||||
|
||||
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
|
||||
|
||||
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
|
||||
}
|
||||
|
@ -0,0 +1,17 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.noa
|
||||
|
||||
import space.kscience.kmath.noa.memory.NoaResource
|
||||
import space.kscience.kmath.noa.memory.NoaScope
|
||||
|
||||
internal typealias JitModuleHandle = Long
|
||||
|
||||
public class NoaJitModule
|
||||
internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle)
|
||||
: NoaResource(scope){
|
||||
override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle)
|
||||
}
|
@ -5,6 +5,7 @@
|
||||
|
||||
package space.kscience.kmath.noa
|
||||
|
||||
import com.sun.security.auth.module.JndiLoginModule
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.noa.memory.NoaScope
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
@ -139,6 +140,14 @@ protected constructor(protected val scope: NoaScope) :
|
||||
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
|
||||
wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt()))
|
||||
|
||||
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
||||
|
||||
public fun NoaJitModule.forward(parameters: TensorType): TensorType =
|
||||
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensorHandle))
|
||||
|
||||
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
||||
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||
@ -278,17 +287,24 @@ protected constructor(scope: NoaScope) :
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType {
|
||||
return wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
|
||||
}
|
||||
public fun TensorType.autoGradient(variable: TensorType, retainGraph: Boolean): TensorType =
|
||||
wrap(JNoa.autoGradTensor(tensorHandle, variable.tensorHandle, retainGraph))
|
||||
|
||||
public infix fun TensorType.hess(variable: TensorType): TensorType {
|
||||
return wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle))
|
||||
}
|
||||
public fun TensorType.autoHessian(variable: TensorType): TensorType =
|
||||
wrap(JNoa.autoHessTensor(tensorHandle, variable.tensorHandle))
|
||||
|
||||
public fun TensorType.detachFromGraph(): TensorType =
|
||||
wrap(JNoa.detachFromGraph(tensorHandle))
|
||||
|
||||
public fun TensorType.backward(): Unit =
|
||||
JNoa.backwardPass(tensorHandle)
|
||||
|
||||
public fun TensorType.grad(): TensorType =
|
||||
wrap(JNoa.tensorGrad(tensorHandle))
|
||||
|
||||
public fun NoaJitModule.train(status: Boolean): Unit =
|
||||
JNoa.trainMode(this.jitModuleHandle, status)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaDoubleAlgebra
|
||||
@ -365,6 +381,8 @@ protected constructor(scope: NoaScope) :
|
||||
override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.fullDouble(value, shape, device.toInt()))
|
||||
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt()))
|
||||
}
|
||||
|
||||
public sealed class NoaFloatAlgebra
|
||||
@ -441,6 +459,9 @@ protected constructor(scope: NoaScope) :
|
||||
override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.fullFloat(value, shape, device.toInt()))
|
||||
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaLongAlgebra
|
||||
@ -502,6 +523,9 @@ protected constructor(scope: NoaScope) :
|
||||
override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
|
||||
wrap(JNoa.fullLong(value, shape, device.toInt()))
|
||||
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaIntAlgebra
|
||||
@ -563,4 +587,7 @@ protected constructor(scope: NoaScope) :
|
||||
override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
|
||||
wrap(JNoa.fullInt(value, shape, device.toInt()))
|
||||
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt()))
|
||||
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||
internal typealias TensorHandle = Long
|
||||
|
||||
public sealed class NoaTensor<T>
|
||||
constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
|
||||
protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
|
||||
NoaResource(scope), Tensor<T> {
|
||||
|
||||
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
|
||||
@ -69,7 +69,7 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
|
||||
}
|
||||
|
||||
public sealed class NoaTensorOverField<T>
|
||||
constructor(scope: NoaScope, tensorHandle: Long) :
|
||||
protected constructor(scope: NoaScope, tensorHandle: Long) :
|
||||
NoaTensor<T>(scope, tensorHandle) {
|
||||
public var requiresGrad: Boolean
|
||||
get() = JNoa.requiresGrad(tensorHandle)
|
||||
|
Loading…
Reference in New Issue
Block a user