Integrating withDeferScope

This commit is contained in:
Roland Grinis 2021-01-16 20:33:58 +00:00
commit 889691a122
6 changed files with 85 additions and 8 deletions

View File

@ -0,0 +1,7 @@
package kscience.kmath.memory
public expect class DeferScope {
public inline fun defer(crossinline block: () -> Unit)
}
public expect inline fun <R> withDeferScope(block: DeferScope.() -> R): R

View File

@ -0,0 +1,30 @@
package kscience.kmath.memory
private typealias Deferred = () -> Unit
public actual class DeferScope {
@PublishedApi
internal val deferred: MutableList<Deferred> = mutableListOf()
@PublishedApi
internal fun executeAllDeferred() {
deferred.forEach(Deferred::invoke)
deferred.clear()
}
public actual inline fun defer(crossinline block: () -> Unit) {
deferred += {
try {
block()
} catch (ignored: Throwable) {
}
}
}
}
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
val ds = DeferScope()
val r = ds.block()
ds.executeAllDeferred()
return r
}

View File

@ -0,0 +1,30 @@
package kscience.kmath.memory
private typealias Deferred = () -> Unit
public actual class DeferScope {
@PublishedApi
internal val deferred: MutableList<Deferred> = mutableListOf()
@PublishedApi
internal fun executeAllDeferred() {
deferred.forEach(Deferred::invoke)
deferred.clear()
}
public actual inline fun defer(crossinline block: () -> Unit) {
deferred += {
try {
block()
} catch (ignored: Throwable) {
}
}
}
}
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
val ds = DeferScope()
val r = ds.block()
ds.executeAllDeferred()
return r
}

View File

@ -0,0 +1,7 @@
package kscience.kmath.memory
import kotlinx.cinterop.memScoped
public actual typealias DeferScope = kotlinx.cinterop.DeferScope
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R = memScoped(block)

View File

@ -2,13 +2,14 @@ package kscience.kmath.torch
import kscience.kmath.structures.TensorStructure import kscience.kmath.structures.TensorStructure
import kscience.kmath.memory.DeferScope
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.ctorch.*
public sealed class TorchTensor<T> constructor( public sealed class TorchTensor<T> constructor(
internal val scope: DeferScope, public val scope: DeferScope,
internal val tensorHandle: COpaquePointer internal val tensorHandle: COpaquePointer
) : TensorStructure<T>() { ) : TensorStructure<T>() {
init { init {
@ -79,7 +80,7 @@ public sealed class TorchTensorOverField<T> constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensor<T>(scope, tensorHandle) { ) : TorchTensor<T>(scope, tensorHandle) {
internal var requiresGrad: Boolean public var requiresGrad: Boolean
get() = requires_grad(tensorHandle) get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value) set(value) = requires_grad_(tensorHandle, value)
} }

View File

@ -2,6 +2,8 @@ package kscience.kmath.torch
import kscience.kmath.structures.* import kscience.kmath.structures.*
import kscience.kmath.memory.DeferScope
import kscience.kmath.memory.withDeferScope
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.ctorch.*
@ -591,23 +593,23 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
} }
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
memScoped { TorchTensorRealAlgebra(this).block() } withDeferScope { TorchTensorRealAlgebra(this).block() }
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R = public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
memScoped { TorchTensorFloatAlgebra(this).block() } withDeferScope { TorchTensorFloatAlgebra(this).block() }
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R = public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
memScoped { TorchTensorLongAlgebra(this).block() } withDeferScope { TorchTensorLongAlgebra(this).block() }
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R = public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
memScoped { TorchTensorIntAlgebra(this).block() } withDeferScope { TorchTensorIntAlgebra(this).block() }
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal { public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
this.requiresGrad = true this.requiresGrad = true
return TorchTensorRealAlgebra(this.scope).block() return TorchTensorRealAlgebra(this.scope).block()
} }
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat { public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
this.requiresGrad = true this.requiresGrad = true
return TorchTensorFloatAlgebra(this.scope).block() return TorchTensorFloatAlgebra(this.scope).block()
} }