Optimize tensor shape computation
This commit is contained in:
parent
94489b28e2
commit
8286db30af
@ -51,10 +51,15 @@ public abstract class Strides : ShapeIndexer {
|
||||
*/
|
||||
internal abstract val strides: IntArray
|
||||
|
||||
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})")
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
public override fun offset(index: IntArray): Int {
|
||||
var res = 0
|
||||
index.forEachIndexed { i, value ->
|
||||
if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})")
|
||||
res += value * strides[i]
|
||||
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// TODO introduce a fast way to calculate index of the next element?
|
||||
|
||||
@ -74,17 +79,15 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
override val strides: IntArray by lazy {
|
||||
sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
override val strides: IntArray = sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
|
||||
shape.forEach {
|
||||
current *= it
|
||||
yield(current)
|
||||
}
|
||||
}.toList().toIntArray()
|
||||
}
|
||||
shape.forEach {
|
||||
current *= it
|
||||
yield(current)
|
||||
}
|
||||
}.toList().toIntArray()
|
||||
|
||||
override fun index(offset: Int): IntArray {
|
||||
val res = IntArray(shape.size)
|
||||
@ -120,10 +123,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
|
||||
*/
|
||||
public class RowStrides(override val shape: ShapeND) : Strides() {
|
||||
|
||||
override val strides: IntArray by lazy {
|
||||
override val strides: IntArray = run {
|
||||
val nDim = shape.size
|
||||
val res = IntArray(nDim)
|
||||
if (nDim == 0) return@lazy res
|
||||
if (nDim == 0) return@run res
|
||||
|
||||
var current = nDim - 1
|
||||
res[current] = 1
|
||||
|
Loading…
Reference in New Issue
Block a user