TensorAlgebra first syncing
This commit is contained in:
parent
97ef57697d
commit
ed4ac2623d
@ -1,50 +0,0 @@
|
|||||||
package kscience.kmath.torch
|
|
||||||
|
|
||||||
import kscience.kmath.operations.Field
|
|
||||||
import kscience.kmath.operations.Ring
|
|
||||||
|
|
||||||
|
|
||||||
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
|
|
||||||
|
|
||||||
public operator fun T.plus(other: TorchTensorType): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.plus(value: T): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.plusAssign(value: T): Unit
|
|
||||||
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
|
|
||||||
|
|
||||||
public operator fun T.minus(other: TorchTensorType): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.minus(value: T): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.minusAssign(value: T): Unit
|
|
||||||
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
|
|
||||||
|
|
||||||
public operator fun T.times(other: TorchTensorType): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.times(value: T): TorchTensorType
|
|
||||||
public operator fun TorchTensorType.timesAssign(value: T): Unit
|
|
||||||
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
|
|
||||||
|
|
||||||
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
|
|
||||||
|
|
||||||
public fun diagonalEmbedding(
|
|
||||||
diagonalEntries: TorchTensorType,
|
|
||||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
|
||||||
): TorchTensorType
|
|
||||||
|
|
||||||
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
|
|
||||||
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
|
|
||||||
|
|
||||||
public fun TorchTensorType.abs(): TorchTensorType
|
|
||||||
public fun TorchTensorType.sum(): TorchTensorType
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
|
|
||||||
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
|
|
||||||
|
|
||||||
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
|
|
||||||
|
|
||||||
public fun TorchTensorType.exp(): TorchTensorType
|
|
||||||
public fun TorchTensorType.log(): TorchTensorType
|
|
||||||
|
|
||||||
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
|
|
||||||
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
|
|
||||||
|
|
||||||
}
|
|
@ -1,13 +0,0 @@
|
|||||||
package kscience.kmath.torch
|
|
||||||
|
|
||||||
import kscience.kmath.structures.MutableNDStructure
|
|
||||||
|
|
||||||
public abstract class TensorStructure<T>: MutableNDStructure<T> {
|
|
||||||
|
|
||||||
// A tensor can have empty shape, in which case it represents just a value
|
|
||||||
public abstract fun value(): T
|
|
||||||
|
|
||||||
// Tensors might hold shared resources
|
|
||||||
override fun equals(other: Any?): Boolean = false
|
|
||||||
override fun hashCode(): Int = 0
|
|
||||||
}
|
|
@ -1,6 +1,8 @@
|
|||||||
package kscience.kmath.torch
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
|
import kscience.kmath.structures.TensorStructure
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.ctorch.*
|
||||||
|
|
||||||
@ -12,15 +14,16 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
init {
|
init {
|
||||||
scope.defer(::close)
|
scope.defer(::close)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun close(): Unit = dispose_tensor(tensorHandle)
|
private fun close(): Unit = dispose_tensor(tensorHandle)
|
||||||
|
|
||||||
protected abstract fun item(): T
|
protected abstract fun item(): T
|
||||||
|
|
||||||
override val dimension: Int get() = get_dim(tensorHandle)
|
override val dimension: Int get() = get_dim(tensorHandle)
|
||||||
override val shape: IntArray
|
override val shape: IntArray
|
||||||
get() = (1..dimension).map{get_shape_at(tensorHandle, it-1)}.toIntArray()
|
get() = (1..dimension).map { get_shape_at(tensorHandle, it - 1) }.toIntArray()
|
||||||
public val strides: IntArray
|
public val strides: IntArray
|
||||||
get() = (1..dimension).map{get_stride_at(tensorHandle, it-1)}.toIntArray()
|
get() = (1..dimension).map { get_stride_at(tensorHandle, it - 1) }.toIntArray()
|
||||||
public val size: Int get() = get_numel(tensorHandle)
|
public val size: Int get() = get_numel(tensorHandle)
|
||||||
public val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
public val device: Device get() = Device.fromInt(get_device(tensorHandle))
|
||||||
|
|
||||||
@ -44,27 +47,27 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
internal inline fun checkIsValue() = check(isValue()) {
|
internal inline fun checkIsValue() = check(isValue()) {
|
||||||
"This tensor has shape ${shape.toList()}"
|
"This tensor has shape ${shape.toList()}"
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun value(): T {
|
override fun value(): T {
|
||||||
checkIsValue()
|
checkIsValue()
|
||||||
return item()
|
return item()
|
||||||
}
|
}
|
||||||
|
|
||||||
public var requiresGrad: Boolean
|
|
||||||
get() = requires_grad(tensorHandle)
|
|
||||||
set(value) = requires_grad_(tensorHandle, value)
|
|
||||||
|
|
||||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = copy_to_double(this.tensorHandle)!!
|
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = copy_to_float(this.tensorHandle)!!
|
tensorHandle = copy_to_float(this.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = copy_to_long(this.tensorHandle)!!
|
tensorHandle = copy_to_long(this.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = copy_to_int(this.tensorHandle)!!
|
tensorHandle = copy_to_int(this.tensorHandle)!!
|
||||||
@ -72,10 +75,20 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public sealed class TorchTensorOverField<T> constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensor<T>(scope, tensorHandle) {
|
||||||
|
internal var requiresGrad: Boolean
|
||||||
|
get() = requires_grad(tensorHandle)
|
||||||
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public class TorchTensorReal internal constructor(
|
public class TorchTensorReal internal constructor(
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensor<Double>(scope, tensorHandle) {
|
) : TorchTensorOverField<Double>(scope, tensorHandle) {
|
||||||
override fun item(): Double = get_item_double(tensorHandle)
|
override fun item(): Double = get_item_double(tensorHandle)
|
||||||
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
||||||
override fun set(index: IntArray, value: Double) {
|
override fun set(index: IntArray, value: Double) {
|
||||||
@ -86,7 +99,7 @@ public class TorchTensorReal internal constructor(
|
|||||||
public class TorchTensorFloat internal constructor(
|
public class TorchTensorFloat internal constructor(
|
||||||
scope: DeferScope,
|
scope: DeferScope,
|
||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensor<Float>(scope, tensorHandle) {
|
) : TorchTensorOverField<Float>(scope, tensorHandle) {
|
||||||
override fun item(): Float = get_item_float(tensorHandle)
|
override fun item(): Float = get_item_float(tensorHandle)
|
||||||
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||||
override fun set(index: IntArray, value: Float) {
|
override fun set(index: IntArray, value: Float) {
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package kscience.kmath.torch
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
|
import kscience.kmath.structures.*
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.ctorch.*
|
||||||
|
|
||||||
@ -40,12 +43,12 @@ public sealed class TorchTensorAlgebra<
|
|||||||
override val one: TorchTensorType
|
override val one: TorchTensorType
|
||||||
get() = number(1)
|
get() = number(1)
|
||||||
|
|
||||||
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType) =
|
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
||||||
check(a.device == b.device) {
|
check(a.device == b.device) {
|
||||||
"Tensors must be on the same device"
|
"Tensors must be on the same device"
|
||||||
}
|
}
|
||||||
|
|
||||||
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType) =
|
protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType): Unit =
|
||||||
check(a.shape contentEquals b.shape) {
|
check(a.shape contentEquals b.shape) {
|
||||||
"Tensors must be of identical shape"
|
"Tensors must be of identical shape"
|
||||||
}
|
}
|
||||||
@ -214,7 +217,7 @@ public sealed class TorchTensorAlgebra<
|
|||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||||
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>>(scope: DeferScope) :
|
||||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
|
||||||
TensorFieldAlgebra<T, TorchTensorType> {
|
TensorFieldAlgebra<T, TorchTensorType> {
|
||||||
|
|
||||||
@ -597,4 +600,14 @@ public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() ->
|
|||||||
memScoped { TorchTensorLongAlgebra(this).block() }
|
memScoped { TorchTensorLongAlgebra(this).block() }
|
||||||
|
|
||||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||||
memScoped { TorchTensorIntAlgebra(this).block() }
|
memScoped { TorchTensorIntAlgebra(this).block() }
|
||||||
|
|
||||||
|
public fun TorchTensorReal.withGrad(block: TorchTensorRealAlgebra.() -> TorchTensorReal): TorchTensorReal {
|
||||||
|
this.requiresGrad = true
|
||||||
|
return TorchTensorRealAlgebra(this.scope).block()
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun TorchTensorFloat.withGrad(block: TorchTensorFloatAlgebra.() -> TorchTensorFloat): TorchTensorFloat {
|
||||||
|
this.requiresGrad = true
|
||||||
|
return TorchTensorFloatAlgebra(this.scope).block()
|
||||||
|
}
|
@ -5,14 +5,15 @@ import kotlin.test.*
|
|||||||
internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
|
internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
|
|
||||||
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
tensorX.requiresGrad = true
|
|
||||||
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||||
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
||||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
|
||||||
val expressionAtX =
|
val expressionAtX = tensorX.withGrad {
|
||||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||||
|
}
|
||||||
|
|
||||||
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
val gradientAtX = expressionAtX.grad(tensorX, retainGraph = true)
|
||||||
val hessianAtX = expressionAtX hess tensorX
|
val hessianAtX = expressionAtX hess tensorX
|
||||||
@ -31,14 +32,14 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
|||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
|
|
||||||
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||||
tensorX.requiresGrad = true
|
|
||||||
val tensorXt = tensorX.transpose(-1,-2)
|
|
||||||
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
|
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
|
||||||
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
|
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
|
||||||
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||||
|
|
||||||
val expressionAtX =
|
val expressionAtX = tensorX.withGrad{
|
||||||
|
val tensorXt = tensorX.transpose(-1,-2)
|
||||||
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
|
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
|
||||||
|
}
|
||||||
expressionAtX.sumAssign()
|
expressionAtX.sumAssign()
|
||||||
|
|
||||||
val gradientAtX = expressionAtX grad tensorX
|
val gradientAtX = expressionAtX grad tensorX
|
||||||
@ -52,7 +53,7 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
|||||||
|
|
||||||
internal class TestAutograd {
|
internal class TestAutograd {
|
||||||
@Test
|
@Test
|
||||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
fun testAutoGrad() = testingAutoGrad(dim = 3)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)
|
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)
|
||||||
|
Loading…
Reference in New Issue
Block a user