disposable pattern

This commit is contained in:
Roland Grinis 2021-07-07 11:58:48 +01:00
parent cadcb9916f
commit 78b1cd41da
8 changed files with 97 additions and 2 deletions

View File

@ -15,7 +15,7 @@ plugins {
description = "Wrapper for the Bayesian Computation library NOA on top of LibTorch" description = "Wrapper for the Bayesian Computation library NOA on top of LibTorch"
dependencies { dependencies {
implementation(project(":kmath-tensors")) api(project(":kmath-tensors"))
} }
val home: String = System.getProperty("user.home") val home: String = System.getProperty("user.home")

View File

@ -5,7 +5,7 @@
package space.kscience.kmath.noa; package space.kscience.kmath.noa;
public class JNoa { class JNoa {
static { static {
String jNoaPath = System.getProperty("user.home") + String jNoaPath = System.getProperty("user.home") +
@ -30,5 +30,8 @@ public class JNoa {
public static native void setNumThreads(int numThreads); public static native void setNumThreads(int numThreads);
public static native void setSeed(int seed); public static native void setSeed(int seed);
public static native void disposeTensor(long tensorHandle);
} }

View File

@ -0,0 +1,8 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa
public sealed class NoaAlgebra<T>{}

View File

@ -0,0 +1,15 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa.memory
public abstract class NoaResource
internal constructor(scope: NoaScope) {
init {
scope.add(::dispose)
}
protected abstract fun dispose(): Unit
}

View File

@ -0,0 +1,34 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa.memory
private typealias Disposable = () -> Unit
public class NoaScope {
private val disposables: ArrayDeque<Disposable> = ArrayDeque(0)
public fun disposeAll() {
disposables.forEach(Disposable::invoke)
disposables.clear()
}
internal inline fun add(crossinline disposable: Disposable) {
disposables += {
try {
disposable()
} catch (ignored: Throwable) {
}
}
}
}
internal inline fun <R> withNoaScope(i: Int, block: NoaScope.() -> R): R {
val noaScope = NoaScope()
val result = noaScope.block()
noaScope.disposeAll()
return result
}

View File

@ -0,0 +1,23 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa
import space.kscience.kmath.noa.memory.NoaResource
import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.tensors.api.Tensor
internal typealias TensorHandle = Long
public sealed class NoaTensor<T>
constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) :
NoaResource(scope){
override fun dispose(): Unit = JNoa.disposeTensor(tensorHandle)
}
public class NoaDoubleTensor
internal constructor(scope: NoaScope, tensorHandle: TensorHandle) :
NoaTensor<Double>(scope, tensorHandle)

View File

@ -47,6 +47,14 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setNumThreads
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSeed JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setSeed
(JNIEnv *, jclass, jint); (JNIEnv *, jclass, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: disposeTensor
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeTensor
(JNIEnv *, jclass, jlong);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.noa package space.kscience.kmath.noa
import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.noa.memory.withNoaScope
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -26,4 +28,6 @@ class TestUtils {
setNumThreads(numThreads) setNumThreads(numThreads)
assertEquals(numThreads, getNumThreads()) assertEquals(numThreads, getNumThreads())
} }
} }