Fixed strides code duplication
This commit is contained in:
parent
6be5caa93f
commit
86c2816cfd
@ -789,7 +789,7 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
|
||||
public fun equals (Ljava/lang/Object;)Z
|
||||
public fun getLinearSize ()I
|
||||
public fun getShape ()[I
|
||||
public fun getStrides ()Ljava/util/List;
|
||||
public fun getStrides ()[I
|
||||
public fun hashCode ()I
|
||||
public fun index (I)[I
|
||||
public fun offset ([I)I
|
||||
@ -931,7 +931,7 @@ public final class space/kscience/kmath/nd/ShortRingNDKt {
|
||||
public abstract interface class space/kscience/kmath/nd/Strides {
|
||||
public abstract fun getLinearSize ()I
|
||||
public abstract fun getShape ()[I
|
||||
public abstract fun getStrides ()Ljava/util/List;
|
||||
public abstract fun getStrides ()[I
|
||||
public abstract fun index (I)[I
|
||||
public fun indices ()Lkotlin/sequences/Sequence;
|
||||
public abstract fun offset ([I)I
|
||||
|
@ -184,7 +184,7 @@ public interface Strides {
|
||||
/**
|
||||
* Array strides
|
||||
*/
|
||||
public val strides: List<Int>
|
||||
public val strides: IntArray
|
||||
|
||||
/**
|
||||
* Get linear index from multidimensional index
|
||||
@ -221,7 +221,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
override val strides: List<Int> by lazy {
|
||||
override val strides: IntArray by lazy {
|
||||
sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
@ -230,7 +230,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
current *= it
|
||||
yield(current)
|
||||
}
|
||||
}.toList()
|
||||
}.toList().toIntArray()
|
||||
}
|
||||
|
||||
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
|
@ -15,7 +15,7 @@ public open class BufferedTensor<T>(
|
||||
get() = TensorLinearStructure(shape)
|
||||
|
||||
public val numElements: Int
|
||||
get() = linearStructure.size
|
||||
get() = linearStructure.linearSize
|
||||
|
||||
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
|
||||
|
||||
@ -60,7 +60,7 @@ internal fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> =
|
||||
|
||||
internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||
is BufferedTensor<T> -> this
|
||||
is MutableBufferND<T> -> if (this.strides.strides.toIntArray() contentEquals TensorLinearStructure(this.shape).strides)
|
||||
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
|
||||
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
|
||||
else -> this.copyToBufferedTensor()
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
@ -28,7 +28,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
|
||||
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.size) {
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
@ -38,7 +38,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
@ -46,7 +46,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
|
||||
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.size) {
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
@ -56,7 +56,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[newThis.bufferStart + i] *
|
||||
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
||||
}
|
||||
@ -65,7 +65,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
|
||||
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.size) {
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
@ -75,7 +75,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[newOther.bufferStart + i] /
|
||||
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
||||
}
|
||||
@ -84,7 +84,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
|
||||
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.size) {
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
|
@ -5,16 +5,17 @@
|
||||
|
||||
package space.kscience.kmath.tensors.core.algebras
|
||||
|
||||
import space.kscience.kmath.nd.Strides
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
|
||||
internal fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
|
||||
index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
|
||||
internal inline fun stridesFromShape(shape: IntArray): IntArray {
|
||||
internal fun stridesFromShape(shape: IntArray): IntArray {
|
||||
val nDim = shape.size
|
||||
val res = IntArray(nDim)
|
||||
if (nDim == 0)
|
||||
@ -31,7 +32,7 @@ internal inline fun stridesFromShape(shape: IntArray): IntArray {
|
||||
|
||||
}
|
||||
|
||||
internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
val res = IntArray(nDim)
|
||||
var current = offset
|
||||
var strideIndex = 0
|
||||
@ -44,7 +45,7 @@ internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int):
|
||||
return res
|
||||
}
|
||||
|
||||
internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||
internal fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||
val res = index.copyOf()
|
||||
var current = nDim - 1
|
||||
var carry = 0
|
||||
@ -62,26 +63,26 @@ internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntA
|
||||
}
|
||||
|
||||
|
||||
public class TensorLinearStructure(public val shape: IntArray)
|
||||
public class TensorLinearStructure(override val shape: IntArray) : Strides
|
||||
{
|
||||
public val strides: IntArray
|
||||
override val strides: IntArray
|
||||
get() = stridesFromShape(shape)
|
||||
|
||||
public fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
||||
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
||||
|
||||
public fun index(offset: Int): IntArray =
|
||||
override fun index(offset: Int): IntArray =
|
||||
indexFromOffset(offset, strides, shape.size)
|
||||
|
||||
public fun stepIndex(index: IntArray): IntArray =
|
||||
stepIndex(index, shape, shape.size)
|
||||
|
||||
public val size: Int
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
|
||||
public val dim: Int
|
||||
get() = shape.size
|
||||
|
||||
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map {
|
||||
override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map {
|
||||
index(it)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user