forked from kscience/kmath
Fixed strides code duplication
This commit is contained in:
parent
6be5caa93f
commit
86c2816cfd
@ -789,7 +789,7 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
|
|||||||
public fun equals (Ljava/lang/Object;)Z
|
public fun equals (Ljava/lang/Object;)Z
|
||||||
public fun getLinearSize ()I
|
public fun getLinearSize ()I
|
||||||
public fun getShape ()[I
|
public fun getShape ()[I
|
||||||
public fun getStrides ()Ljava/util/List;
|
public fun getStrides ()[I
|
||||||
public fun hashCode ()I
|
public fun hashCode ()I
|
||||||
public fun index (I)[I
|
public fun index (I)[I
|
||||||
public fun offset ([I)I
|
public fun offset ([I)I
|
||||||
@ -931,7 +931,7 @@ public final class space/kscience/kmath/nd/ShortRingNDKt {
|
|||||||
public abstract interface class space/kscience/kmath/nd/Strides {
|
public abstract interface class space/kscience/kmath/nd/Strides {
|
||||||
public abstract fun getLinearSize ()I
|
public abstract fun getLinearSize ()I
|
||||||
public abstract fun getShape ()[I
|
public abstract fun getShape ()[I
|
||||||
public abstract fun getStrides ()Ljava/util/List;
|
public abstract fun getStrides ()[I
|
||||||
public abstract fun index (I)[I
|
public abstract fun index (I)[I
|
||||||
public fun indices ()Lkotlin/sequences/Sequence;
|
public fun indices ()Lkotlin/sequences/Sequence;
|
||||||
public abstract fun offset ([I)I
|
public abstract fun offset ([I)I
|
||||||
|
@ -184,7 +184,7 @@ public interface Strides {
|
|||||||
/**
|
/**
|
||||||
* Array strides
|
* Array strides
|
||||||
*/
|
*/
|
||||||
public val strides: List<Int>
|
public val strides: IntArray
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get linear index from multidimensional index
|
* Get linear index from multidimensional index
|
||||||
@ -221,7 +221,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
/**
|
/**
|
||||||
* Strides for memory access
|
* Strides for memory access
|
||||||
*/
|
*/
|
||||||
override val strides: List<Int> by lazy {
|
override val strides: IntArray by lazy {
|
||||||
sequence {
|
sequence {
|
||||||
var current = 1
|
var current = 1
|
||||||
yield(1)
|
yield(1)
|
||||||
@ -230,7 +230,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
current *= it
|
current *= it
|
||||||
yield(current)
|
yield(current)
|
||||||
}
|
}
|
||||||
}.toList()
|
}.toList().toIntArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||||
|
@ -15,7 +15,7 @@ public open class BufferedTensor<T>(
|
|||||||
get() = TensorLinearStructure(shape)
|
get() = TensorLinearStructure(shape)
|
||||||
|
|
||||||
public val numElements: Int
|
public val numElements: Int
|
||||||
get() = linearStructure.size
|
get() = linearStructure.linearSize
|
||||||
|
|
||||||
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
|
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ internal fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> =
|
|||||||
|
|
||||||
internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||||
is BufferedTensor<T> -> this
|
is BufferedTensor<T> -> this
|
||||||
is MutableBufferND<T> -> if (this.strides.strides.toIntArray() contentEquals TensorLinearStructure(this.shape).strides)
|
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
|
||||||
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
|
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
|
||||||
else -> this.copyToBufferedTensor()
|
else -> this.copyToBufferedTensor()
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||||
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
|
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
@ -28,7 +28,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
|
|
||||||
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
|
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
|
||||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||||
for (i in 0 until tensor.linearStructure.size) {
|
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
|
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
|
||||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||||
}
|
}
|
||||||
@ -38,7 +38,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||||
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
|
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
@ -46,7 +46,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
|
|
||||||
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
|
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
|
||||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||||
for (i in 0 until tensor.linearStructure.size) {
|
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
|
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
|
||||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||||
}
|
}
|
||||||
@ -56,7 +56,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||||
newThis.mutableBuffer.array()[newThis.bufferStart + i] *
|
newThis.mutableBuffer.array()[newThis.bufferStart + i] *
|
||||||
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
||||||
}
|
}
|
||||||
@ -65,7 +65,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
|
|
||||||
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
|
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
|
||||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||||
for (i in 0 until tensor.linearStructure.size) {
|
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
|
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
|
||||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||||
}
|
}
|
||||||
@ -75,7 +75,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||||
newThis.mutableBuffer.array()[newOther.bufferStart + i] /
|
newThis.mutableBuffer.array()[newOther.bufferStart + i] /
|
||||||
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
||||||
}
|
}
|
||||||
@ -84,7 +84,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
|
|
||||||
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
|
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
|
||||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||||
for (i in 0 until tensor.linearStructure.size) {
|
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
||||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||||
}
|
}
|
||||||
|
@ -5,16 +5,17 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.core.algebras
|
package space.kscience.kmath.tensors.core.algebras
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.Strides
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
internal inline fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
|
internal fun offsetFromIndex(index: IntArray, shape: IntArray, strides: IntArray): Int =
|
||||||
index.mapIndexed { i, value ->
|
index.mapIndexed { i, value ->
|
||||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
|
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${shape[i]})")
|
||||||
value * strides[i]
|
value * strides[i]
|
||||||
}.sum()
|
}.sum()
|
||||||
|
|
||||||
internal inline fun stridesFromShape(shape: IntArray): IntArray {
|
internal fun stridesFromShape(shape: IntArray): IntArray {
|
||||||
val nDim = shape.size
|
val nDim = shape.size
|
||||||
val res = IntArray(nDim)
|
val res = IntArray(nDim)
|
||||||
if (nDim == 0)
|
if (nDim == 0)
|
||||||
@ -31,7 +32,7 @@ internal inline fun stridesFromShape(shape: IntArray): IntArray {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||||
val res = IntArray(nDim)
|
val res = IntArray(nDim)
|
||||||
var current = offset
|
var current = offset
|
||||||
var strideIndex = 0
|
var strideIndex = 0
|
||||||
@ -44,7 +45,7 @@ internal inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int):
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
internal fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||||
val res = index.copyOf()
|
val res = index.copyOf()
|
||||||
var current = nDim - 1
|
var current = nDim - 1
|
||||||
var carry = 0
|
var carry = 0
|
||||||
@ -62,26 +63,26 @@ internal inline fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntA
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public class TensorLinearStructure(public val shape: IntArray)
|
public class TensorLinearStructure(override val shape: IntArray) : Strides
|
||||||
{
|
{
|
||||||
public val strides: IntArray
|
override val strides: IntArray
|
||||||
get() = stridesFromShape(shape)
|
get() = stridesFromShape(shape)
|
||||||
|
|
||||||
public fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
||||||
|
|
||||||
public fun index(offset: Int): IntArray =
|
override fun index(offset: Int): IntArray =
|
||||||
indexFromOffset(offset, strides, shape.size)
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
public fun stepIndex(index: IntArray): IntArray =
|
public fun stepIndex(index: IntArray): IntArray =
|
||||||
stepIndex(index, shape, shape.size)
|
stepIndex(index, shape, shape.size)
|
||||||
|
|
||||||
public val size: Int
|
override val linearSize: Int
|
||||||
get() = shape.reduce(Int::times)
|
get() = shape.reduce(Int::times)
|
||||||
|
|
||||||
public val dim: Int
|
public val dim: Int
|
||||||
get() = shape.size
|
get() = shape.size
|
||||||
|
|
||||||
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map {
|
override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map {
|
||||||
index(it)
|
index(it)
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user