Optimize tensor shape computation

This commit is contained in:
Alexander Nozik 2022-10-16 20:15:37 +03:00
parent 94489b28e2
commit 8286db30af
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357

View File

@ -51,10 +51,15 @@ public abstract class Strides : ShapeIndexer {
*/ */
internal abstract val strides: IntArray internal abstract val strides: IntArray
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> public override fun offset(index: IntArray): Int {
if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})") var res = 0
value * strides[i] index.forEachIndexed { i, value ->
}.sum() if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})")
res += value * strides[i]
}
return res
}
// TODO introduce a fast way to calculate index of the next element? // TODO introduce a fast way to calculate index of the next element?
@ -74,17 +79,15 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
/** /**
* Strides for memory access * Strides for memory access
*/ */
override val strides: IntArray by lazy { override val strides: IntArray = sequence {
sequence { var current = 1
var current = 1 yield(1)
yield(1)
shape.forEach { shape.forEach {
current *= it current *= it
yield(current) yield(current)
} }
}.toList().toIntArray() }.toList().toIntArray()
}
override fun index(offset: Int): IntArray { override fun index(offset: Int): IntArray {
val res = IntArray(shape.size) val res = IntArray(shape.size)
@ -120,10 +123,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
*/ */
public class RowStrides(override val shape: ShapeND) : Strides() { public class RowStrides(override val shape: ShapeND) : Strides() {
override val strides: IntArray by lazy { override val strides: IntArray = run {
val nDim = shape.size val nDim = shape.size
val res = IntArray(nDim) val res = IntArray(nDim)
if (nDim == 0) return@lazy res if (nDim == 0) return@run res
var current = nDim - 1 var current = nDim - 1
res[current] = 1 res[current] = 1