forked from kscience/kmath
TensorAlgebra first syncing
This commit is contained in:
parent
97ef57697d
commit
ed4ac2623d
@ -1,50 +0,0 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.operations.Field
|
||||
import kscience.kmath.operations.Ring
|
||||
|
||||
|
||||
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
|
||||
|
||||
public operator fun T.plus(other: TorchTensorType): TorchTensorType
|
||||
public operator fun TorchTensorType.plus(value: T): TorchTensorType
|
||||
public operator fun TorchTensorType.plusAssign(value: T): Unit
|
||||
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
|
||||
|
||||
public operator fun T.minus(other: TorchTensorType): TorchTensorType
|
||||
public operator fun TorchTensorType.minus(value: T): TorchTensorType
|
||||
public operator fun TorchTensorType.minusAssign(value: T): Unit
|
||||
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
|
||||
|
||||
public operator fun T.times(other: TorchTensorType): TorchTensorType
|
||||
public operator fun TorchTensorType.times(value: T): TorchTensorType
|
||||
public operator fun TorchTensorType.timesAssign(value: T): Unit
|
||||
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
|
||||
|
||||
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
|
||||
|
||||
public fun diagonalEmbedding(
|
||||
diagonalEntries: TorchTensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||
): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
|
||||
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.abs(): TorchTensorType
|
||||
public fun TorchTensorType.sum(): TorchTensorType
|
||||
|
||||
}
|
||||
|
||||
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
|
||||
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
|
||||
|
||||
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
|
||||
|
||||
public fun TorchTensorType.exp(): TorchTensorType
|
||||
public fun TorchTensorType.log(): TorchTensorType
|
||||
|
||||
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
|
||||
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
|
||||
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.structures.MutableNDStructure
|
||||
|
||||
public abstract class TensorStructure<T>: MutableNDStructure<T> {
|
||||
|
||||
// A tensor can have empty shape, in which case it represents just a value
|
||||
public abstract fun value(): T
|
||||
|
||||
// Tensors might hold shared resources
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
override fun hashCode(): Int = 0
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
|
||||
import kscience.kmath.structures.TensorStructure
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
|
||||
@ -12,15 +14,16 @@ public sealed class TorchTensor<T> constructor(
|
||||
init {
|
||||
scope.defer(::close)
|
||||
}
|
||||
|
||||
private fun close(): Unit = dispose_tensor(tensorHandle)
|
||||
|
||||
protected abstract fun item(): T
|
||||
|
||||
override val dimension: Int get() = get_dim(tensorHandle)
|
||||
override val shape: IntArray
|
||||
get() = (1..dimension).map{get_shape_at(tensorHandle, it-1)}.toIntArray()
|
||||
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
|
||||
public val strides: IntArray
|
||||
get() = (1..dimension).map{get_stride_at(tensorHandle, it-1)}.toIntArray()
|
||||
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
|
||||
public val size: Int get() = get_numel(tensorHandle)
|
||||
public val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
||||
|
||||
@ -44,27 +47,27 @@ public sealed class TorchTensor<T> constructor(
|
||||
internal inline fun checkIsValue() = check(isValue()) {
|
||||
"This tensor has shape ${shape.toList()}"
|
||||
}
|
||||
|
||||
override fun value(): T {
|
||||
checkIsValue()
|
||||
return item()
|
||||
}
|
||||
|
||||
public var requiresGrad: Boolean
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
|
||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_float(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_long(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_int(this.tensorHandle)!!
|
||||
@ -72,10 +75,20 @@ public sealed class TorchTensor<T> constructor(
|
||||
|
||||
}
|
||||
|
||||
public sealed class TorchTensorOverField<T> constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<T>(scope, tensorHandle) {
|
||||
internal var requiresGrad: Boolean
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
}
|
||||
|
||||
|
||||
public class TorchTensorReal internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<Double>(scope, tensorHandle) {
|
||||
) : TorchTensorOverField<Double>(scope, tensorHandle) {
|
||||
override fun item(): Double = get_item_double(tensorHandle)
|
||||
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Double) {
|
||||
@ -86,7 +99,7 @@ public class TorchTensorReal internal constructor(
|
||||
public class TorchTensorFloat internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: COpaquePointer
|
||||
) : TorchTensor<Float>(scope, tensorHandle) {
|
||||
) : TorchTensorOverField<Float>(scope, tensorHandle) {
|
||||
override fun item(): Float = get_item_float(tensorHandle)
|
||||
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||
override fun set(index: IntArray, value: Float) {
|
||||
|
@ -1,5 +1,8 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
|
||||
@ -40,12 +43,12 @@ public sealed class TorchTensorAlgebra<
|
||||
override val one: TorchTensorType
|
||||
get() = number(1)
|
||||
|
||||
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType) =
|
||||
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
||||
check(a.device == b.device) {
|
||||
"Tensors must be on the same device"
|
||||
}
|
||||
|
||||
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType) =
|
||||
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Tensors must be of identical shape"
|
||||
}
|
||||
@ -214,7 +217,7 @@ public sealed class TorchTensorAlgebra<
|
||||
}
|
||||
|
||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
||||
TensorFieldAlgebra<T, TorchTensorType> {
|
||||
|
||||
@ -598,3 +601,13 @@ public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() ->
|
||||
|
||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorIntAlgebra(this).block() }
|
||||
|
||||
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
||||
this.requiresGrad = true
|
||||
return TorchTensorRealAlgebra(this.scope).block()
|
||||
}
|
||||
|
||||
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
||||
this.requiresGrad = true
|
||||
return TorchTensorFloatAlgebra(this.scope).block()
|
||||
}
|
@ -5,14 +5,15 @@ import kotlin.test.*
|
||||
internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
setSeed(SEED)
|
||||
|
||||
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||
tensorX.requiresGrad = true
|
||||
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
|
||||
val expressionAtX =
|
||||
val expressionAtX = tensorX.withGrad {
|
||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||
}
|
||||
|
||||
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||
val hessianAtX = expressionAtX hess tensorX
|
||||
@ -31,14 +32,14 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||
setSeed(SEED)
|
||||
|
||||
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||
tensorX.requiresGrad = true
|
||||
val tensorXt = tensorX.transpose(-1,-2)
|
||||
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
|
||||
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||
|
||||
val expressionAtX =
|
||||
val expressionAtX = tensorX.withGrad{
|
||||
val tensorXt = tensorX.transpose(-1,-2)
|
||||
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
|
||||
}
|
||||
expressionAtX.sumAssign()
|
||||
|
||||
val gradientAtX = expressionAtX grad tensorX
|
||||
@ -52,7 +53,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||
|
||||
internal class TestAutograd {
|
||||
@Test
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 3)
|
||||
|
||||
@Test
|
||||
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)
|
||||
|
Loading…
Reference in New Issue
Block a user