Optimize tensor shape computation

This commit is contained in:
Alexander Nozik 2022-10-16 20:15:37 +03:00
parent 94489b28e2
commit 8286db30af
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357

View File

@ -51,10 +51,15 @@ public abstract class Strides : ShapeIndexer {
*/ */
internal abstract val strides: IntArray internal abstract val strides: IntArray
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> public override fun offset(index: IntArray): Int {
var res = 0
index.forEachIndexed { i, value ->
if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})") if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})")
value * strides[i] res += value * strides[i]
}.sum()
}
return res
}
// TODO introduce a fast way to calculate index of the next element? // TODO introduce a fast way to calculate index of the next element?
@ -74,8 +79,7 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
/** /**
* Strides for memory access * Strides for memory access
*/ */
override val strides: IntArray by lazy { override val strides: IntArray = sequence {
sequence {
var current = 1 var current = 1
yield(1) yield(1)
@ -84,7 +88,6 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
yield(current) yield(current)
} }
}.toList().toIntArray() }.toList().toIntArray()
}
override fun index(offset: Int): IntArray { override fun index(offset: Int): IntArray {
val res = IntArray(shape.size) val res = IntArray(shape.size)
@ -120,10 +123,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
*/ */
public class RowStrides(override val shape: ShapeND) : Strides() { public class RowStrides(override val shape: ShapeND) : Strides() {
override val strides: IntArray by lazy { override val strides: IntArray = run {
val nDim = shape.size val nDim = shape.size
val res = IntArray(nDim) val res = IntArray(nDim)
if (nDim == 0) return@lazy res if (nDim == 0) return@run res
var current = nDim - 1 var current = nDim - 1
res[current] = 1 res[current] = 1