diff --git a/kmath-core/api/kmath-core.api b/kmath-core/api/kmath-core.api index 687de9cc8..261aa1e24 100644 --- a/kmath-core/api/kmath-core.api +++ b/kmath-core/api/kmath-core.api @@ -792,7 +792,6 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath public fun getStrides ()[I public fun hashCode ()I public fun index (I)[I - public fun offset ([I)I } public final class space/kscience/kmath/nd/DefaultStrides$Companion { @@ -934,7 +933,7 @@ public abstract interface class space/kscience/kmath/nd/Strides { public abstract fun getStrides ()[I public abstract fun index (I)[I public fun indices ()Lkotlin/sequences/Sequence; - public abstract fun offset ([I)I + public fun offset ([I)I } public abstract interface class space/kscience/kmath/nd/Structure1D : space/kscience/kmath/nd/StructureND, space/kscience/kmath/structures/Buffer { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt index 65c233012..a3331d71a 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt @@ -189,7 +189,10 @@ public interface Strides { /** * Get linear index from multidimensional index */ - public fun offset(index: IntArray): Int + public fun offset(index: IntArray): Int = index.mapIndexed { i, value -> + if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") + value * strides[i] + }.sum() /** * Get multidimensional from linear @@ -233,11 +236,6 @@ public class DefaultStrides private constructor(override val shape: IntArray) : }.toList().toIntArray() } - override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> - if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") - value * strides[i] - }.sum() - override fun index(offset: Int): IntArray { val res = IntArray(shape.size) var current = offset diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/TensorLinearStructure.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/TensorLinearStructure.kt index 8e83dafd6..08aab5175 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/TensorLinearStructure.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/TensorLinearStructure.kt @@ -68,8 +68,6 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides override val strides: IntArray get() = stridesFromShape(shape) - override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides) - override fun index(offset: Int): IntArray = indexFromOffset(offset, strides, shape.size) @@ -82,7 +80,4 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides public val dim: Int get() = shape.size - override fun indices(): Sequence = (0 until linearSize).asSequence().map { - index(it) - } } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/broadcastUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/broadcastUtils.kt index 58d8654af..e883b7861 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/broadcastUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/broadcastUtils.kt @@ -75,14 +75,13 @@ internal fun broadcastTensors(vararg tensors: DoubleTensor): List val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray()) val n = totalShape.reduce { acc, i -> acc * i } - val res = ArrayList(0) - for (tensor in tensors) { - val resTensor = DoubleTensor(totalShape, DoubleArray(n)) - multiIndexBroadCasting(tensor, resTensor, n) - res.add(resTensor) + return buildList { + for (tensor in tensors) { + val resTensor = DoubleTensor(totalShape, DoubleArray(n)) + multiIndexBroadCasting(tensor, resTensor, n) + add(resTensor) + } } - - return res } internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List { diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 34fe5d5f1..a0efa4573 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -170,7 +170,6 @@ internal class TestDoubleTensorAlgebra { val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor3 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 5.0)) - val tensor4 = fromArray(intArrayOf(6, 1), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(tensor1 eq tensor1) assertTrue(tensor1 eq tensor2)