v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
5 changed files with 69 additions and 67 deletions
Showing only changes of commit 94e5ee4a6d - Show all commits

View File

@ -177,7 +177,7 @@ public interface Strides {
/** /**
* Array strides * Array strides
*/ */
public val strides: IntArray public val strides: List<Int>
/** /**
* Get linear index from multidimensional index * Get linear index from multidimensional index
@ -209,12 +209,6 @@ public interface Strides {
} }
} }
internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
value * strides[i]
}.sum()
/** /**
* Simple implementation of [Strides]. * Simple implementation of [Strides].
*/ */
@ -225,7 +219,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
/** /**
* Strides for memory access * Strides for memory access
*/ */
override val strides: IntArray by lazy { override val strides: List<Int> by lazy {
sequence { sequence {
var current = 1 var current = 1
yield(1) yield(1)
@ -234,10 +228,13 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
current *= it current *= it
yield(current) yield(current)
} }
}.toList().toIntArray() }.toList()
} }
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides) override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
value * strides[i]
}.sum()
override fun index(offset: Int): IntArray { override fun index(offset: Int): IntArray {
val res = IntArray(shape.size) val res = IntArray(shape.size)

View File

@ -8,7 +8,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(this, other) val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] + newOther.buffer.array()[i] newThis.buffer.array()[i] + newOther.buffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
@ -16,7 +16,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun DoubleTensor.plusAssign(other: DoubleTensor) { override fun DoubleTensor.plusAssign(other: DoubleTensor) {
val newOther = broadcastTo(other, this.shape) val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] += this.buffer.array()[this.bufferStart + i] +=
newOther.buffer.array()[this.bufferStart + i] newOther.buffer.array()[this.bufferStart + i]
} }
@ -26,7 +26,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(this, other) val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] - newOther.buffer.array()[i] newThis.buffer.array()[i] - newOther.buffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
@ -34,7 +34,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun DoubleTensor.minusAssign(other: DoubleTensor) { override fun DoubleTensor.minusAssign(other: DoubleTensor) {
val newOther = broadcastTo(other, this.shape) val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] -= this.buffer.array()[this.bufferStart + i] -=
newOther.buffer.array()[this.bufferStart + i] newOther.buffer.array()[this.bufferStart + i]
} }
@ -44,7 +44,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(this, other) val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newOther.bufferStart + i] * newThis.buffer.array()[newOther.bufferStart + i] *
newOther.buffer.array()[newOther.bufferStart + i] newOther.buffer.array()[newOther.bufferStart + i]
} }
@ -53,7 +53,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun DoubleTensor.timesAssign(other: DoubleTensor) { override fun DoubleTensor.timesAssign(other: DoubleTensor) {
val newOther = broadcastTo(other, this.shape) val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] *= this.buffer.array()[this.bufferStart + i] *=
newOther.buffer.array()[this.bufferStart + i] newOther.buffer.array()[this.bufferStart + i]
} }
@ -63,7 +63,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(this, other) val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newOther.bufferStart + i] / newThis.buffer.array()[newOther.bufferStart + i] /
newOther.buffer.array()[newOther.bufferStart + i] newOther.buffer.array()[newOther.bufferStart + i]
} }
@ -72,7 +72,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun DoubleTensor.divAssign(other: DoubleTensor) { override fun DoubleTensor.divAssign(other: DoubleTensor) {
val newOther = broadcastTo(other, this.shape) val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] /= this.buffer.array()[this.bufferStart + i] /=
newOther.buffer.array()[this.bufferStart + i] newOther.buffer.array()[this.bufferStart + i]
} }
@ -130,7 +130,7 @@ internal inline fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): Doubl
} }
for (linearIndex in 0 until n) { for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex) val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf() val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size val offset = totalMultiIndex.size - curMultiIndex.size
@ -143,7 +143,7 @@ internal inline fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): Doubl
} }
} }
val curLinearIndex = tensor.strides.offset(curMultiIndex) val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] = resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex] tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
} }
@ -159,7 +159,7 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
val resTensor = DoubleTensor(totalShape, DoubleArray(n)) val resTensor = DoubleTensor(totalShape, DoubleArray(n))
for (linearIndex in 0 until n) { for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex) val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf() val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size val offset = totalMultiIndex.size - curMultiIndex.size
@ -172,7 +172,7 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
} }
} }
val curLinearIndex = tensor.strides.offset(curMultiIndex) val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] = resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex] tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
} }

View File

