KMP library for tensors #300
@ -1,9 +1,8 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.linear.BufferMatrix
|
||||
import space.kscience.kmath.nd.MutableNDBuffer
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.structures.BufferAccessor2D
|
||||
|
||||
|
||||
public open class BufferedTensor<T>(
|
||||
override val shape: IntArray,
|
||||
@ -15,13 +14,13 @@ public open class BufferedTensor<T>(
|
||||
buffer
|
||||
) {
|
||||
|
||||
public operator fun get(i: Int, j: Int): T{
|
||||
check(this.dimension == 2) {"Not matrix"}
|
||||
public operator fun get(i: Int, j: Int): T {
|
||||
check(this.dimension == 2) { "Not matrix" }
|
||||
return this[intArrayOf(i, j)]
|
||||
}
|
||||
|
||||
public operator fun set(i: Int, j: Int, value: T): Unit{
|
||||
check(this.dimension == 2) {"Not matrix"}
|
||||
public operator fun set(i: Int, j: Int, value: T): Unit {
|
||||
check(this.dimension == 2) { "Not matrix" }
|
||||
this[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
@ -31,3 +30,18 @@ public class IntTensor(
|
||||
shape: IntArray,
|
||||
buffer: IntArray
|
||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer))
|
||||
|
||||
public class LongTensor(
|
||||
shape: IntArray,
|
||||
buffer: LongArray
|
||||
) : BufferedTensor<Long>(shape, LongBuffer(buffer))
|
||||
|
||||
public class FloatTensor(
|
||||
shape: IntArray,
|
||||
buffer: FloatArray
|
||||
) : BufferedTensor<Float>(shape, FloatBuffer(buffer))
|
||||
|
||||
public class RealTensor(
|
||||
shape: IntArray,
|
||||
buffer: DoubleArray
|
||||
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
|
@ -6,11 +6,6 @@ import kotlin.math.abs
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
public class RealTensor(
|
||||
shape: IntArray,
|
||||
buffer: DoubleArray
|
||||
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
|
||||
|
||||
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||
|
||||
override fun RealTensor.value(): Double {
|
||||
@ -54,64 +49,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
||||
return RealTensor(this.shape, this.buffer.array.copyOf())
|
||||
}
|
||||
|
||||
override fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||
var totalDim = 0
|
||||
for (shape in shapes) {
|
||||
totalDim = max(totalDim, shape.size)
|
||||
}
|
||||
|
||||
val totalShape = IntArray(totalDim) {0}
|
||||
for (shape in shapes) {
|
||||
for (i in shape.indices) {
|
||||
val curDim = shape[i]
|
||||
val offset = totalDim - shape.size
|
||||
totalShape[i + offset] = max(totalShape[i + offset], curDim)
|
||||
}
|
||||
}
|
||||
|
||||
for (shape in shapes) {
|
||||
for (i in shape.indices) {
|
||||
val curDim = shape[i]
|
||||
val offset = totalDim - shape.size
|
||||
if (curDim != 1 && totalShape[i + offset] != curDim) {
|
||||
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return totalShape
|
||||
}
|
||||
|
||||
override fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
|
||||
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||
val n = totalShape.reduce{ acc, i -> acc * i }
|
||||
|
||||
val res = ArrayList<RealTensor>(0)
|
||||
for (tensor in tensors) {
|
||||
val resTensor = RealTensor(totalShape, DoubleArray(n))
|
||||
|
||||
for (linearIndex in 0 until n) {
|
||||
val totalMultiIndex = resTensor.strides.index(linearIndex)
|
||||
val curMultiIndex = tensor.shape.copyOf()
|
||||
|
||||
val offset = totalMultiIndex.size - curMultiIndex.size
|
||||
|
||||
for (i in curMultiIndex.indices) {
|
||||
if (curMultiIndex[i] != 1) {
|
||||
curMultiIndex[i] = totalMultiIndex[i + offset]
|
||||
} else {
|
||||
curMultiIndex[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
|
||||
}
|
||||
res.add(resTensor)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
override fun Double.plus(other: RealTensor): RealTensor {
|
||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||
|
@ -13,9 +13,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
|
||||
public fun TensorType.copy(): TensorType
|
||||
|
||||
public fun broadcastShapes(vararg shapes: IntArray): IntArray
|
||||
public fun broadcastTensors(vararg tensors: RealTensor): List<TensorType>
|
||||
|
||||
public operator fun T.plus(other: TensorType): TensorType
|
||||
public operator fun TensorType.plus(value: T): TensorType
|
||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||
@ -88,44 +85,3 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
||||
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
||||
|
||||
}
|
||||
|
||||
public inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkShapeCompatible(
|
||||
a: TensorType, b: TensorType
|
||||
): Unit =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Tensors must be of identical shape"
|
||||
}
|
||||
|
||||
public inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||
val sa = a.shape
|
||||
val sb = b.shape
|
||||
val na = sa.size
|
||||
val nb = sb.size
|
||||
var status: Boolean
|
||||
if (nb == 1) {
|
||||
status = sa.last() == sb[0]
|
||||
} else {
|
||||
status = sa.last() == sb[nb - 2]
|
||||
if ((na > 2) and (nb > 2)) {
|
||||
status = status and
|
||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||
}
|
||||
}
|
||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||
}
|
||||
|
||||
public inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
public inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
@ -5,7 +5,7 @@ import space.kscience.kmath.nd.offsetFromIndex
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
public inline fun stridesFromShape(shape: IntArray): IntArray {
|
||||
internal inline fun stridesFromShape(shape: IntArray): IntArray {
|
||||
val nDim = shape.size
|
||||
val res = IntArray(nDim)
|
||||
if (nDim == 0)
|
||||
@ -22,7 +22,7 @@ public inline fun stridesFromShape(shape: IntArray): IntArray {
|
||||
|
||||
}
|
||||
|
||||
public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
val res = IntArray(nDim)
|
||||
var current = offset
|
||||
var strideIndex = 0
|
||||
@ -35,7 +35,7 @@ public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): In
|
||||
return res
|
||||
}
|
||||
|
||||
public inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||
internal inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||
val res = index.copyOf()
|
||||
var current = nDim - 1
|
||||
var carry = 0
|
||||
|
@ -0,0 +1,96 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.array
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||
var totalDim = 0
|
||||
for (shape in shapes) {
|
||||
totalDim = max(totalDim, shape.size)
|
||||
}
|
||||
|
||||
val totalShape = IntArray(totalDim) { 0 }
|
||||
for (shape in shapes) {
|
||||
for (i in shape.indices) {
|
||||
val curDim = shape[i]
|
||||
val offset = totalDim - shape.size
|
||||
totalShape[i + offset] = max(totalShape[i + offset], curDim)
|
||||
}
|
||||
}
|
||||
|
||||
for (shape in shapes) {
|
||||
for (i in shape.indices) {
|
||||
val curDim = shape[i]
|
||||
val offset = totalDim - shape.size
|
||||
if (curDim != 1 && totalShape[i + offset] != curDim) {
|
||||
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return totalShape
|
||||
}
|
||||
|
||||
internal inline fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
|
||||
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||
val n = totalShape.reduce { acc, i -> acc * i }
|
||||
|
||||
val res = ArrayList<RealTensor>(0)
|
||||
for (tensor in tensors) {
|
||||
val resTensor = RealTensor(totalShape, DoubleArray(n))
|
||||
|
||||
for (linearIndex in 0 until n) {
|
||||
val totalMultiIndex = resTensor.strides.index(linearIndex)
|
||||
val curMultiIndex = tensor.shape.copyOf()
|
||||
|
||||
val offset = totalMultiIndex.size - curMultiIndex.size
|
||||
|
||||
for (i in curMultiIndex.indices) {
|
||||
if (curMultiIndex[i] != 1) {
|
||||
curMultiIndex[i] = totalMultiIndex[i + offset]
|
||||
} else {
|
||||
curMultiIndex[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
|
||||
}
|
||||
res.add(resTensor)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||
val sa = a.shape
|
||||
val sb = b.shape
|
||||
val na = sa.size
|
||||
val nb = sb.size
|
||||
var status: Boolean
|
||||
if (nb == 1) {
|
||||
status = sa.last() == sb[0]
|
||||
} else {
|
||||
status = sa.last() == sb[nb - 2]
|
||||
if ((na > 2) and (nb > 2)) {
|
||||
status = status and
|
||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||
}
|
||||
}
|
||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
@ -2,8 +2,6 @@ package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.array
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertFails
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class TestRealTensorAlgebra {
|
||||
@ -51,11 +49,11 @@ class TestRealTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun broadcastShapes() = RealTensorAlgebra {
|
||||
assertTrue(this.broadcastShapes(
|
||||
assertTrue(broadcastShapes(
|
||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||
) contentEquals intArrayOf(1, 2, 3))
|
||||
|
||||
assertTrue(this.broadcastShapes(
|
||||
assertTrue(broadcastShapes(
|
||||
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
|
||||
) contentEquals intArrayOf(5, 6, 7))
|
||||
}
|
||||
@ -66,7 +64,7 @@ class TestRealTensorAlgebra {
|
||||
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val res = this.broadcastTensors(tensor1, tensor2, tensor3)
|
||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||
|
Loading…
Reference in New Issue
Block a user