Fixed strides code duplication

This commit is contained in:
Roland Grinis 2021-04-30 14:53:02 +01:00
parent 6be5caa93f
commit 86c2816cfd
5 changed files with 26 additions and 25 deletions

View File

@ -789,7 +789,7 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
public fun equals (Ljava/lang/Object;)Z public fun equals (Ljava/lang/Object;)Z
public fun getLinearSize ()I public fun getLinearSize ()I
public fun getShape ()[I public fun getShape ()[I
public fun getStrides ()Ljava/util/List; public fun getStrides ()[I
public fun hashCode ()I public fun hashCode ()I
public fun index (I)[I public fun index (I)[I
public fun offset ([I)I public fun offset ([I)I
@ -931,7 +931,7 @@ public final class space/kscience/kmath/nd/ShortRingNDKt {
public abstract interface class space/kscience/kmath/nd/Strides { public abstract interface class space/kscience/kmath/nd/Strides {
public abstract fun getLinearSize ()I public abstract fun getLinearSize ()I
public abstract fun getShape ()[I public abstract fun getShape ()[I
public abstract fun getStrides ()Ljava/util/List; public abstract fun getStrides ()[I
public abstract fun index (I)[I public abstract fun index (I)[I
public fun indices ()Lkotlin/sequences/Sequence; public fun indices ()Lkotlin/sequences/Sequence;
public abstract fun offset ([I)I public abstract fun offset ([I)I

View File

@ -184,7 +184,7 @@ public interface Strides {
/** /**
* Array strides * Array strides
*/ */
public val strides: List<Int> public val strides: IntArray
/** /**
* Get linear index from multidimensional index * Get linear index from multidimensional index
@ -221,7 +221,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
/** /**
* Strides for memory access * Strides for memory access
*/ */
override val strides: List<Int> by lazy { override val strides: IntArray by lazy {
sequence { sequence {
var current = 1 var current = 1
yield(1) yield(1)
@ -230,7 +230,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
current *= it current *= it
yield(current) yield(current)
} }
}.toList() }.toList().toIntArray()
} }
override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->

View File

@ -15,7 +15,7 @@ public open class BufferedTensor<T>(
get() = TensorLinearStructure(shape) get() = TensorLinearStructure(shape)
public val numElements: Int public val numElements: Int
get() = linearStructure.size get() = linearStructure.linearSize
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)] override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
@ -60,7 +60,7 @@ internal fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> =
internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides.toIntArray() contentEquals TensorLinearStructure(this.shape).strides) is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
else -> this.copyToBufferedTensor() else -> this.copyToBufferedTensor()
} }

View File

@ -20,7 +20,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i] newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
@ -28,7 +28,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
@ -38,7 +38,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i] newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
@ -46,7 +46,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
@ -56,7 +56,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newThis.bufferStart + i] * newThis.mutableBuffer.array()[newThis.bufferStart + i] *
newOther.mutableBuffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
@ -65,7 +65,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
@ -75,7 +75,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newOther.bufferStart + i] / newThis.mutableBuffer.array()[newOther.bufferStart + i] /
newOther.mutableBuffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
@ -84,7 +84,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }

View File

@ -5,16 +5,17 @@
package space.kscience.kmath.tensors.core.algebras package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.nd.Strides
import kotlin.math.max import kotlin.math.max
internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int = internal fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
index.mapIndexed { i, value -> index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})") if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
value * strides[i] value * strides[i]
}.sum() }.sum()
internal inline fun stridesFromShape(shape: IntArray): IntArray { internal fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size val nDim = shape.size
val res = IntArray(nDim) val res = IntArray(nDim)
if (nDim == 0) if (nDim == 0)
@ -31,7 +32,7 @@ internal inline fun stridesFromShape(shape: IntArray): IntArray {
} }
internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray { internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim) val res = IntArray(nDim)
var current = offset var current = offset
var strideIndex = 0 var strideIndex = 0
@ -44,7 +45,7 @@ internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int):
return res return res
} }
internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray { internal fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
val res = index.copyOf() val res = index.copyOf()
var current = nDim - 1 var current = nDim - 1
var carry = 0 var carry = 0
@ -62,26 +63,26 @@ internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntA
} }
public class TensorLinearStructure(public val shape: IntArray) public class TensorLinearStructure(override val shape: IntArray) : Strides
{ {
public val strides: IntArray override val strides: IntArray
get() = stridesFromShape(shape) get() = stridesFromShape(shape)
public fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides) override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
public fun index(offset: Int): IntArray = override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size) indexFromOffset(offset, strides, shape.size)
public fun stepIndex(index: IntArray): IntArray = public fun stepIndex(index: IntArray): IntArray =
stepIndex(index, shape, shape.size) stepIndex(index, shape, shape.size)
public val size: Int override val linearSize: Int
get() = shape.reduce(Int::times) get() = shape.reduce(Int::times)
public val dim: Int public val dim: Int
get() = shape.size get() = shape.size
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map { override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map {
index(it) index(it)
} }
} }