TensorAlgebra first syncing

This commit is contained in:
Roland Grinis 2021-01-16 19:35:30 +00:00
parent 97ef57697d
commit ed4ac2623d
5 changed files with 45 additions and 81 deletions

View File

@ -1,50 +0,0 @@
package kscience.kmath.torch
import kscience.kmath.operations.Field
import kscience.kmath.operations.Ring
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
public operator fun T.plus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.plus(value: T): TorchTensorType
public operator fun TorchTensorType.plusAssign(value: T): Unit
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
public operator fun T.minus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.minus(value: T): TorchTensorType
public operator fun TorchTensorType.minusAssign(value: T): Unit
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
public operator fun T.times(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.times(value: T): TorchTensorType
public operator fun TorchTensorType.timesAssign(value: T): Unit
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
public fun diagonalEmbedding(
diagonalEntries: TorchTensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TorchTensorType
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
public fun TorchTensorType.abs(): TorchTensorType
public fun TorchTensorType.sum(): TorchTensorType
}
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
public fun TorchTensorType.exp(): TorchTensorType
public fun TorchTensorType.log(): TorchTensorType
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
}

View File

@ -1,13 +0,0 @@
package kscience.kmath.torch
import kscience.kmath.structures.MutableNDStructure
public abstract class TensorStructure<T>: MutableNDStructure<T> {
// A tensor can have empty shape, in which case it represents just a value
public abstract fun value(): T
// Tensors might hold shared resources
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}

View File

@ -1,6 +1,8 @@
package kscience.kmath.torch
import kscience.kmath.structures.TensorStructure
import kotlinx.cinterop.*
import kscience.kmath.ctorch.*
@ -12,6 +14,7 @@ public sealed class TorchTensor<T> constructor(
init {
scope.defer(::close)
}
private fun close(): Unit = dispose_tensor(tensorHandle)
protected abstract fun item(): T
@ -44,27 +47,27 @@ public sealed class TorchTensor<T> constructor(
internal inline fun checkIsValue() = check(isValue()) {
"This tensor has shape ${shape.toList()}"
}
override fun value(): T {
checkIsValue()
return item()
}
public var requiresGrad: Boolean
get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value)
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = copy_to_double(this.tensorHandle)!!
)
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
scope = scope,
tensorHandle = copy_to_float(this.tensorHandle)!!
)
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
scope = scope,
tensorHandle = copy_to_long(this.tensorHandle)!!
)
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
scope = scope,
tensorHandle = copy_to_int(this.tensorHandle)!!
@ -72,10 +75,20 @@ public sealed class TorchTensor<T> constructor(
}
public sealed class TorchTensorOverField<T> constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<T>(scope, tensorHandle) {
internal var requiresGrad: Boolean
get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value)
}
public class TorchTensorReal internal constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<Double>(scope, tensorHandle) {
) : TorchTensorOverField<Double>(scope, tensorHandle) {
override fun item(): Double = get_item_double(tensorHandle)
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Double) {
@ -86,7 +99,7 @@ public class TorchTensorReal internal constructor(
public class TorchTensorFloat internal constructor(
scope: DeferScope,
tensorHandle: COpaquePointer
) : TorchTensor<Float>(scope, tensorHandle) {
) : TorchTensorOverField<Float>(scope, tensorHandle) {
override fun item(): Float = get_item_float(tensorHandle)
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Float) {

View File

@ -1,5 +1,8 @@
package kscience.kmath.torch
import kscience.kmath.structures.*
import kotlinx.cinterop.*
import kscience.kmath.ctorch.*
@ -40,12 +43,12 @@ public sealed class TorchTensorAlgebra<
override val one: TorchTensorType
get() = number(1)
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType) =
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
check(a.device == b.device) {
"Tensors must be on the same device"
}
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType) =
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
check(a.shape contentEquals b.shape) {
"Tensors must be of identical shape"
}
@ -214,7 +217,7 @@ public sealed class TorchTensorAlgebra<
}
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
TensorFieldAlgebra<T, TorchTensorType> {
@ -598,3 +601,13 @@ public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() ->
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
memScoped { TorchTensorIntAlgebra(this).block() }
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
this.requiresGrad = true
return TorchTensorRealAlgebra(this.scope).block()
}
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
this.requiresGrad = true
return TorchTensorFloatAlgebra(this.scope).block()
}

View File

@ -5,14 +5,15 @@ import kotlin.test.*
internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra {
setSeed(SEED)
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
tensorX.requiresGrad = true
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
val expressionAtX =
val expressionAtX = tensorX.withGrad {
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
}
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
val hessianAtX = expressionAtX hess tensorX
@ -31,14 +32,14 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
setSeed(SEED)
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
tensorX.requiresGrad = true
val tensorXt = tensorX.transpose(-1,-2)
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
val expressionAtX =
val expressionAtX = tensorX.withGrad{
val tensorXt = tensorX.transpose(-1,-2)
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
}
expressionAtX.sumAssign()
val gradientAtX = expressionAtX grad tensorX
@ -52,7 +53,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
internal class TestAutograd {
@Test
fun testAutoGrad() = testingAutoGrad(dim = 100)
fun testAutoGrad() = testingAutoGrad(dim = 3)
@Test
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)