Syncing with tensor-algebra
This commit is contained in:
commit
6594ffc965
@ -162,7 +162,7 @@ public interface Strides {
|
|||||||
/**
|
/**
|
||||||
* Array strides
|
* Array strides
|
||||||
*/
|
*/
|
||||||
public val strides: List<Int>
|
public val strides: IntArray
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get linear index from multidimensional index
|
* Get linear index from multidimensional index
|
||||||
@ -189,6 +189,11 @@ public interface Strides {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int = index.mapIndexed { i, value ->
|
||||||
|
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
|
||||||
|
value * strides[i]
|
||||||
|
}.sum()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple implementation of [Strides].
|
* Simple implementation of [Strides].
|
||||||
*/
|
*/
|
||||||
@ -199,7 +204,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
/**
|
/**
|
||||||
* Strides for memory access
|
* Strides for memory access
|
||||||
*/
|
*/
|
||||||
override val strides: List<Int> by lazy {
|
override val strides: IntArray by lazy {
|
||||||
sequence {
|
sequence {
|
||||||
var current = 1
|
var current = 1
|
||||||
yield(1)
|
yield(1)
|
||||||
@ -208,13 +213,10 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
current *= it
|
current *= it
|
||||||
yield(current)
|
yield(current)
|
||||||
}
|
}
|
||||||
}.toList()
|
}.toList().toIntArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
||||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
|
||||||
value * strides[i]
|
|
||||||
}.sum()
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray {
|
override fun index(offset: Int): IntArray {
|
||||||
val res = IntArray(shape.size)
|
val res = IntArray(shape.size)
|
||||||
@ -323,8 +325,7 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
/**
|
/**
|
||||||
* Mutable ND buffer based on linear [MutableBuffer].
|
* Mutable ND buffer based on linear [MutableBuffer].
|
||||||
*/
|
*/
|
||||||
|
public open class MutableNDBuffer<T>(
|
||||||
public class MutableNDBuffer<T>(
|
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
buffer: MutableBuffer<T>,
|
buffer: MutableBuffer<T>,
|
||||||
) : NDBuffer<T>(strides, buffer), MutableNDStructure<T> {
|
) : NDBuffer<T>(strides, buffer), MutableNDStructure<T> {
|
||||||
|
@ -0,0 +1,173 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.MutableNDBuffer
|
||||||
|
import space.kscience.kmath.structures.RealBuffer
|
||||||
|
import space.kscience.kmath.structures.array
|
||||||
|
|
||||||
|
|
||||||
|
public class RealTensor(
|
||||||
|
override val shape: IntArray,
|
||||||
|
buffer: DoubleArray
|
||||||
|
) :
|
||||||
|
TensorStructure<Double>,
|
||||||
|
MutableNDBuffer<Double>(
|
||||||
|
TensorStrides(shape),
|
||||||
|
RealBuffer(buffer)
|
||||||
|
) {
|
||||||
|
override fun item(): Double = buffer[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||||
|
|
||||||
|
override fun add(a: RealTensor, b: RealTensor): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: RealTensor, k: Number): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: RealTensor
|
||||||
|
get() = TODO("Not yet implemented")
|
||||||
|
|
||||||
|
override fun multiply(a: RealTensor, b: RealTensor): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override val one: RealTensor
|
||||||
|
get() = TODO("Not yet implemented")
|
||||||
|
|
||||||
|
|
||||||
|
override fun Double.plus(other: RealTensor): RealTensor {
|
||||||
|
val n = other.buffer.size
|
||||||
|
val arr = other.buffer.array
|
||||||
|
val res = DoubleArray(n)
|
||||||
|
for (i in 1..n)
|
||||||
|
res[i - 1] = arr[i - 1] + this
|
||||||
|
return RealTensor(other.shape, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.plus(value: Double): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.plusAssign(value: Double) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.plusAssign(other: RealTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Double.minus(other: RealTensor): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.minus(value: Double): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.minusAssign(value: Double) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.minusAssign(other: RealTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Double.times(other: RealTensor): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.times(value: Double): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.timesAssign(value: Double) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.timesAssign(other: RealTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.dot(other: RealTensor): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.dotAssign(other: RealTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.dotRightAssign(other: RealTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun diagonalEmbedding(diagonalEntries: RealTensor, offset: Int, dim1: Int, dim2: Int): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.transpose(i: Int, j: Int): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.transposeAssign(i: Int, j: Int) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.view(shape: IntArray): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.abs(): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.absAssign() {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.sum(): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.sumAssign() {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.div(other: RealTensor): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.divAssign(other: RealTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.exp(): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.expAssign() {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.log(): RealTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.logAssign() {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun RealTensor.symEig(eigenvectors: Boolean): Pair<RealTensor, RealTensor> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <R> RealTensorAlgebra(block: RealTensorAlgebra.() -> R): R =
|
||||||
|
RealTensorAlgebra().block()
|
@ -1,70 +1,114 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.nd.MutableNDStructure
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.operations.RingWithNumbers
|
||||||
public interface TensorStructure<T> : MutableNDStructure<T> {
|
|
||||||
// A tensor can have empty shape, in which case it represents just a value
|
|
||||||
public fun value(): T
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
|
public interface TensorAlgebra<T, TensorType : TensorStructure<T>>: RingWithNumbers<TensorType> {
|
||||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|
||||||
|
|
||||||
public operator fun T.plus(other: TensorType): TensorType
|
public operator fun T.plus(other: TensorType): TensorType
|
||||||
public operator fun TensorType.plus(value: T): TensorType
|
public operator fun TensorType.plus(value: T): TensorType
|
||||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
|
||||||
public operator fun TensorType.plusAssign(value: T): Unit
|
public operator fun TensorType.plusAssign(value: T): Unit
|
||||||
public operator fun TensorType.plusAssign(other: TensorType): Unit
|
public operator fun TensorType.plusAssign(other: TensorType): Unit
|
||||||
|
|
||||||
public operator fun T.minus(other: TensorType): TensorType
|
public operator fun T.minus(other: TensorType): TensorType
|
||||||
public operator fun TensorType.minus(value: T): TensorType
|
public operator fun TensorType.minus(value: T): TensorType
|
||||||
public operator fun TensorType.minus(other: TensorType): TensorType
|
|
||||||
public operator fun TensorType.minusAssign(value: T): Unit
|
public operator fun TensorType.minusAssign(value: T): Unit
|
||||||
public operator fun TensorType.minusAssign(other: TensorType): Unit
|
public operator fun TensorType.minusAssign(other: TensorType): Unit
|
||||||
|
|
||||||
public operator fun T.times(other: TensorType): TensorType
|
public operator fun T.times(other: TensorType): TensorType
|
||||||
public operator fun TensorType.times(value: T): TensorType
|
public operator fun TensorType.times(value: T): TensorType
|
||||||
public operator fun TensorType.times(other: TensorType): TensorType
|
|
||||||
public operator fun TensorType.timesAssign(value: T): Unit
|
public operator fun TensorType.timesAssign(value: T): Unit
|
||||||
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||||
public operator fun TensorType.unaryMinus(): TensorType
|
|
||||||
|
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||||
public infix fun TensorType.dotAssign(other: TensorType): Unit
|
public infix fun TensorType.dotAssign(other: TensorType): Unit
|
||||||
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
|
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||||
public fun diagonalEmbedding(
|
public fun diagonalEmbedding(
|
||||||
diagonalEntries: TensorType,
|
diagonalEntries: TensorType,
|
||||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||||
): TensorType
|
): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||||
public fun TensorType.transposeAssign(i: Int, j: Int): Unit
|
public fun TensorType.transposeAssign(i: Int, j: Int): Unit
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/tensor_view.html
|
||||||
public fun TensorType.view(shape: IntArray): TensorType
|
public fun TensorType.view(shape: IntArray): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
||||||
public fun TensorType.abs(): TensorType
|
public fun TensorType.abs(): TensorType
|
||||||
public fun TensorType.absAssign(): Unit
|
public fun TensorType.absAssign(): Unit
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.sum.html
|
||||||
public fun TensorType.sum(): TensorType
|
public fun TensorType.sum(): TensorType
|
||||||
public fun TensorType.sumAssign(): Unit
|
public fun TensorType.sumAssign(): Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
||||||
|
|
||||||
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
|
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
|
||||||
TensorAlgebra<T, TensorType> {
|
TensorAlgebra<T, TensorType> {
|
||||||
|
|
||||||
public operator fun TensorType.div(other: TensorType): TensorType
|
public operator fun TensorType.div(other: TensorType): TensorType
|
||||||
public operator fun TensorType.divAssign(other: TensorType)
|
public operator fun TensorType.divAssign(other: TensorType)
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||||
public fun TensorType.exp(): TensorType
|
public fun TensorType.exp(): TensorType
|
||||||
public fun TensorType.expAssign(): Unit
|
public fun TensorType.expAssign(): Unit
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.log.html
|
||||||
public fun TensorType.log(): TensorType
|
public fun TensorType.log(): TensorType
|
||||||
public fun TensorType.logAssign(): Unit
|
public fun TensorType.logAssign(): Unit
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.svd.html
|
||||||
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.symeig.html
|
||||||
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkShapeCompatible(
|
||||||
|
a: TensorType, b: TensorType
|
||||||
|
): Unit =
|
||||||
|
check(a.shape contentEquals b.shape) {
|
||||||
|
"Tensors must be of identical shape"
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||||
|
val sa = a.shape
|
||||||
|
val sb = b.shape
|
||||||
|
val na = sa.size
|
||||||
|
val nb = sb.size
|
||||||
|
var status: Boolean
|
||||||
|
if (nb == 1) {
|
||||||
|
status = sa.last() == sb[0]
|
||||||
|
} else {
|
||||||
|
status = sa.last() == sb[nb - 2]
|
||||||
|
if ((na > 2) and (nb > 2)) {
|
||||||
|
status = status and
|
||||||
|
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||||
|
check((i < dim) and (j < dim)) {
|
||||||
|
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||||
|
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
@ -0,0 +1,50 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.Strides
|
||||||
|
import space.kscience.kmath.nd.offsetFromIndex
|
||||||
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
|
inline public fun stridesFromShape(shape: IntArray): IntArray {
|
||||||
|
val nDim = shape.size
|
||||||
|
val res = IntArray(nDim)
|
||||||
|
if (nDim == 0)
|
||||||
|
return res
|
||||||
|
|
||||||
|
var current = nDim - 1
|
||||||
|
res[current] = 1
|
||||||
|
|
||||||
|
while (current > 0) {
|
||||||
|
res[current - 1] = max(1, shape[current]) * res[current]
|
||||||
|
current--
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
inline public fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||||
|
val res = IntArray(nDim)
|
||||||
|
var current = offset
|
||||||
|
var strideIndex = 0
|
||||||
|
|
||||||
|
while (strideIndex < nDim) {
|
||||||
|
res[strideIndex] = (current / strides[strideIndex])
|
||||||
|
current %= strides[strideIndex]
|
||||||
|
strideIndex++
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class TensorStrides(override val shape: IntArray) : Strides {
|
||||||
|
override val strides: IntArray
|
||||||
|
get() = stridesFromShape(shape)
|
||||||
|
|
||||||
|
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
||||||
|
|
||||||
|
override fun index(offset: Int): IntArray =
|
||||||
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
|
override val linearSize: Int
|
||||||
|
get() = shape.fold(1) { acc, i -> acc * i }
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.MutableNDStructure
|
||||||
|
|
||||||
|
public interface TensorStructure<T> : MutableNDStructure<T> {
|
||||||
|
public fun item(): T
|
||||||
|
|
||||||
|
// A tensor can have empty shape, in which case it represents just a value
|
||||||
|
public fun value(): T {
|
||||||
|
checkIsValue()
|
||||||
|
return item()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T> TensorStructure<T>.isValue(): Boolean {
|
||||||
|
return (dimension == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T> TensorStructure<T>.isNotValue(): Boolean = !this.isValue()
|
||||||
|
|
||||||
|
public inline fun <T> TensorStructure<T>.checkIsValue(): Unit = check(this.isValue()) {
|
||||||
|
"This tensor has shape ${shape.toList()}"
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.array
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class TestRealTensor {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun valueTest(){
|
||||||
|
val value = 12.5
|
||||||
|
val tensor = RealTensor(IntArray(0), doubleArrayOf(value))
|
||||||
|
assertEquals(tensor.value(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun stridesTest(){
|
||||||
|
val tensor = RealTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||||
|
assertEquals(tensor[intArrayOf(0,1)], 5.8)
|
||||||
|
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.array)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.array
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class TestRealTensorAlgebra {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun doublePlus() = RealTensorAlgebra {
|
||||||
|
val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||||
|
val res = 10.0 + tensor
|
||||||
|
assertTrue(res.buffer.array contentEquals doubleArrayOf(11.0,12.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -10,13 +10,13 @@ public sealed class Device {
|
|||||||
public data class CUDA(val index: Int): Device()
|
public data class CUDA(val index: Int): Device()
|
||||||
public fun toInt(): Int {
|
public fun toInt(): Int {
|
||||||
when(this) {
|
when(this) {
|
||||||
is Device.CPU -> return 0
|
is CPU -> return 0
|
||||||
is Device.CUDA -> return this.index + 1
|
is CUDA -> return this.index + 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun fromInt(deviceInt: Int): Device {
|
public fun fromInt(deviceInt: Int): Device {
|
||||||
return if (deviceInt == 0) Device.CPU else Device.CUDA(
|
return if (deviceInt == 0) CPU else CUDA(
|
||||||
deviceInt - 1
|
deviceInt - 1
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -2,17 +2,14 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.torch
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
import space.kscience.kmath.tensors.TensorStructure
|
import space.kscience.kmath.tensors.*
|
||||||
|
|
||||||
public interface TorchTensor<T> : TensorStructure<T> {
|
public interface TorchTensor<T> : TensorStructure<T> {
|
||||||
public fun item(): T
|
|
||||||
public val strides: IntArray
|
public val strides: IntArray
|
||||||
public val size: Int
|
public val size: Int
|
||||||
public val device: Device
|
public val device: Device
|
||||||
override fun value(): T {
|
|
||||||
checkIsValue()
|
|
||||||
return item()
|
|
||||||
}
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
if (dimension == 0) {
|
if (dimension == 0) {
|
||||||
return emptySequence()
|
return emptySequence()
|
||||||
@ -22,31 +19,8 @@ public interface TorchTensor<T> : TensorStructure<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <T> TorchTensor<T>.isValue(): Boolean {
|
|
||||||
return (dimension == 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T> TorchTensor<T>.isNotValue(): Boolean = !this.isValue()
|
|
||||||
|
|
||||||
public inline fun <T> TorchTensor<T>.checkIsValue(): Unit = check(this.isValue()) {
|
|
||||||
"This tensor has shape ${shape.toList()}"
|
|
||||||
}
|
|
||||||
|
|
||||||
public interface TorchTensorOverField<T>: TorchTensor<T>
|
public interface TorchTensorOverField<T>: TorchTensor<T>
|
||||||
{
|
{
|
||||||
public var requiresGrad: Boolean
|
public var requiresGrad: Boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
|
||||||
val res = IntArray(nDim)
|
|
||||||
var current = offset
|
|
||||||
var strideIndex = 0
|
|
||||||
|
|
||||||
while (strideIndex < nDim) {
|
|
||||||
res[strideIndex] = (current / strides[strideIndex])
|
|
||||||
current %= strides[strideIndex]
|
|
||||||
strideIndex++
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.torch
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
import space.kscience.kmath.tensors.TensorAlgebra
|
import space.kscience.kmath.tensors.*
|
||||||
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
|
|
||||||
|
|
||||||
public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>> :
|
public interface TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>> :
|
||||||
TensorAlgebra<T, TorchTensorType> {
|
TensorAlgebra<T, TorchTensorType> {
|
||||||
@ -75,15 +74,6 @@ public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|||||||
"Tensors must be on the same device"
|
"Tensors must be on the same device"
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|
||||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
|
||||||
TorchTensorAlgebraType.checkShapeCompatible(
|
|
||||||
a: TorchTensorType,
|
|
||||||
b: TorchTensorType
|
|
||||||
): Unit =
|
|
||||||
check(a.shape contentEquals b.shape) {
|
|
||||||
"Tensors must be of identical shape"
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
||||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
@ -92,8 +82,8 @@ public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|||||||
b: TorchTensorType
|
b: TorchTensorType
|
||||||
) {
|
) {
|
||||||
if (a.isNotValue() and b.isNotValue()) {
|
if (a.isNotValue() and b.isNotValue()) {
|
||||||
this.checkDeviceCompatible(a, b)
|
checkDeviceCompatible(a, b)
|
||||||
this.checkShapeCompatible(a, b)
|
checkShapeCompatible(a, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,35 +91,8 @@ public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|||||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
TorchTensorAlgebraType.checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
|
||||||
checkDeviceCompatible(a, b)
|
checkDeviceCompatible(a, b)
|
||||||
val sa = a.shape
|
checkDot(a,b)
|
||||||
val sb = b.shape
|
|
||||||
val na = sa.size
|
|
||||||
val nb = sb.size
|
|
||||||
var status: Boolean
|
|
||||||
if (nb == 1) {
|
|
||||||
status = sa.last() == sb[0]
|
|
||||||
} else {
|
|
||||||
status = sa.last() == sb[nb - 2]
|
|
||||||
if ((na > 2) and (nb > 2)) {
|
|
||||||
status = status and
|
|
||||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|
||||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
|
||||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
|
||||||
check((i < dim) and (j < dim)) {
|
|
||||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensor<T>,
|
|
||||||
TorchTensorAlgebraType : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
|
||||||
TorchTensorAlgebraType.checkView(a: TorchTensorType, shape: IntArray): Unit =
|
|
||||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
|
||||||
|
|
||||||
|
|
||||||
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
public inline fun <T, PrimitiveArrayType, TorchTensorType : TorchTensorOverField<T>,
|
||||||
TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
TorchTensorDivisionAlgebraType : TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType>>
|
||||||
|
@ -29,9 +29,9 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
|
|
||||||
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
|
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
|
||||||
|
|
||||||
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, other)
|
if (checks) checkLinearOperation(this, b)
|
||||||
return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle))
|
return wrap(JTorch.timesTensor(this.tensorHandle, b.tensorHandle))
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
@ -39,9 +39,9 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
|
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, other)
|
if (checks) checkLinearOperation(this, b)
|
||||||
return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle))
|
return wrap(JTorch.plusTensor(this.tensorHandle, b.tensorHandle))
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
@ -49,9 +49,9 @@ public sealed class TorchTensorAlgebraJVM<
|
|||||||
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
|
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, other)
|
if (checks) checkLinearOperation(this, b)
|
||||||
return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle))
|
return wrap(JTorch.minusTensor(this.tensorHandle, b.tensorHandle))
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
package space.kscience.kmath.torch
|
package space.kscience.kmath.torch
|
||||||
|
|
||||||
|
|
||||||
import space.kscience.kmath.memory.DeferScope
|
import space.kscience.kmath.memory.DeferScope
|
||||||
import space.kscience.kmath.memory.withDeferScope
|
import space.kscience.kmath.memory.withDeferScope
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
|
import space.kscience.kmath.tensors.*
|
||||||
import space.kscience.kmath.torch.ctorch.*
|
import space.kscience.kmath.torch.ctorch.*
|
||||||
|
|
||||||
public sealed class TorchTensorAlgebraNative<
|
public sealed class TorchTensorAlgebraNative<
|
||||||
@ -38,9 +38,9 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
||||||
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
||||||
|
|
||||||
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, other)
|
if (checks) checkLinearOperation(this, b)
|
||||||
return wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
@ -48,9 +48,9 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, other)
|
if (checks) checkLinearOperation(this, b)
|
||||||
return wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
@ -58,9 +58,9 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType {
|
||||||
if (checks) checkLinearOperation(this, other)
|
if (checks) checkLinearOperation(this, b)
|
||||||
return wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
@ -68,6 +68,10 @@ public sealed class TorchTensorAlgebraNative<
|
|||||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
|
||||||
|
|
||||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
wrap(unary_minus(this.tensorHandle)!!)
|
wrap(unary_minus(this.tensorHandle)!!)
|
||||||
|
|
||||||
@ -254,6 +258,15 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||||
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
|
||||||
|
|
||||||
|
override val zero: TorchTensorReal
|
||||||
|
get() = full(0.0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
override val one: TorchTensorReal
|
||||||
|
get() = full(1.0, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -317,6 +330,15 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
|||||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
|
||||||
|
|
||||||
|
override val zero: TorchTensorFloat
|
||||||
|
get() = full(0f, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
|
||||||
|
override val one: TorchTensorFloat
|
||||||
|
get() = full(1f, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
@ -372,6 +394,15 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||||
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
|
||||||
|
|
||||||
|
override val zero: TorchTensorLong
|
||||||
|
get() = full(0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
|
||||||
|
override val one: TorchTensorLong
|
||||||
|
get() = full(1, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
@ -427,6 +458,16 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
|||||||
|
|
||||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
|
||||||
|
override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
|
||||||
|
|
||||||
|
override val zero: TorchTensorInt
|
||||||
|
get() = full(0, IntArray(0), Device.CPU)
|
||||||
|
|
||||||
|
|
||||||
|
override val one: TorchTensorInt
|
||||||
|
get() = full(1, IntArray(0), Device.CPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user