diff --git a/kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/DeferScope.kt b/kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/DeferScope.kt new file mode 100644 index 000000000..df5040ee0 --- /dev/null +++ b/kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/DeferScope.kt @@ -0,0 +1,7 @@ +package kscience.kmath.memory + +public expect class DeferScope { + public inline fun defer(crossinline block: () -> Unit) +} + +public expect inline fun withDeferScope(block: DeferScope.() -> R): R diff --git a/kmath-memory/src/jsMain/kotlin/kscience/kmath/memory/DeferScopeJS.kt b/kmath-memory/src/jsMain/kotlin/kscience/kmath/memory/DeferScopeJS.kt new file mode 100644 index 000000000..aeb699a39 --- /dev/null +++ b/kmath-memory/src/jsMain/kotlin/kscience/kmath/memory/DeferScopeJS.kt @@ -0,0 +1,30 @@ +package kscience.kmath.memory + +private typealias Deferred = () -> Unit + +public actual class DeferScope { + @PublishedApi + internal val deferred: MutableList = mutableListOf() + + @PublishedApi + internal fun executeAllDeferred() { + deferred.forEach(Deferred::invoke) + deferred.clear() + } + + public actual inline fun defer(crossinline block: () -> Unit) { + deferred += { + try { + block() + } catch (ignored: Throwable) { + } + } + } +} + +public actual inline fun withDeferScope(block: DeferScope.() -> R): R { + val ds = DeferScope() + val r = ds.block() + ds.executeAllDeferred() + return r +} diff --git a/kmath-memory/src/jvmMain/kotlin/kscience/kmath/memory/DeferScopeJvm.kt b/kmath-memory/src/jvmMain/kotlin/kscience/kmath/memory/DeferScopeJvm.kt new file mode 100644 index 000000000..aeb699a39 --- /dev/null +++ b/kmath-memory/src/jvmMain/kotlin/kscience/kmath/memory/DeferScopeJvm.kt @@ -0,0 +1,30 @@ +package kscience.kmath.memory + +private typealias Deferred = () -> Unit + +public actual class DeferScope { + @PublishedApi + internal val deferred: MutableList = mutableListOf() + + @PublishedApi + internal fun executeAllDeferred() { + deferred.forEach(Deferred::invoke) + deferred.clear() + } + + public actual inline fun defer(crossinline block: () -> Unit) { + deferred += { + try { + block() + } catch (ignored: Throwable) { + } + } + } +} + +public actual inline fun withDeferScope(block: DeferScope.() -> R): R { + val ds = DeferScope() + val r = ds.block() + ds.executeAllDeferred() + return r +} diff --git a/kmath-memory/src/nativeMain/kotlin/kscience/kmath/memory/DeferScopeNative.kt b/kmath-memory/src/nativeMain/kotlin/kscience/kmath/memory/DeferScopeNative.kt new file mode 100644 index 000000000..6809b18b3 --- /dev/null +++ b/kmath-memory/src/nativeMain/kotlin/kscience/kmath/memory/DeferScopeNative.kt @@ -0,0 +1,7 @@ +package kscience.kmath.memory + +import kotlinx.cinterop.memScoped + +public actual typealias DeferScope = kotlinx.cinterop.DeferScope + +public actual inline fun withDeferScope(block: DeferScope.() -> R): R = memScoped(block) diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt index cc41359a2..3e2ebd562 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt @@ -2,13 +2,14 @@ package kscience.kmath.torch import kscience.kmath.structures.TensorStructure +import kscience.kmath.memory.DeferScope import kotlinx.cinterop.* import kscience.kmath.ctorch.* public sealed class TorchTensor constructor( - internal val scope: DeferScope, + public val scope: DeferScope, internal val tensorHandle: COpaquePointer ) : TensorStructure() { init { @@ -79,7 +80,7 @@ public sealed class TorchTensorOverField constructor( scope: DeferScope, tensorHandle: COpaquePointer ) : TorchTensor(scope, tensorHandle) { - internal var requiresGrad: Boolean + public var requiresGrad: Boolean get() = requires_grad(tensorHandle) set(value) = requires_grad_(tensorHandle, value) } diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index dc66813de..00d123c4a 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -2,6 +2,8 @@ package kscience.kmath.torch import kscience.kmath.structures.* +import kscience.kmath.memory.DeferScope +import kscience.kmath.memory.withDeferScope import kotlinx.cinterop.* import kscience.kmath.ctorch.* @@ -591,23 +593,23 @@ public class TorchTensorIntAlgebra(scope: DeferScope) : } public inline fun TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = - memScoped { TorchTensorRealAlgebra(this).block() } + withDeferScope { TorchTensorRealAlgebra(this).block() } public inline fun TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R = - memScoped { TorchTensorFloatAlgebra(this).block() } + withDeferScope { TorchTensorFloatAlgebra(this).block() } public inline fun TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R = - memScoped { TorchTensorLongAlgebra(this).block() } + withDeferScope { TorchTensorLongAlgebra(this).block() } public inline fun TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R = - memScoped { TorchTensorIntAlgebra(this).block() } + withDeferScope { TorchTensorIntAlgebra(this).block() } -public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal { +public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal { this.requiresGrad = true return TorchTensorRealAlgebra(this.scope).block() } -public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat { +public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat { this.requiresGrad = true return TorchTensorFloatAlgebra(this.scope).block() } \ No newline at end of file