forked from kscience/kmath
optimiser
This commit is contained in:
parent
1029871047
commit
3dd915e2fb
@ -307,4 +307,14 @@ class JNoa {
|
|||||||
public static native long getModuleBuffer(long jitModuleHandle, String name);
|
public static native long getModuleBuffer(long jitModuleHandle, String name);
|
||||||
|
|
||||||
public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle);
|
public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle);
|
||||||
|
|
||||||
|
public static native long adamOptim(long jitModuleHandle, double learningRate);
|
||||||
|
|
||||||
|
public static native void disposeAdamOptim(long adamOptHandle);
|
||||||
|
|
||||||
|
public static native void stepAdamOptim(long adamOptHandle);
|
||||||
|
|
||||||
|
public static native void zeroGradAdamOptim(long adamOptHandle);
|
||||||
|
|
||||||
|
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.noa
|
package space.kscience.kmath.noa
|
||||||
|
|
||||||
import com.sun.security.auth.module.JndiLoginModule
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.noa.memory.NoaScope
|
import space.kscience.kmath.noa.memory.NoaScope
|
||||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||||
@ -143,22 +142,25 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
||||||
|
|
||||||
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
||||||
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle))
|
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
|
||||||
|
|
||||||
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
||||||
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
|
JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle)
|
||||||
|
|
||||||
public fun NoaJitModule.getParameter(name: String): TensorType =
|
public fun NoaJitModule.getParameter(name: String): TensorType =
|
||||||
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
||||||
|
|
||||||
public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit =
|
public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit =
|
||||||
JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle)
|
JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle)
|
||||||
|
|
||||||
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
||||||
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
||||||
|
|
||||||
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
||||||
JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle)
|
JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle)
|
||||||
|
|
||||||
|
public infix fun TensorType.swap(other: TensorType): Unit =
|
||||||
|
JNoa.swapTensors(tensorHandle, other.tensorHandle)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +317,10 @@ protected constructor(scope: NoaScope) :
|
|||||||
wrap(JNoa.tensorGrad(tensorHandle))
|
wrap(JNoa.tensorGrad(tensorHandle))
|
||||||
|
|
||||||
public fun NoaJitModule.train(status: Boolean): Unit =
|
public fun NoaJitModule.train(status: Boolean): Unit =
|
||||||
JNoa.trainMode(this.jitModuleHandle, status)
|
JNoa.trainMode(jitModuleHandle, status)
|
||||||
|
|
||||||
|
public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser =
|
||||||
|
AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate))
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaDoubleAlgebra
|
public sealed class NoaDoubleAlgebra
|
||||||
|
25
kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt
Normal file
25
kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.noa
|
||||||
|
|
||||||
|
import space.kscience.kmath.noa.memory.NoaResource
|
||||||
|
import space.kscience.kmath.noa.memory.NoaScope
|
||||||
|
|
||||||
|
internal typealias OptimiserHandle = Long
|
||||||
|
|
||||||
|
public abstract class NoaOptimiser
|
||||||
|
internal constructor(scope: NoaScope) : NoaResource(scope) {
|
||||||
|
public abstract fun step(): Unit
|
||||||
|
public abstract fun zeroGrad(): Unit
|
||||||
|
}
|
||||||
|
|
||||||
|
public class AdamOptimiser
|
||||||
|
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||||
|
: NoaOptimiser(scope) {
|
||||||
|
override fun dispose(): Unit = JNoa.disposeAdamOptim(optimiserHandle)
|
||||||
|
override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle)
|
||||||
|
override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle)
|
||||||
|
}
|
@ -1159,6 +1159,46 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer
|
|||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer
|
||||||
(JNIEnv *, jclass, jlong, jstring, jlong);
|
(JNIEnv *, jclass, jlong, jstring, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: adamOptim
|
||||||
|
* Signature: (JD)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_adamOptim
|
||||||
|
(JNIEnv *, jclass, jlong, jdouble);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: disposeAdamOptim
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeAdamOptim
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: stepAdamOptim
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_stepAdamOptim
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: zeroGradAdamOptim
|
||||||
|
* Signature: (J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: swapTensors
|
||||||
|
* Signature: (JJ)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user