utils module for tensors
This commit is contained in:
parent
8f88a101d2
commit
384415dc98
@ -1,9 +1,8 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.linear.BufferMatrix
|
|
||||||
import space.kscience.kmath.nd.MutableNDBuffer
|
import space.kscience.kmath.nd.MutableNDBuffer
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
import space.kscience.kmath.structures.BufferAccessor2D
|
|
||||||
|
|
||||||
public open class BufferedTensor<T>(
|
public open class BufferedTensor<T>(
|
||||||
override val shape: IntArray,
|
override val shape: IntArray,
|
||||||
@ -15,13 +14,13 @@ public open class BufferedTensor<T>(
|
|||||||
buffer
|
buffer
|
||||||
) {
|
) {
|
||||||
|
|
||||||
public operator fun get(i: Int, j: Int): T{
|
public operator fun get(i: Int, j: Int): T {
|
||||||
check(this.dimension == 2) {"Not matrix"}
|
check(this.dimension == 2) { "Not matrix" }
|
||||||
return this[intArrayOf(i, j)]
|
return this[intArrayOf(i, j)]
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun set(i: Int, j: Int, value: T): Unit{
|
public operator fun set(i: Int, j: Int, value: T): Unit {
|
||||||
check(this.dimension == 2) {"Not matrix"}
|
check(this.dimension == 2) { "Not matrix" }
|
||||||
this[intArrayOf(i, j)] = value
|
this[intArrayOf(i, j)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,3 +30,18 @@ public class IntTensor(
|
|||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray
|
buffer: IntArray
|
||||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer))
|
) : BufferedTensor<Int>(shape, IntBuffer(buffer))
|
||||||
|
|
||||||
|
public class LongTensor(
|
||||||
|
shape: IntArray,
|
||||||
|
buffer: LongArray
|
||||||
|
) : BufferedTensor<Long>(shape, LongBuffer(buffer))
|
||||||
|
|
||||||
|
public class FloatTensor(
|
||||||
|
shape: IntArray,
|
||||||
|
buffer: FloatArray
|
||||||
|
) : BufferedTensor<Float>(shape, FloatBuffer(buffer))
|
||||||
|
|
||||||
|
public class RealTensor(
|
||||||
|
shape: IntArray,
|
||||||
|
buffer: DoubleArray
|
||||||
|
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
|
@ -6,11 +6,6 @@ import kotlin.math.abs
|
|||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
public class RealTensor(
|
|
||||||
shape: IntArray,
|
|
||||||
buffer: DoubleArray
|
|
||||||
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
|
|
||||||
|
|
||||||
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||||
|
|
||||||
override fun RealTensor.value(): Double {
|
override fun RealTensor.value(): Double {
|
||||||
@ -54,64 +49,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
return RealTensor(this.shape, this.buffer.array.copyOf())
|
return RealTensor(this.shape, this.buffer.array.copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
|
||||||
var totalDim = 0
|
|
||||||
for (shape in shapes) {
|
|
||||||
totalDim = max(totalDim, shape.size)
|
|
||||||
}
|
|
||||||
|
|
||||||
val totalShape = IntArray(totalDim) {0}
|
|
||||||
for (shape in shapes) {
|
|
||||||
for (i in shape.indices) {
|
|
||||||
val curDim = shape[i]
|
|
||||||
val offset = totalDim - shape.size
|
|
||||||
totalShape[i + offset] = max(totalShape[i + offset], curDim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (shape in shapes) {
|
|
||||||
for (i in shape.indices) {
|
|
||||||
val curDim = shape[i]
|
|
||||||
val offset = totalDim - shape.size
|
|
||||||
if (curDim != 1 && totalShape[i + offset] != curDim) {
|
|
||||||
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalShape
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
|
|
||||||
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
|
||||||
val n = totalShape.reduce{ acc, i -> acc * i }
|
|
||||||
|
|
||||||
val res = ArrayList<RealTensor>(0)
|
|
||||||
for (tensor in tensors) {
|
|
||||||
val resTensor = RealTensor(totalShape, DoubleArray(n))
|
|
||||||
|
|
||||||
for (linearIndex in 0 until n) {
|
|
||||||
val totalMultiIndex = resTensor.strides.index(linearIndex)
|
|
||||||
val curMultiIndex = tensor.shape.copyOf()
|
|
||||||
|
|
||||||
val offset = totalMultiIndex.size - curMultiIndex.size
|
|
||||||
|
|
||||||
for (i in curMultiIndex.indices) {
|
|
||||||
if (curMultiIndex[i] != 1) {
|
|
||||||
curMultiIndex[i] = totalMultiIndex[i + offset]
|
|
||||||
} else {
|
|
||||||
curMultiIndex[i] = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
|
||||||
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
|
|
||||||
}
|
|
||||||
res.add(resTensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun Double.plus(other: RealTensor): RealTensor {
|
override fun Double.plus(other: RealTensor): RealTensor {
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||||
|
@ -13,9 +13,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
|
|
||||||
public fun TensorType.copy(): TensorType
|
public fun TensorType.copy(): TensorType
|
||||||
|
|
||||||
public fun broadcastShapes(vararg shapes: IntArray): IntArray
|
|
||||||
public fun broadcastTensors(vararg tensors: RealTensor): List<TensorType>
|
|
||||||
|
|
||||||
public operator fun T.plus(other: TensorType): TensorType
|
public operator fun T.plus(other: TensorType): TensorType
|
||||||
public operator fun TensorType.plus(value: T): TensorType
|
public operator fun TensorType.plus(value: T): TensorType
|
||||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||||
@ -87,45 +84,4 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.symeig.html
|
//https://pytorch.org/docs/stable/generated/torch.symeig.html
|
||||||
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkShapeCompatible(
|
|
||||||
a: TensorType, b: TensorType
|
|
||||||
): Unit =
|
|
||||||
check(a.shape contentEquals b.shape) {
|
|
||||||
"Tensors must be of identical shape"
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
|
||||||
val sa = a.shape
|
|
||||||
val sb = b.shape
|
|
||||||
val na = sa.size
|
|
||||||
val nb = sb.size
|
|
||||||
var status: Boolean
|
|
||||||
if (nb == 1) {
|
|
||||||
status = sa.last() == sb[0]
|
|
||||||
} else {
|
|
||||||
status = sa.last() == sb[nb - 2]
|
|
||||||
if ((na > 2) and (nb > 2)) {
|
|
||||||
status = status and
|
|
||||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
|
||||||
check((i < dim) and (j < dim)) {
|
|
||||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
|
||||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
|
@ -5,7 +5,7 @@ import space.kscience.kmath.nd.offsetFromIndex
|
|||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
public inline fun stridesFromShape(shape: IntArray): IntArray {
|
internal inline fun stridesFromShape(shape: IntArray): IntArray {
|
||||||
val nDim = shape.size
|
val nDim = shape.size
|
||||||
val res = IntArray(nDim)
|
val res = IntArray(nDim)
|
||||||
if (nDim == 0)
|
if (nDim == 0)
|
||||||
@ -22,7 +22,7 @@ public inline fun stridesFromShape(shape: IntArray): IntArray {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||||
val res = IntArray(nDim)
|
val res = IntArray(nDim)
|
||||||
var current = offset
|
var current = offset
|
||||||
var strideIndex = 0
|
var strideIndex = 0
|
||||||
@ -35,7 +35,7 @@ public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): In
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
internal inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||||
val res = index.copyOf()
|
val res = index.copyOf()
|
||||||
var current = nDim - 1
|
var current = nDim - 1
|
||||||
var carry = 0
|
var carry = 0
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.array
|
||||||
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
|
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||||
|
var totalDim = 0
|
||||||
|
for (shape in shapes) {
|
||||||
|
totalDim = max(totalDim, shape.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
val totalShape = IntArray(totalDim) { 0 }
|
||||||
|
for (shape in shapes) {
|
||||||
|
for (i in shape.indices) {
|
||||||
|
val curDim = shape[i]
|
||||||
|
val offset = totalDim - shape.size
|
||||||
|
totalShape[i + offset] = max(totalShape[i + offset], curDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (shape in shapes) {
|
||||||
|
for (i in shape.indices) {
|
||||||
|
val curDim = shape[i]
|
||||||
|
val offset = totalDim - shape.size
|
||||||
|
if (curDim != 1 && totalShape[i + offset] != curDim) {
|
||||||
|
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalShape
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
|
||||||
|
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||||
|
val n = totalShape.reduce { acc, i -> acc * i }
|
||||||
|
|
||||||
|
val res = ArrayList<RealTensor>(0)
|
||||||
|
for (tensor in tensors) {
|
||||||
|
val resTensor = RealTensor(totalShape, DoubleArray(n))
|
||||||
|
|
||||||
|
for (linearIndex in 0 until n) {
|
||||||
|
val totalMultiIndex = resTensor.strides.index(linearIndex)
|
||||||
|
val curMultiIndex = tensor.shape.copyOf()
|
||||||
|
|
||||||
|
val offset = totalMultiIndex.size - curMultiIndex.size
|
||||||
|
|
||||||
|
for (i in curMultiIndex.indices) {
|
||||||
|
if (curMultiIndex[i] != 1) {
|
||||||
|
curMultiIndex[i] = totalMultiIndex[i + offset]
|
||||||
|
} else {
|
||||||
|
curMultiIndex[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||||
|
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
|
||||||
|
}
|
||||||
|
res.add(resTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||||
|
val sa = a.shape
|
||||||
|
val sb = b.shape
|
||||||
|
val na = sa.size
|
||||||
|
val nb = sb.size
|
||||||
|
var status: Boolean
|
||||||
|
if (nb == 1) {
|
||||||
|
status = sa.last() == sb[0]
|
||||||
|
} else {
|
||||||
|
status = sa.last() == sb[nb - 2]
|
||||||
|
if ((na > 2) and (nb > 2)) {
|
||||||
|
status = status and
|
||||||
|
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||||
|
check((i < dim) and (j < dim)) {
|
||||||
|
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||||
|
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
@ -2,8 +2,6 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
import space.kscience.kmath.structures.array
|
import space.kscience.kmath.structures.array
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertFails
|
|
||||||
import kotlin.test.assertFailsWith
|
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class TestRealTensorAlgebra {
|
class TestRealTensorAlgebra {
|
||||||
@ -51,11 +49,11 @@ class TestRealTensorAlgebra {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun broadcastShapes() = RealTensorAlgebra {
|
fun broadcastShapes() = RealTensorAlgebra {
|
||||||
assertTrue(this.broadcastShapes(
|
assertTrue(broadcastShapes(
|
||||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||||
) contentEquals intArrayOf(1, 2, 3))
|
) contentEquals intArrayOf(1, 2, 3))
|
||||||
|
|
||||||
assertTrue(this.broadcastShapes(
|
assertTrue(broadcastShapes(
|
||||||
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
|
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
|
||||||
) contentEquals intArrayOf(5, 6, 7))
|
) contentEquals intArrayOf(5, 6, 7))
|
||||||
}
|
}
|
||||||
@ -66,7 +64,7 @@ class TestRealTensorAlgebra {
|
|||||||
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
val res = this.broadcastTensors(tensor1, tensor2, tensor3)
|
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
Loading…
Reference in New Issue
Block a user