add functions transpose and transposeAssign
This commit is contained in:
parent
3a37b88b5c
commit
723e0e458e
@ -174,6 +174,11 @@ public interface Strides {
|
|||||||
*/
|
*/
|
||||||
public fun index(offset: Int): IntArray
|
public fun index(offset: Int): IntArray
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get next multidimensional index from the current multidimensional index
|
||||||
|
*/
|
||||||
|
public fun nextIndex(index: IntArray): IntArray
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||||
*/
|
*/
|
||||||
@ -232,6 +237,10 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun nextIndex(index: IntArray): IntArray {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other !is DefaultStrides) return false
|
if (other !is DefaultStrides) return false
|
||||||
|
@ -164,11 +164,33 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.transpose(i: Int, j: Int): RealTensor {
|
override fun RealTensor.transpose(i: Int, j: Int): RealTensor {
|
||||||
TODO("Alya")
|
val n = this.buffer.size
|
||||||
|
val resBuffer = DoubleArray(n)
|
||||||
|
|
||||||
|
val resShape = this.shape.copyOf()
|
||||||
|
resShape[i] = resShape[j].also { resShape[j] = resShape[i] }
|
||||||
|
|
||||||
|
val resTensor = RealTensor(resShape, resBuffer)
|
||||||
|
|
||||||
|
for (offset in 0 until n) {
|
||||||
|
val oldMultiIndex = this.strides.index(offset)
|
||||||
|
val newMultiIndex = oldMultiIndex.copyOf()
|
||||||
|
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
||||||
|
|
||||||
|
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
||||||
|
resTensor.buffer.array[linearIndex] = this.buffer.array[offset]
|
||||||
|
}
|
||||||
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.transposeAssign(i: Int, j: Int) {
|
override fun RealTensor.transposeAssign(i: Int, j: Int) {
|
||||||
TODO("Alya")
|
val transposedTensor = this.transpose(i, j)
|
||||||
|
for (i in transposedTensor.shape.indices) {
|
||||||
|
this.shape[i] = transposedTensor.shape[i]
|
||||||
|
}
|
||||||
|
for (i in transposedTensor.buffer.array.indices) {
|
||||||
|
this.buffer.array[i] = transposedTensor.buffer.array[i]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.view(shape: IntArray): RealTensor {
|
override fun RealTensor.view(shape: IntArray): RealTensor {
|
||||||
|
@ -35,6 +35,23 @@ public inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): In
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||||
|
val res = index.copyOf()
|
||||||
|
var current = nDim - 1
|
||||||
|
var carry = 0
|
||||||
|
|
||||||
|
do {
|
||||||
|
res[current]++
|
||||||
|
if (res[current] >= shape[current]) {
|
||||||
|
carry = 1
|
||||||
|
res[current] = 0
|
||||||
|
}
|
||||||
|
current--
|
||||||
|
} while(carry != 0 && current >= 0)
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public class TensorStrides(override val shape: IntArray): Strides
|
public class TensorStrides(override val shape: IntArray): Strides
|
||||||
@ -47,6 +64,9 @@ public class TensorStrides(override val shape: IntArray): Strides
|
|||||||
override fun index(offset: Int): IntArray =
|
override fun index(offset: Int): IntArray =
|
||||||
indexFromOffset(offset, strides, shape.size)
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
|
override fun nextIndex(index: IntArray): IntArray =
|
||||||
|
nextIndex(index, shape, shape.size)
|
||||||
|
|
||||||
override val linearSize: Int
|
override val linearSize: Int
|
||||||
get() = shape.reduce(Int::times)
|
get() = shape.reduce(Int::times)
|
||||||
}
|
}
|
@ -13,4 +13,55 @@ class TestRealTensorAlgebra {
|
|||||||
assertTrue(res.buffer.array contentEquals doubleArrayOf(11.0,12.0))
|
assertTrue(res.buffer.array contentEquals doubleArrayOf(11.0,12.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transpose1x1() = RealTensorAlgebra {
|
||||||
|
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||||
|
val res = tensor.transpose(0, 0)
|
||||||
|
|
||||||
|
assertTrue(res.buffer.array contentEquals doubleArrayOf(0.0))
|
||||||
|
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transpose3x2() = RealTensorAlgebra {
|
||||||
|
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val res = tensor.transpose(1, 0)
|
||||||
|
|
||||||
|
assertTrue(res.buffer.array contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||||
|
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transpose1x2x3() = RealTensorAlgebra {
|
||||||
|
val tensor = RealTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val res01 = tensor.transpose(0, 1)
|
||||||
|
val res02 = tensor.transpose(0, 2)
|
||||||
|
val res12 = tensor.transpose(1, 2)
|
||||||
|
|
||||||
|
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
||||||
|
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||||
|
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||||
|
|
||||||
|
assertTrue(res01.buffer.array contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
|
assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transposeAssign1x2() = RealTensorAlgebra {
|
||||||
|
val tensor = RealTensor(intArrayOf(1,2), doubleArrayOf(1.0, 2.0))
|
||||||
|
tensor.transposeAssign(0, 1)
|
||||||
|
|
||||||
|
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 2.0))
|
||||||
|
assertTrue(tensor.shape contentEquals intArrayOf(2, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transposeAssign2x3() = RealTensorAlgebra {
|
||||||
|
val tensor = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
tensor.transposeAssign(1, 0)
|
||||||
|
|
||||||
|
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
|
assertTrue(tensor.shape contentEquals intArrayOf(3, 2))
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user