forked from kscience/kmath
optimiser
This commit is contained in:
parent
1029871047
commit
3dd915e2fb
@ -307,4 +307,14 @@ class JNoa {
|
||||
public static native long getModuleBuffer(long jitModuleHandle, String name);
|
||||
|
||||
public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle);
|
||||
|
||||
public static native long adamOptim(long jitModuleHandle, double learningRate);
|
||||
|
||||
public static native void disposeAdamOptim(long adamOptHandle);
|
||||
|
||||
public static native void stepAdamOptim(long adamOptHandle);
|
||||
|
||||
public static native void zeroGradAdamOptim(long adamOptHandle);
|
||||
|
||||
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||
}
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
package space.kscience.kmath.noa
|
||||
|
||||
import com.sun.security.auth.module.JndiLoginModule
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.noa.memory.NoaScope
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
@ -143,22 +142,25 @@ protected constructor(protected val scope: NoaScope) :
|
||||
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
||||
|
||||
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
||||
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle))
|
||||
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
|
||||
|
||||
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
||||
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
|
||||
JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle)
|
||||
|
||||
public fun NoaJitModule.getParameter(name: String): TensorType =
|
||||
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
||||
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
||||
|
||||
public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit =
|
||||
JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle)
|
||||
JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle)
|
||||
|
||||
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
||||
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
||||
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
||||
|
||||
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
||||
JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle)
|
||||
JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle)
|
||||
|
||||
public infix fun TensorType.swap(other: TensorType): Unit =
|
||||
JNoa.swapTensors(tensorHandle, other.tensorHandle)
|
||||
|
||||
}
|
||||
|
||||
@ -315,7 +317,10 @@ protected constructor(scope: NoaScope) :
|
||||
wrap(JNoa.tensorGrad(tensorHandle))
|
||||
|
||||
public fun NoaJitModule.train(status: Boolean): Unit =
|
||||
JNoa.trainMode(this.jitModuleHandle, status)
|
||||
JNoa.trainMode(jitModuleHandle, status)
|
||||
|
||||
public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser =
|
||||
AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate))
|
||||
}
|
||||
|
||||
public sealed class NoaDoubleAlgebra
|
||||
|
25
kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt
Normal file
25
kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt
Normal file
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.noa
|
||||
|
||||
import space.kscience.kmath.noa.memory.NoaResource
|
||||
import space.kscience.kmath.noa.memory.NoaScope
|
||||
|
||||
internal typealias OptimiserHandle = Long
|
||||
|
||||
public abstract class NoaOptimiser
|
||||
internal constructor(scope: NoaScope) : NoaResource(scope) {
|
||||
public abstract fun step(): Unit
|
||||
public abstract fun zeroGrad(): Unit
|
||||
}
|
||||
|
||||
public class AdamOptimiser
|
||||
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||
: NoaOptimiser(scope) {
|
||||
override fun dispose(): Unit = JNoa.disposeAdamOptim(optimiserHandle)
|
||||
override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle)
|
||||
override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle)
|
||||
}
|
@ -1159,6 +1159,46 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer
|
||||
(JNIEnv *, jclass, jlong, jstring, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: adamOptim
|
||||
* Signature: (JD)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_adamOptim
|
||||
(JNIEnv *, jclass, jlong, jdouble);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: disposeAdamOptim
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeAdamOptim
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: stepAdamOptim
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_stepAdamOptim
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: zeroGradAdamOptim
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: swapTensors
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user