forked from kscience/kmath
Integrating withDeferScope
This commit is contained in:
commit
889691a122
@ -0,0 +1,7 @@
|
|||||||
|
package kscience.kmath.memory
|
||||||
|
|
||||||
|
public expect class DeferScope {
|
||||||
|
public inline fun defer(crossinline block: () -> Unit)
|
||||||
|
}
|
||||||
|
|
||||||
|
public expect inline fun <R> withDeferScope(block: DeferScope.() -> R): R
|
@ -0,0 +1,30 @@
|
|||||||
|
package kscience.kmath.memory
|
||||||
|
|
||||||
|
private typealias Deferred = () -> Unit
|
||||||
|
|
||||||
|
public actual class DeferScope {
|
||||||
|
@PublishedApi
|
||||||
|
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
internal fun executeAllDeferred() {
|
||||||
|
deferred.forEach(Deferred::invoke)
|
||||||
|
deferred.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||||
|
deferred += {
|
||||||
|
try {
|
||||||
|
block()
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||||
|
val ds = DeferScope()
|
||||||
|
val r = ds.block()
|
||||||
|
ds.executeAllDeferred()
|
||||||
|
return r
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
package kscience.kmath.memory
|
||||||
|
|
||||||
|
private typealias Deferred = () -> Unit
|
||||||
|
|
||||||
|
public actual class DeferScope {
|
||||||
|
@PublishedApi
|
||||||
|
internal val deferred: MutableList<Deferred> = mutableListOf()
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
internal fun executeAllDeferred() {
|
||||||
|
deferred.forEach(Deferred::invoke)
|
||||||
|
deferred.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun defer(crossinline block: () -> Unit) {
|
||||||
|
deferred += {
|
||||||
|
try {
|
||||||
|
block()
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R {
|
||||||
|
val ds = DeferScope()
|
||||||
|
val r = ds.block()
|
||||||
|
ds.executeAllDeferred()
|
||||||
|
return r
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
package kscience.kmath.memory
|
||||||
|
|
||||||
|
import kotlinx.cinterop.memScoped
|
||||||
|
|
||||||
|
public actual typealias DeferScope = kotlinx.cinterop.DeferScope
|
||||||
|
|
||||||
|
public actual inline fun <R> withDeferScope(block: DeferScope.() -> R): R = memScoped(block)
|
@ -2,13 +2,14 @@ package kscience.kmath.torch
|
|||||||
|
|
||||||
|
|
||||||
import kscience.kmath.structures.TensorStructure
|
import kscience.kmath.structures.TensorStructure
|
||||||
|
import kscience.kmath.memory.DeferScope
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.ctorch.*
|
||||||
|
|
||||||
|
|
||||||
public sealed class TorchTensor<T> constructor(
|
public sealed class TorchTensor<T> constructor(
|
||||||
internal val scope: DeferScope,
|
public val scope: DeferScope,
|
||||||
internal val tensorHandle: COpaquePointer
|
internal val tensorHandle: COpaquePointer
|
||||||
) : TensorStructure<T>() {
|
) : TensorStructure<T>() {
|
||||||
init {
|
init {
|
||||||
@ -79,7 +80,7 @@ public sealed class TorchTensorOverField<T> constructor(
|
|||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensor<T>(scope, tensorHandle) {
|
) : TorchTensor<T>(scope, tensorHandle) {
|
||||||
internal var requiresGrad: Boolean
|
public var requiresGrad: Boolean
|
||||||
get() = requires_grad(tensorHandle)
|
get() = requires_grad(tensorHandle)
|
||||||
set(value) = requires_grad_(tensorHandle, value)
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,8 @@ package kscience.kmath.torch
|
|||||||
|
|
||||||
|
|
||||||
import kscience.kmath.structures.*
|
import kscience.kmath.structures.*
|
||||||
|
import kscience.kmath.memory.DeferScope
|
||||||
|
import kscience.kmath.memory.withDeferScope
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.ctorch.*
|
||||||
@ -591,23 +593,23 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
memScoped { TorchTensorRealAlgebra(this).block() }
|
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||||
|
|
||||||
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||||
memScoped { TorchTensorFloatAlgebra(this).block() }
|
withDeferScope { TorchTensorFloatAlgebra(this).block() }
|
||||||
|
|
||||||
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||||
memScoped { TorchTensorLongAlgebra(this).block() }
|
withDeferScope { TorchTensorLongAlgebra(this).block() }
|
||||||
|
|
||||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||||
memScoped { TorchTensorIntAlgebra(this).block() }
|
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
||||||
|
|
||||||
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
public inline fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
||||||
this.requiresGrad = true
|
this.requiresGrad = true
|
||||||
return TorchTensorRealAlgebra(this.scope).block()
|
return TorchTensorRealAlgebra(this.scope).block()
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
public inline fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
||||||
this.requiresGrad = true
|
this.requiresGrad = true
|
||||||
return TorchTensorFloatAlgebra(this.scope).block()
|
return TorchTensorFloatAlgebra(this.scope).block()
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user