add functions transpose and transposeAssign
This commit is contained in:
parent
3a37b88b5c
commit
723e0e458e
@ -174,6 +174,11 @@ public interface Strides {
|
||||
*/
|
||||
public fun index(offset: Int): IntArray
|
||||
|
||||
/**
|
||||
* Get next multidimensional index from the current multidimensional index
|
||||
*/
|
||||
public fun nextIndex(index: IntArray): IntArray
|
||||
|
||||
/**
|
||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||
*/
|
||||
@ -232,6 +237,10 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
return res
|
||||
}
|
||||
|
||||
override fun nextIndex(index: IntArray): IntArray {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is DefaultStrides) return false
|
||||
|
@ -164,11 +164,33 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
||||
}
|
||||
|
||||
override fun RealTensor.transpose(i: Int, j: Int): RealTensor {
|
||||
TODO("Alya")
|
||||
val n = this.buffer.size
|
||||
val resBuffer = DoubleArray(n)
|
||||
|
||||
val resShape = this.shape.copyOf()
|
||||
resShape[i] = resShape[j].also { resShape[j] = resShape[i] }
|
||||
|
||||
val resTensor = RealTensor(resShape, resBuffer)
|
||||
|
||||
for (offset in 0 until n) {
|
||||
val oldMultiIndex = this.strides.index(offset)
|
||||
val newMultiIndex = oldMultiIndex.copyOf()
|
||||
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
||||
|
||||
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
||||
resTensor.buffer.array[linearIndex] = this.buffer.array[offset]
|
||||
}
|
||||
return resTensor
|
||||
}
|
||||
|
||||
override fun RealTensor.transposeAssign(i: Int, j: Int) {
|
||||
TODO("Alya")
|
||||
val transposedTensor = this.transpose(i, j)
|
||||
for (i in transposedTensor.shape.indices) {
|
||||
this.shape[i] = transposedTensor.shape[i]
|
||||
}
|
||||
for (i in transposedTensor.buffer.array.indices) {
|
||||
this.buffer.array[i] = transposedTensor.buffer.array[i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun RealTensor.view(shape: IntArray): RealTensor {
|
||||
|
@ -35,6 +35,23 @@ public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): In
|
||||
return res
|
||||
}
|
||||
|
||||
public inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||
val res = index.copyOf()
|
||||
var current = nDim - 1
|
||||
var carry = 0
|
||||
|
||||
do {
|
||||
res[current]++
|
||||
if (res[current] >= shape[current]) {
|
||||
carry = 1
|
||||
res[current] = 0
|
||||
}
|
||||
current--
|
||||
} while(carry != 0 && current >= 0)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
|
||||
|
||||
public class TensorStrides(override val shape: IntArray): Strides
|
||||
@ -47,6 +64,9 @@ public class TensorStrides(override val shape: IntArray): Strides
|
||||
override fun index(offset: Int): IntArray =
|
||||
indexFromOffset(offset, strides, shape.size)
|
||||
|
||||
override fun nextIndex(index: IntArray): IntArray =
|
||||
nextIndex(index, shape, shape.size)
|
||||
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
}
|
@ -13,4 +13,55 @@ class TestRealTensorAlgebra {
|
||||
assertTrue(res.buffer.array contentEquals doubleArrayOf(11.0,12.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x1() = RealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||
val res = tensor.transpose(0, 0)
|
||||
|
||||
assertTrue(res.buffer.array contentEquals doubleArrayOf(0.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose3x2() = RealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res = tensor.transpose(1, 0)
|
||||
|
||||
assertTrue(res.buffer.array contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x2x3() = RealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res01 = tensor.transpose(0, 1)
|
||||
val res02 = tensor.transpose(0, 2)
|
||||
val res12 = tensor.transpose(1, 2)
|
||||
|
||||
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||
|
||||
assertTrue(res01.buffer.array contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transposeAssign1x2() = RealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(1,2), doubleArrayOf(1.0, 2.0))
|
||||
tensor.transposeAssign(0, 1)
|
||||
|
||||
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 2.0))
|
||||
assertTrue(tensor.shape contentEquals intArrayOf(2, 1))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transposeAssign2x3() = RealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
tensor.transposeAssign(1, 0)
|
||||
|
||||
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
assertTrue(tensor.shape contentEquals intArrayOf(3, 2))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user