From 8286db30afef0fb5e647117fd7892f20ebfc1e73 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 16 Oct 2022 20:15:37 +0300 Subject: [PATCH] Optimize tensor shape computation --- .../space/kscience/kmath/nd/ShapeIndices.kt | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt index 37e0c7b5e..3a27614c5 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt @@ -51,10 +51,15 @@ public abstract class Strides : ShapeIndexer { */ internal abstract val strides: IntArray - public override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> - if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})") - value * strides[i] - }.sum() + public override fun offset(index: IntArray): Int { + var res = 0 + index.forEachIndexed { i, value -> + if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})") + res += value * strides[i] + + } + return res + } // TODO introduce a fast way to calculate index of the next element? @@ -74,17 +79,15 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() { /** * Strides for memory access */ - override val strides: IntArray by lazy { - sequence { - var current = 1 - yield(1) + override val strides: IntArray = sequence { + var current = 1 + yield(1) - shape.forEach { - current *= it - yield(current) - } - }.toList().toIntArray() - } + shape.forEach { + current *= it + yield(current) + } + }.toList().toIntArray() override fun index(offset: Int): IntArray { val res = IntArray(shape.size) @@ -120,10 +123,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() { */ public class RowStrides(override val shape: ShapeND) : Strides() { - override val strides: IntArray by lazy { + override val strides: IntArray = run { val nDim = shape.size val res = IntArray(nDim) - if (nDim == 0) return@lazy res + if (nDim == 0) return@run res var current = nDim - 1 res[current] = 1