@ -1,9 +1,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.TensorStrides
import space.kscience.kmath.tensors.TensorStructure import space.kscience.kmath.tensors.TensorStructure
@ -13,19 +11,19 @@ public open class BufferedTensor<T>(
internal val bufferStart: Int internal val bufferStart: Int
) : TensorStructure<T> ) : TensorStructure<T>
{ {
public val strides: TensorStrides public val linearStructure: TensorLinearStructure
get() = TensorStrides(shape) get() = TensorLinearStructure(shape)
public val numel: Int public val numel: Int
get() = strides.linearSize get() = linearStructure.size
override fun get(index: IntArray): T = buffer[bufferStart + strides.offset(index)] override fun get(index: IntArray): T = buffer[bufferStart + linearStructure.offset(index)]
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
buffer[bufferStart + strides.offset(index)] = value buffer[bufferStart + linearStructure.offset(index)] = value
} }
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map {
it to this[it] it to this[it]
} }
@ -35,7 +33,7 @@ public open class BufferedTensor<T>(
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence { public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
check(shape.size >= 1) {"todo"} check(shape.size >= 1) {"todo"}
val vectorOffset = strides.strides[0] val vectorOffset = linearStructure.strides[0]
val vectorShape = intArrayOf(shape.last()) val vectorShape = intArrayOf(shape.last())
for (offset in 0 until numel step vectorOffset) { for (offset in 0 until numel step vectorOffset) {
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D() val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
@ -45,7 +43,7 @@ public open class BufferedTensor<T>(
public fun matrixSequence(): Sequence<MutableStructure2D<T>> = sequence { public fun matrixSequence(): Sequence<MutableStructure2D<T>> = sequence {
check(shape.size >= 2) {"todo"} check(shape.size >= 2) {"todo"}
val matrixOffset = strides.strides[1] val matrixOffset = linearStructure.strides[1]
val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way? val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way?
for (offset in 0 until numel step matrixOffset) { for (offset in 0 until numel step matrixOffset) {
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D() val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()

View File

@ -27,7 +27,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.fullLike(value: Double): DoubleTensor { override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
val shape = this.shape val shape = this.shape
val buffer = DoubleArray(this.strides.linearSize) { value } val buffer = DoubleArray(this.linearStructure.size) { value }
return DoubleTensor(shape, buffer) return DoubleTensor(shape, buffer)
} }
@ -54,7 +54,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
override fun Double.plus(other: DoubleTensor): DoubleTensor { override fun Double.plus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.strides.linearSize) { i -> val resBuffer = DoubleArray(other.linearStructure.size) { i ->
other.buffer.array()[other.bufferStart + i] + this other.buffer.array()[other.bufferStart + i] + this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
@ -64,35 +64,35 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[i] + other.buffer.array()[i] this.buffer.array()[i] + other.buffer.array()[i]
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
override fun DoubleTensor.plusAssign(value: Double) { override fun DoubleTensor.plusAssign(value: Double) {
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] += value this.buffer.array()[this.bufferStart + i] += value
} }
} }
override fun DoubleTensor.plusAssign(other: DoubleTensor) { override fun DoubleTensor.plusAssign(other: DoubleTensor) {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] += this.buffer.array()[this.bufferStart + i] +=
other.buffer.array()[this.bufferStart + i] other.buffer.array()[this.bufferStart + i]
} }
} }
override fun Double.minus(other: DoubleTensor): DoubleTensor { override fun Double.minus(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.strides.linearSize) { i -> val resBuffer = DoubleArray(other.linearStructure.size) { i ->
this - other.buffer.array()[other.bufferStart + i] this - other.buffer.array()[other.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun DoubleTensor.minus(value: Double): DoubleTensor { override fun DoubleTensor.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[this.bufferStart + i] - value this.buffer.array()[this.bufferStart + i] - value
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
@ -100,28 +100,28 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[i] - other.buffer.array()[i] this.buffer.array()[i] - other.buffer.array()[i]
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
} }
override fun DoubleTensor.minusAssign(value: Double) { override fun DoubleTensor.minusAssign(value: Double) {
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] -= value this.buffer.array()[this.bufferStart + i] -= value
} }
} }
override fun DoubleTensor.minusAssign(other: DoubleTensor) { override fun DoubleTensor.minusAssign(other: DoubleTensor) {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] -= this.buffer.array()[this.bufferStart + i] -=
other.buffer.array()[this.bufferStart + i] other.buffer.array()[this.bufferStart + i]
} }
} }
override fun Double.times(other: DoubleTensor): DoubleTensor { override fun Double.times(other: DoubleTensor): DoubleTensor {
val resBuffer = DoubleArray(other.strides.linearSize) { i -> val resBuffer = DoubleArray(other.linearStructure.size) { i ->
other.buffer.array()[other.bufferStart + i] * this other.buffer.array()[other.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
@ -131,7 +131,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[other.bufferStart + i] * this.buffer.array()[other.bufferStart + i] *
other.buffer.array()[other.bufferStart + i] other.buffer.array()[other.bufferStart + i]
} }
@ -139,21 +139,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
override fun DoubleTensor.timesAssign(value: Double) { override fun DoubleTensor.timesAssign(value: Double) {
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] *= value this.buffer.array()[this.bufferStart + i] *= value
} }
} }
override fun DoubleTensor.timesAssign(other: DoubleTensor) { override fun DoubleTensor.timesAssign(other: DoubleTensor) {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] *= this.buffer.array()[this.bufferStart + i] *=
other.buffer.array()[this.bufferStart + i] other.buffer.array()[this.bufferStart + i]
} }
} }
override fun DoubleTensor.div(value: Double): DoubleTensor { override fun DoubleTensor.div(value: Double): DoubleTensor {
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[this.bufferStart + i] / value this.buffer.array()[this.bufferStart + i] / value
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
@ -161,7 +161,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[other.bufferStart + i] / this.buffer.array()[other.bufferStart + i] /
other.buffer.array()[other.bufferStart + i] other.buffer.array()[other.bufferStart + i]
} }
@ -169,21 +169,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
override fun DoubleTensor.divAssign(value: Double) { override fun DoubleTensor.divAssign(value: Double) {
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] /= value this.buffer.array()[this.bufferStart + i] /= value
} }
} }
override fun DoubleTensor.divAssign(other: DoubleTensor) { override fun DoubleTensor.divAssign(other: DoubleTensor) {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.linearStructure.size) {
this.buffer.array()[this.bufferStart + i] /= this.buffer.array()[this.bufferStart + i] /=
other.buffer.array()[this.bufferStart + i] other.buffer.array()[this.bufferStart + i]
} }
} }
override fun DoubleTensor.unaryMinus(): DoubleTensor { override fun DoubleTensor.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[this.bufferStart + i].unaryMinus() this.buffer.array()[this.bufferStart + i].unaryMinus()
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(this.shape, resBuffer)
@ -191,7 +191,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor { override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor {
checkTranspose(this.dimension, i, j) checkTranspose(this.dimension, i, j)
val n = this.strides.linearSize val n = this.linearStructure.size
val resBuffer = DoubleArray(n) val resBuffer = DoubleArray(n)
val resShape = this.shape.copyOf() val resShape = this.shape.copyOf()
@ -200,11 +200,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
val resTensor = DoubleTensor(resShape, resBuffer) val resTensor = DoubleTensor(resShape, resBuffer)
for (offset in 0 until n) { for (offset in 0 until n) {
val oldMultiIndex = this.strides.index(offset) val oldMultiIndex = this.linearStructure.index(offset)
val newMultiIndex = oldMultiIndex.copyOf() val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] } newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
val linearIndex = resTensor.strides.offset(newMultiIndex) val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
resTensor.buffer.array()[linearIndex] = resTensor.buffer.array()[linearIndex] =
this.buffer.array()[this.bufferStart + offset] this.buffer.array()[this.bufferStart + offset]
} }

View File

@ -1,10 +1,14 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.Strides
import space.kscience.kmath.nd.offsetFromIndex
import kotlin.math.max import kotlin.math.max
internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
value * strides[i]
}.sum()
internal inline fun stridesFromShape(shape: IntArray): IntArray { internal inline fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size val nDim = shape.size
val res = IntArray(nDim) val res = IntArray(nDim)
@ -35,7 +39,7 @@ internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int):
return res return res
} }
internal inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray { internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
val res = index.copyOf() val res = index.copyOf()
var current = nDim - 1 var current = nDim - 1
var carry = 0 var carry = 0
@ -47,26 +51,29 @@ internal inline fun nextIndex(index: IntArray, shape: IntArray, nDim: Int): IntA
res[current] = 0 res[current] = 0
} }
current-- current--
} while(carry != 0 && current >= 0) } while (carry != 0 && current >= 0)
return res return res
} }
public class TensorLinearStructure(public val shape: IntArray)
public class TensorStrides(override val shape: IntArray): Strides
{ {
override val strides: IntArray public val strides: IntArray
get() = stridesFromShape(shape) get() = stridesFromShape(shape)
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides) public fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
override fun index(offset: Int): IntArray = public fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size) indexFromOffset(offset, strides, shape.size)
override fun nextIndex(index: IntArray): IntArray = public fun stepIndex(index: IntArray): IntArray =
nextIndex(index, shape, shape.size) stepIndex(index, shape, shape.size)
override val linearSize: Int public val size: Int
get() = shape.reduce(Int::times) get() = shape.reduce(Int::times)
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map {
index(it)
}
} }