Optimize tensor shape computation

This commit is contained in:
Alexander Nozik 2022-10-16 20:15:37 +03:00
parent 94489b28e2
commit 8286db30af
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357

View File

@ -51,10 +51,15 @@ public abstract class Strides : ShapeIndexer {
*/
internal abstract val strides: IntArray
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
public override fun offset(index: IntArray): Int {
var res = 0
index.forEachIndexed { i, value ->
if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})")
value * strides[i]
}.sum()
res += value * strides[i]
}
return res
}
// TODO introduce a fast way to calculate index of the next element?
@ -74,8 +79,7 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
/**
* Strides for memory access
*/
override val strides: IntArray by lazy {
sequence {
override val strides: IntArray = sequence {
var current = 1
yield(1)
@ -84,7 +88,6 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
yield(current)
}
}.toList().toIntArray()
}
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
@ -120,10 +123,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
*/
public class RowStrides(override val shape: ShapeND) : Strides() {
override val strides: IntArray by lazy {
override val strides: IntArray = run {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0) return@lazy res
if (nDim == 0) return@run res
var current = nDim - 1
res[current] = 1