Optimize tensor shape computation
This commit is contained in:
parent
94489b28e2
commit
8286db30af
@ -51,10 +51,15 @@ public abstract class Strides : ShapeIndexer {
|
|||||||
*/
|
*/
|
||||||
internal abstract val strides: IntArray
|
internal abstract val strides: IntArray
|
||||||
|
|
||||||
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
public override fun offset(index: IntArray): Int {
|
||||||
if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})")
|
var res = 0
|
||||||
value * strides[i]
|
index.forEachIndexed { i, value ->
|
||||||
}.sum()
|
if (value !in 0 until shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0, ${this.shape[i]})")
|
||||||
|
res += value * strides[i]
|
||||||
|
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
// TODO introduce a fast way to calculate index of the next element?
|
// TODO introduce a fast way to calculate index of the next element?
|
||||||
|
|
||||||
@ -74,17 +79,15 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
|
|||||||
/**
|
/**
|
||||||
* Strides for memory access
|
* Strides for memory access
|
||||||
*/
|
*/
|
||||||
override val strides: IntArray by lazy {
|
override val strides: IntArray = sequence {
|
||||||
sequence {
|
var current = 1
|
||||||
var current = 1
|
yield(1)
|
||||||
yield(1)
|
|
||||||
|
|
||||||
shape.forEach {
|
shape.forEach {
|
||||||
current *= it
|
current *= it
|
||||||
yield(current)
|
yield(current)
|
||||||
}
|
}
|
||||||
}.toList().toIntArray()
|
}.toList().toIntArray()
|
||||||
}
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray {
|
override fun index(offset: Int): IntArray {
|
||||||
val res = IntArray(shape.size)
|
val res = IntArray(shape.size)
|
||||||
@ -120,10 +123,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
|
|||||||
*/
|
*/
|
||||||
public class RowStrides(override val shape: ShapeND) : Strides() {
|
public class RowStrides(override val shape: ShapeND) : Strides() {
|
||||||
|
|
||||||
override val strides: IntArray by lazy {
|
override val strides: IntArray = run {
|
||||||
val nDim = shape.size
|
val nDim = shape.size
|
||||||
val res = IntArray(nDim)
|
val res = IntArray(nDim)
|
||||||
if (nDim == 0) return@lazy res
|
if (nDim == 0) return@run res
|
||||||
|
|
||||||
var current = nDim - 1
|
var current = nDim - 1
|
||||||
res[current] = 1
|
res[current] = 1
|
||||||
|
Loading…
Reference in New Issue
Block a user