forked from kscience/kmath
Integrating withDeferScope
This commit is contained in:
commit
889691a122
@ -0,0 +1,7 @@
|
||||
package kscience.kmath.memory
|
||||
|
||||
public expect class DeferScope {
|
||||
public inline fun defer(crossinline block: () -> Unit)
|
||||
}
|
||||
|
||||
public expect inline fun <R> withDeferScope(block: DeferScope.() -> R): R
|
@ -0,0 +1,30 @@
|
||||
package kscience.kmath.memory
|
||||
|
||||
private typealias Deferred = () -> Unit
|
||||
|
||||
public actual class DeferScope {
|
||||
@PublishedApi
|
||||
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||
|
||||
@PublishedApi
|
||||
internal fun executeAllDeferred() {
|
||||
deferred.forEach(Deferred::invoke)
|
||||
deferred.clear()
|
||||
}
|
||||
|
||||
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||
deferred += {
|
||||
try {
|
||||
block()
|
||||
} catch (ignored: Throwable) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||
val ds = DeferScope()
|
||||
val r = ds.block()
|
||||
ds.executeAllDeferred()
|
||||
return r
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package kscience.kmath.memory
|
||||
|
||||
private typealias Deferred = () -> Unit
|
||||
|
||||
public actual class DeferScope {
|
||||
@PublishedApi
|
||||
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||
|
||||
@PublishedApi
|
||||
internal fun executeAllDeferred() {
|
||||
deferred.forEach(Deferred::invoke)
|
||||
deferred.clear()
|
||||
}
|
||||
|
||||
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||
deferred += {
|
||||
try {
|
||||
block()
|
||||
} catch (ignored: Throwable) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||
val ds = DeferScope()
|
||||
val r = ds.block()
|
||||
ds.executeAllDeferred()
|
||||
return r
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package kscience.kmath.memory
|
||||
|
||||
import kotlinx.cinterop.memScoped
|
||||
|
||||
public actual typealias DeferScope = kotlinx.cinterop.DeferScope
|
||||
|
||||
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R = memScoped(block)
|
@ -2,13 +2,14 @@ package kscience.kmath.torch
|
||||
|
||||
|
||||
import kscience.kmath.structures.TensorStructure
|
||||
import kscience.kmath.memory.DeferScope
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
|
||||
|
||||
public sealed class TorchTensor<T> constructor(
|
||||
internal val scope: DeferScope,
|
||||
public val scope: DeferScope,
|
||||
internal val tensorHandle: COpaquePointer
|
||||
) : TensorStructure<T>() {
|
||||
init {
|
||||
@ -79,7 +80,7 @@ public sealed class TorchTensorOverField<T> constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<T>(scope, tensorHandle) {
|
||||
internal var requiresGrad: Boolean
|
||||
public var requiresGrad: Boolean
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
}
|
||||
|
@ -2,6 +2,8 @@ package kscience.kmath.torch
|
||||
|
||||
|
||||
import kscience.kmath.structures.*
|
||||
import kscience.kmath.memory.DeferScope
|
||||
import kscience.kmath.memory.withDeferScope
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
@ -591,23 +593,23 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
}
|
||||
|
||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorRealAlgebra(this).block() }
|
||||
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorFloatAlgebra(this).block() }
|
||||
withDeferScope { TorchTensorFloatAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorLongAlgebra(this).block() }
|
||||
withDeferScope { TorchTensorLongAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorIntAlgebra(this).block() }
|
||||
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
||||
|
||||
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
||||
public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
||||
this.requiresGrad = true
|
||||
return TorchTensorRealAlgebra(this.scope).block()
|
||||
}
|
||||
|
||||
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
||||
public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
||||
this.requiresGrad = true
|
||||
return TorchTensorFloatAlgebra(this.scope).block()
|
||||
}
|
Loading…
Reference in New Issue
Block a user