Fixed strides code duplication

This commit is contained in:
Roland Grinis 2021-04-30 14:53:02 +01:00
parent 6be5caa93f
commit 86c2816cfd
5 changed files with 26 additions and 25 deletions

View File

@ -789,7 +789,7 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
public fun equals (Ljava/lang/Object;)Z
public fun getLinearSize ()I
public fun getShape ()[I
public fun getStrides ()Ljava/util/List;
public fun getStrides ()[I
public fun hashCode ()I
public fun index (I)[I
public fun offset ([I)I
@ -931,7 +931,7 @@ public final class space/kscience/kmath/nd/ShortRingNDKt {
public abstract interface class space/kscience/kmath/nd/Strides {
public abstract fun getLinearSize ()I
public abstract fun getShape ()[I
public abstract fun getStrides ()Ljava/util/List;
public abstract fun getStrides ()[I
public abstract fun index (I)[I
public fun indices ()Lkotlin/sequences/Sequence;
public abstract fun offset ([I)I

View File

@ -184,7 +184,7 @@ public interface Strides {
/**
* Array strides
*/
public val strides: List<Int>
public val strides: IntArray
/**
* Get linear index from multidimensional index
@ -221,7 +221,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
/**
* Strides for memory access
*/
override val strides: List<Int> by lazy {
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
@ -230,7 +230,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
current *= it
yield(current)
}
}.toList()
}.toList().toIntArray()
}
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->

View File

@ -15,7 +15,7 @@ public open class BufferedTensor<T>(
get() = TensorLinearStructure(shape)
public val numElements: Int
get() = linearStructure.size
get() = linearStructure.linearSize
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
@ -60,7 +60,7 @@ internal fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> =
internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides.toIntArray() contentEquals TensorLinearStructure(this.shape).strides)
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
else -> this.copyToBufferedTensor()
}

View File

@ -20,7 +20,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
@ -28,7 +28,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
@ -38,7 +38,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
@ -46,7 +46,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
@ -56,7 +56,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newThis.bufferStart + i] *
newOther.mutableBuffer.array()[newOther.bufferStart + i]
}
@ -65,7 +65,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
@ -75,7 +75,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newOther.bufferStart + i] /
newOther.mutableBuffer.array()[newOther.bufferStart + i]
}
@ -84,7 +84,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}

View File

@ -5,16 +5,17 @@
package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.nd.Strides
import kotlin.math.max
internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
internal fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
value * strides[i]
}.sum()
internal inline fun stridesFromShape(shape: IntArray): IntArray {
internal fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
@ -31,7 +32,7 @@ internal inline fun stridesFromShape(shape: IntArray): IntArray {
}
internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
@ -44,7 +45,7 @@ internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int):
return res
}
internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
internal fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
val res = index.copyOf()
var current = nDim - 1
var carry = 0
@ -62,26 +63,26 @@ internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntA
}
public class TensorLinearStructure(public val shape: IntArray)
public class TensorLinearStructure(override val shape: IntArray) : Strides
{
public val strides: IntArray
override val strides: IntArray
get() = stridesFromShape(shape)
public fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
public fun index(offset: Int): IntArray =
override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size)
public fun stepIndex(index: IntArray): IntArray =
stepIndex(index, shape, shape.size)
public val size: Int
override val linearSize: Int
get() = shape.reduce(Int::times)
public val dim: Int
get() = shape.size
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map {
override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map {
index(it)
}
}