utils module for tensors

This commit is contained in:
Roland Grinis 2021-03-13 19:16:13 +00:00
parent 8f88a101d2
commit 384415dc98
6 changed files with 124 additions and 123 deletions

View File

@ -1,9 +1,8 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.linear.BufferMatrix
import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.structures.*
import space.kscience.kmath.structures.BufferAccessor2D
public open class BufferedTensor<T>(
override val shape: IntArray,
@ -15,13 +14,13 @@ public open class BufferedTensor<T>(
buffer
) {
public operator fun get(i: Int, j: Int): T{
check(this.dimension == 2) {"Not matrix"}
return this[intArrayOf(i, j)]
public operator fun get(i: Int, j: Int): T {
check(this.dimension == 2) { "Not matrix" }
return this[intArrayOf(i, j)]
}
public operator fun set(i: Int, j: Int, value: T): Unit{
check(this.dimension == 2) {"Not matrix"}
public operator fun set(i: Int, j: Int, value: T): Unit {
check(this.dimension == 2) { "Not matrix" }
this[intArrayOf(i, j)] = value
}
@ -31,3 +30,18 @@ public class IntTensor(
shape: IntArray,
buffer: IntArray
) : BufferedTensor<Int>(shape, IntBuffer(buffer))
public class LongTensor(
shape: IntArray,
buffer: LongArray
) : BufferedTensor<Long>(shape, LongBuffer(buffer))
public class FloatTensor(
shape: IntArray,
buffer: FloatArray
) : BufferedTensor<Float>(shape, FloatBuffer(buffer))
public class RealTensor(
shape: IntArray,
buffer: DoubleArray
) : BufferedTensor<Double>(shape, RealBuffer(buffer))

View File

@ -6,11 +6,6 @@ import kotlin.math.abs
import kotlin.math.max
public class RealTensor(
shape: IntArray,
buffer: DoubleArray
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
override fun RealTensor.value(): Double {
@ -54,64 +49,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
return RealTensor(this.shape, this.buffer.array.copyOf())
}
override fun broadcastShapes(vararg shapes: IntArray): IntArray {
var totalDim = 0
for (shape in shapes) {
totalDim = max(totalDim, shape.size)
}
val totalShape = IntArray(totalDim) {0}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
totalShape[i + offset] = max(totalShape[i + offset], curDim)
}
}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
if (curDim != 1 && totalShape[i + offset] != curDim) {
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
}
}
}
return totalShape
}
override fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce{ acc, i -> acc * i }
val res = ArrayList<RealTensor>(0)
for (tensor in tensors) {
val resTensor = RealTensor(totalShape, DoubleArray(n))
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
}
res.add(resTensor)
}
return res
}
override fun Double.plus(other: RealTensor): RealTensor {
val resBuffer = DoubleArray(other.buffer.size) { i ->

View File

@ -13,9 +13,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.copy(): TensorType
public fun broadcastShapes(vararg shapes: IntArray): IntArray
public fun broadcastTensors(vararg tensors: RealTensor): List<TensorType>
public operator fun T.plus(other: TensorType): TensorType
public operator fun TensorType.plus(value: T): TensorType
public operator fun TensorType.plus(other: TensorType): TensorType
@ -88,44 +85,3 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
}
public inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkShapeCompatible(
a: TensorType, b: TensorType
): Unit =
check(a.shape contentEquals b.shape) {
"Tensors must be of identical shape"
}
public inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
val sa = a.shape
val sb = b.shape
val na = sa.size
val nb = sb.size
var status: Boolean
if (nb == 1) {
status = sa.last() == sb[0]
} else {
status = sa.last() == sb[nb - 2]
if ((na > 2) and (nb > 2)) {
status = status and
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
}
}
check(status) { "Incompatible shapes $sa and $sb for dot product" }
}
public inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
check((i < dim) and (j < dim)) {
"Cannot transpose $i to $j for a tensor of dim $dim"
}
public inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))

View File

@ -5,7 +5,7 @@ import space.kscience.kmath.nd.offsetFromIndex
import kotlin.math.max
public inline fun stridesFromShape(shape: IntArray): IntArray {
internal inline fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
@ -22,7 +22,7 @@ public inline fun stridesFromShape(shape: IntArray): IntArray {
}
public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
@ -35,7 +35,7 @@ public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): In
return res
}
public inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
internal inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
val res = index.copyOf()
var current = nDim - 1
var carry = 0

View File

@ -0,0 +1,96 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.array
import kotlin.math.max
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
var totalDim = 0
for (shape in shapes) {
totalDim = max(totalDim, shape.size)
}
val totalShape = IntArray(totalDim) { 0 }
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
totalShape[i + offset] = max(totalShape[i + offset], curDim)
}
}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
if (curDim != 1 && totalShape[i + offset] != curDim) {
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
}
}
}
return totalShape
}
internal inline fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i }
val res = ArrayList<RealTensor>(0)
for (tensor in tensors) {
val resTensor = RealTensor(totalShape, DoubleArray(n))
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
}
res.add(resTensor)
}
return res
}
internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
val sa = a.shape
val sb = b.shape
val na = sa.size
val nb = sb.size
var status: Boolean
if (nb == 1) {
status = sa.last() == sb[0]
} else {
status = sa.last() == sb[nb - 2]
if ((na > 2) and (nb > 2)) {
status = status and
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
}
}
check(status) { "Incompatible shapes $sa and $sb for dot product" }
}
internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
check((i < dim) and (j < dim)) {
"Cannot transpose $i to $j for a tensor of dim $dim"
}
internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))

View File

@ -2,8 +2,6 @@ package space.kscience.kmath.tensors
import space.kscience.kmath.structures.array
import kotlin.test.Test
import kotlin.test.assertFails
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class TestRealTensorAlgebra {
@ -51,11 +49,11 @@ class TestRealTensorAlgebra {
@Test
fun broadcastShapes() = RealTensorAlgebra {
assertTrue(this.broadcastShapes(
assertTrue(broadcastShapes(
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
) contentEquals intArrayOf(1, 2, 3))
assertTrue(this.broadcastShapes(
assertTrue(broadcastShapes(
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
) contentEquals intArrayOf(5, 6, 7))
}
@ -66,7 +64,7 @@ class TestRealTensorAlgebra {
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val res = this.broadcastTensors(tensor1, tensor2, tensor3)
val res = broadcastTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))