Integrating withDeferScope

This commit is contained in:
Roland Grinis 2021-01-16 20:33:58 +00:00
commit 889691a122
6 changed files with 85 additions and 8 deletions

View File

@ -0,0 +1,7 @@
package kscience.kmath.memory
public expect class DeferScope {
public inline fun defer(crossinline block: () -> Unit)
}
public expect inline fun <R> withDeferScope(block: DeferScope.() -> R): R

View File

@ -0,0 +1,30 @@
package kscience.kmath.memory
private typealias Deferred = () -> Unit
public actual class DeferScope {
@PublishedApi
internal val deferred: MutableList<Deferred> = mutableListOf()
@PublishedApi
internal fun executeAllDeferred() {
deferred.forEach(Deferred::invoke)
deferred.clear()
}
public actual inline fun defer(crossinline block: () -> Unit) {
deferred += {
try {
block()
} catch (ignored: Throwable) {
}
}
}
}
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
val ds = DeferScope()
val r = ds.block()
ds.executeAllDeferred()
return r
}

View File

@ -0,0 +1,30 @@
package kscience.kmath.memory
private typealias Deferred = () -> Unit
public actual class DeferScope {
@PublishedApi
internal val deferred: MutableList<Deferred> = mutableListOf()
@PublishedApi
internal fun executeAllDeferred() {
deferred.forEach(Deferred::invoke)
deferred.clear()
}
public actual inline fun defer(crossinline block: () -> Unit) {
deferred += {
try {
block()
} catch (ignored: Throwable) {
}
}
}
}
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
val ds = DeferScope()
val r = ds.block()
ds.executeAllDeferred()
return r
}

View File

@ -0,0 +1,7 @@
package kscience.kmath.memory
import kotlinx.cinterop.memScoped
public actual typealias DeferScope = kotlinx.cinterop.DeferScope
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R = memScoped(block)

View File

@ -2,13 +2,14 @@ package kscience.kmath.torch
import kscience.kmath.structures.TensorStructure
import kscience.kmath.memory.DeferScope
import kotlinx.cinterop.*
import kscience.kmath.ctorch.*
public sealed class TorchTensor<T> constructor(
internal val scope: DeferScope,
public val scope: DeferScope,
internal val tensorHandle: COpaquePointer
) : TensorStructure<T>() {
init {
@ -79,7 +80,7 @@ public sealed class TorchTensorOverField<T> constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<T>(scope, tensorHandle) {
internal var requiresGrad: Boolean
public var requiresGrad: Boolean
get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value)
}

View File

@ -2,6 +2,8 @@ package kscience.kmath.torch
import kscience.kmath.structures.*
import kscience.kmath.memory.DeferScope
import kscience.kmath.memory.withDeferScope
import kotlinx.cinterop.*
import kscience.kmath.ctorch.*
@ -591,23 +593,23 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
}
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
memScoped { TorchTensorRealAlgebra(this).block() }
withDeferScope { TorchTensorRealAlgebra(this).block() }
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
memScoped { TorchTensorFloatAlgebra(this).block() }
withDeferScope { TorchTensorFloatAlgebra(this).block() }
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
memScoped { TorchTensorLongAlgebra(this).block() }
withDeferScope { TorchTensorLongAlgebra(this).block() }
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
memScoped { TorchTensorIntAlgebra(this).block() }
withDeferScope { TorchTensorIntAlgebra(this).block() }
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
this.requiresGrad = true
return TorchTensorRealAlgebra(this.scope).block()
}
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
this.requiresGrad = true
return TorchTensorFloatAlgebra(this.scope).block()
}