TensorAlgebra first syncing

This commit is contained in:
Roland Grinis 2021-01-16 19:35:30 +00:00
parent 97ef57697d
commit ed4ac2623d
5 changed files with 45 additions and 81 deletions

View File

@ -1,50 +0,0 @@
package kscience.kmath.torch
import kscience.kmath.operations.Field
import kscience.kmath.operations.Ring
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
public operator fun T.plus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.plus(value: T): TorchTensorType
public operator fun TorchTensorType.plusAssign(value: T): Unit
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
public operator fun T.minus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.minus(value: T): TorchTensorType
public operator fun TorchTensorType.minusAssign(value: T): Unit
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
public operator fun T.times(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.times(value: T): TorchTensorType
public operator fun TorchTensorType.timesAssign(value: T): Unit
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
public fun diagonalEmbedding(
diagonalEntries: TorchTensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TorchTensorType
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
public fun TorchTensorType.abs(): TorchTensorType
public fun TorchTensorType.sum(): TorchTensorType
}
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
public fun TorchTensorType.exp(): TorchTensorType
public fun TorchTensorType.log(): TorchTensorType
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
}

View File

@ -1,13 +0,0 @@
package kscience.kmath.torch
import kscience.kmath.structures.MutableNDStructure
public abstract class TensorStructure<T>: MutableNDStructure<T> {
// A tensor can have empty shape, in which case it represents just a value
public abstract fun value(): T
// Tensors might hold shared resources
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}

View File

@ -1,6 +1,8 @@
package kscience.kmath.torch package kscience.kmath.torch
import kscience.kmath.structures.TensorStructure
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.ctorch.*
@ -12,15 +14,16 @@ public sealed class TorchTensor<T> constructor(
init { init {
scope.defer(::close) scope.defer(::close)
} }
private fun close(): Unit = dispose_tensor(tensorHandle) private fun close(): Unit = dispose_tensor(tensorHandle)
protected abstract fun item(): T protected abstract fun item(): T
override val dimension: Int get() = get_dim(tensorHandle) override val dimension: Int get() = get_dim(tensorHandle)
override val shape: IntArray override val shape: IntArray
get() = (1..dimension).map{get_shape_at(tensorHandle, it-1)}.toIntArray() get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
public val strides: IntArray public val strides: IntArray
get() = (1..dimension).map{get_stride_at(tensorHandle, it-1)}.toIntArray() get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
public val size: Int get() = get_numel(tensorHandle) public val size: Int get() = get_numel(tensorHandle)
public val device: Device get() = Device.fromInt(get_device(tensorHandle)) public val device: Device get() = Device.fromInt(get_device(tensorHandle))
@ -44,27 +47,27 @@ public sealed class TorchTensor<T> constructor(
internal inline fun checkIsValue() = check(isValue()) { internal inline fun checkIsValue() = check(isValue()) {
"This tensor has shape ${shape.toList()}" "This tensor has shape ${shape.toList()}"
} }
override fun value(): T { override fun value(): T {
checkIsValue() checkIsValue()
return item() return item()
} }
public var requiresGrad: Boolean
get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value)
public fun copyToDouble(): TorchTensorReal = TorchTensorReal( public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = copy_to_double(this.tensorHandle)!! tensorHandle = copy_to_double(this.tensorHandle)!!
) )
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat( public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
scope = scope, scope = scope,
tensorHandle = copy_to_float(this.tensorHandle)!! tensorHandle = copy_to_float(this.tensorHandle)!!
) )
public fun copyToLong(): TorchTensorLong = TorchTensorLong( public fun copyToLong(): TorchTensorLong = TorchTensorLong(
scope = scope, scope = scope,
tensorHandle = copy_to_long(this.tensorHandle)!! tensorHandle = copy_to_long(this.tensorHandle)!!
) )
public fun copyToInt(): TorchTensorInt = TorchTensorInt( public fun copyToInt(): TorchTensorInt = TorchTensorInt(
scope = scope, scope = scope,
tensorHandle = copy_to_int(this.tensorHandle)!! tensorHandle = copy_to_int(this.tensorHandle)!!
@ -72,10 +75,20 @@ public sealed class TorchTensor<T> constructor(
} }
public sealed class TorchTensorOverField<T> constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<T>(scope, tensorHandle) {
internal var requiresGrad: Boolean
get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value)
}
public class TorchTensorReal internal constructor( public class TorchTensorReal internal constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensor<Double>(scope, tensorHandle) { ) : TorchTensorOverField<Double>(scope, tensorHandle) {
override fun item(): Double = get_item_double(tensorHandle) override fun item(): Double = get_item_double(tensorHandle)
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues()) override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Double) { override fun set(index: IntArray, value: Double) {
@ -86,7 +99,7 @@ public class TorchTensorReal internal constructor(
public class TorchTensorFloat internal constructor( public class TorchTensorFloat internal constructor(
scope: DeferScope, scope: DeferScope,
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensor<Float>(scope, tensorHandle) { ) : TorchTensorOverField<Float>(scope, tensorHandle) {
override fun item(): Float = get_item_float(tensorHandle) override fun item(): Float = get_item_float(tensorHandle)
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues()) override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Float) { override fun set(index: IntArray, value: Float) {

View File

@ -1,5 +1,8 @@
package kscience.kmath.torch package kscience.kmath.torch
import kscience.kmath.structures.*
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.ctorch.*
@ -40,12 +43,12 @@ public sealed class TorchTensorAlgebra<
override val one: TorchTensorType override val one: TorchTensorType
get() = number(1) get() = number(1)
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType) = protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
check(a.device == b.device) { check(a.device == b.device) {
"Tensors must be on the same device" "Tensors must be on the same device"
} }
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType) = protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
check(a.shape contentEquals b.shape) { check(a.shape contentEquals b.shape) {
"Tensors must be of identical shape" "Tensors must be of identical shape"
} }
@ -214,7 +217,7 @@ public sealed class TorchTensorAlgebra<
} }
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar, public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) : PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope), TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
TensorFieldAlgebra<T, TorchTensorType> { TensorFieldAlgebra<T, TorchTensorType> {
@ -598,3 +601,13 @@ public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() ->
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R = public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
memScoped { TorchTensorIntAlgebra(this).block() } memScoped { TorchTensorIntAlgebra(this).block() }
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
this.requiresGrad = true
return TorchTensorRealAlgebra(this.scope).block()
}
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
this.requiresGrad = true
return TorchTensorFloatAlgebra(this.scope).block()
}

View File

@ -5,14 +5,15 @@ import kotlin.test.*
internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit { internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
val tensorX = randNormal(shape = intArrayOf(dim), device = device) val tensorX = randNormal(shape = intArrayOf(dim), device = device)
tensorX.requiresGrad = true
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device) val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(0,1) val tensorSigma = randFeatures + randFeatures.transpose(0,1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device) val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
val expressionAtX = val expressionAtX = tensorX.withGrad {
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
}
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true) val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
val hessianAtX = expressionAtX hess tensorX val hessianAtX = expressionAtX hess tensorX
@ -31,14 +32,14 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
setSeed(SEED) setSeed(SEED)
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device) val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
tensorX.requiresGrad = true
val tensorXt = tensorX.transpose(-1,-2)
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device) val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1) val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device) val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
val expressionAtX = val expressionAtX = tensorX.withGrad{
val tensorXt = tensorX.transpose(-1,-2)
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2 0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
}
expressionAtX.sumAssign() expressionAtX.sumAssign()
val gradientAtX = expressionAtX grad tensorX val gradientAtX = expressionAtX grad tensorX
@ -52,7 +53,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
internal class TestAutograd { internal class TestAutograd {
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 100) fun testAutoGrad() = testingAutoGrad(dim = 3)
@Test @Test
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30) fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)