From 201887187d18fdc41ebdbb183b1d602655a8827d Mon Sep 17 00:00:00 2001 From: Gleb Minaev <43728100+lounres@users.noreply.github.com> Date: Wed, 8 May 2024 21:59:49 +0300 Subject: [PATCH] Make `ShapeND` a usual non-value class. Implement its `equals` and `hashCode` methods. Deprecate `contentEquals` and `contentHashCode`. --- .../kmath/structures/StreamDoubleFieldND.kt | 2 +- .../kscience/kmath/tensors/OLSWithSVD.kt | 3 +- .../kscience/kmath/tensors/neuralNetwork.kt | 3 +- .../kmath/linear/BufferedLinearSpace.kt | 4 +- .../kmath/linear/Float64LinearSpace.kt | 4 +- .../space/kscience/kmath/nd/AlgebraND.kt | 2 +- .../space/kscience/kmath/nd/ShapeIndices.kt | 8 +-- .../kotlin/space/kscience/kmath/nd/ShapeND.kt | 15 +++++- .../linear/Float64ParallelLinearSpace.kt | 4 +- .../kmath/multik/MultikTensorAlgebra.kt | 6 +-- .../kscience/kmath/nd4j/Nd4jArrayAlgebra.kt | 2 +- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 4 +- .../kmath/tensorflow/TensorFlowAlgebra.kt | 3 +- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 2 +- .../kmath/tensors/core/IntTensorAlgebra.kt | 4 +- .../kmath/tensors/core/internal/checks.kt | 3 +- .../kscience/kmath/tensors/core/tensorOps.kt | 4 +- .../kmath/tensors/core/TestBroadcasting.kt | 38 +++++++------- .../core/TestDoubleLinearOpsAlgebra.kt | 13 +++-- .../tensors/core/TestDoubleTensorAlgebra.kt | 49 ++++++++++--------- .../kscience/kmath/viktor/ViktorFieldOpsND.kt | 2 +- 21 files changed, 92 insertions(+), 83 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt index 3a59423d4..1171231ab 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt @@ -35,7 +35,7 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND.buffer: Float64Buffer get() = when { - !shape.contentEquals(this@StreamDoubleFieldND.shape) -> throw ShapeMismatchException( + shape != this@StreamDoubleFieldND.shape -> throw ShapeMismatchException( this@StreamDoubleFieldND.shape, shape ) diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt index 8de4ab527..66f800f2c 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt @@ -6,7 +6,6 @@ package space.kscience.kmath.tensors import space.kscience.kmath.nd.ShapeND -import space.kscience.kmath.nd.contentEquals import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensorAlgebra @@ -62,7 +61,7 @@ fun main() { // figure out MSE of approximation fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double { require(yTrue.shape.size == 1) - require(yTrue.shape contentEquals yPred.shape) + require(yTrue.shape == yPred.shape) val diff = yTrue - yPred return sqrt(diff.dot(diff)).value() diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt index 78f27c304..74b9c899a 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt @@ -6,7 +6,6 @@ package space.kscience.kmath.tensors import space.kscience.kmath.nd.ShapeND -import space.kscience.kmath.nd.contentEquals import space.kscience.kmath.operations.asIterable import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensors.core.* @@ -94,7 +93,7 @@ class Dense( // simple accuracy equal to the proportion of correct answers fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double { - check(yPred.shape contentEquals yTrue.shape) + check(yPred.shape == yTrue.shape) val n = yPred.shape[0] var correctCnt = 0 for (i in 0 until n) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt index fca29a9dd..878159a51 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt @@ -32,12 +32,12 @@ public class BufferedLinearSpace>( } override fun Matrix.plus(other: Matrix): Matrix = ndAlgebra { - require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } + require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } asND().plus(other.asND()).as2D() } override fun Matrix.minus(other: Matrix): Matrix = ndAlgebra { - require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } + require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } asND().minus(other.asND()).as2D() } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Float64LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Float64LinearSpace.kt index bafdbbc3b..561cb5fd4 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Float64LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/Float64LinearSpace.kt @@ -33,12 +33,12 @@ public object Float64LinearSpace : LinearSpace { } override fun Matrix.plus(other: Matrix): Matrix = Floa64FieldOpsND { - require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } + require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } asND().plus(other.asND()).as2D() } override fun Matrix.minus(other: Matrix): Matrix = Floa64FieldOpsND { - require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } + require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } asND().minus(other.asND()).as2D() } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index 744f760f3..022d56a85 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -55,7 +55,7 @@ public interface AlgebraND> : Algebra> { */ @PerformancePitfall("Very slow on remote execution algebras") public fun zip(left: StructureND, right: StructureND, transform: C.(T, T) -> T): StructureND { - require(left.shape.contentEquals(right.shape)) { + require(left.shape == right.shape) { "Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}" } return structureND(left.shape) { index -> diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt index e8545125f..d22e09d1b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndices.kt @@ -106,10 +106,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() { override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is ColumnStrides) return false - return shape.contentEquals(other.shape) + return shape == other.shape } - override fun hashCode(): Int = shape.contentHashCode() + override fun hashCode(): Int = shape.hashCode() public companion object @@ -156,10 +156,10 @@ public class RowStrides(override val shape: ShapeND) : Strides() { override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is RowStrides) return false - return shape.contentEquals(other.shape) + return shape == other.shape } - override fun hashCode(): Int = shape.contentHashCode() + override fun hashCode(): Int = shape.hashCode() public companion object diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeND.kt index d5cdc6286..2b30d35e4 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeND.kt @@ -11,19 +11,30 @@ import kotlin.jvm.JvmInline /** * A read-only ND shape */ -@JvmInline -public value class ShapeND(@PublishedApi internal val array: IntArray) { +public class ShapeND(@PublishedApi internal val array: IntArray) { public val size: Int get() = array.size public operator fun get(index: Int): Int = array[index] override fun toString(): String = array.contentToString() + override fun hashCode(): Int = array.contentHashCode() + override fun equals(other: Any?): Boolean = other is ShapeND && array.contentEquals(other.array) } public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block) public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block) +@Deprecated( + message = "ShapeND is made a usual class with correct `equals`. Use it instead of the `contentEquals`.", + replaceWith = ReplaceWith("this == other"), + level = DeprecationLevel.WARNING, +) public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array) +@Deprecated( + message = "ShapeND is made a usual class with correct `hashCode`. Use it instead of the `contentHashCode`.", + replaceWith = ReplaceWith("this.hashCode()"), + level = DeprecationLevel.WARNING, +) public fun ShapeND.contentHashCode(): Int = array.contentHashCode() public val ShapeND.indices: IntRange get() = array.indices diff --git a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/linear/Float64ParallelLinearSpace.kt b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/linear/Float64ParallelLinearSpace.kt index 1fb5625b3..088b0fd9f 100644 --- a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/linear/Float64ParallelLinearSpace.kt +++ b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/linear/Float64ParallelLinearSpace.kt @@ -48,12 +48,12 @@ public object Float64ParallelLinearSpace : LinearSpace { } override fun Matrix.plus(other: Matrix): Matrix = Floa64FieldOpsND { - require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } + require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } asND().plus(other.asND()).as2D() } override fun Matrix.minus(other: Matrix): Matrix = Floa64FieldOpsND { - require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } + require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } asND().minus(other.asND()).as2D() } diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index c0209338c..574a55605 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -92,7 +92,7 @@ public abstract class MultikTensorAlgebra>( @OptIn(PerformancePitfall::class) override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): MultikTensor { - require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException + require(left.shape == right.shape) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException val leftArray = left.asMultik().array val rightArray = right.asMultik().array val data = initMemoryView(leftArray.size, dataType) @@ -124,7 +124,7 @@ public abstract class MultikTensorAlgebra>( public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this.asDNArray()) @OptIn(PerformancePitfall::class) - override fun StructureND.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) { + override fun StructureND.valueOrNull(): T? = if (shape == ShapeND(1)) { get(intArrayOf(0)) } else null @@ -224,7 +224,7 @@ public abstract class MultikTensorAlgebra>( } val mt = asMultik().array - return if (ShapeND(mt.shape).contentEquals(shape)) { + return if (ShapeND(mt.shape) == shape) { mt } else { @OptIn(UnsafeKMathAPI::class) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt index 648e5e318..51f59f731 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt @@ -63,7 +63,7 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND, transform: C.(T, T) -> T, ): Nd4jArrayStructure { - require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" } + require(left.shape == right.shape) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" } val new = Nd4j.create(*left.shape.asArray()).wrap() new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) } return new diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index 095715070..954984a69 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -49,7 +49,7 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe @OptIn(PerformancePitfall::class) override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { - require(left.shape.contentEquals(right.shape)) + require(left.shape == right.shape) return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) } } @@ -203,7 +203,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra } override fun StructureND.valueOrNull(): Double? = - if (shape contentEquals ShapeND(1)) ndArray.getDouble(0) else null + if (shape == ShapeND(1)) ndArray.getDouble(0) else null // TODO rewrite override fun diagonalEmbedding( diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt index 2a8720275..73ab44c60 100644 --- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt +++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt @@ -23,7 +23,6 @@ import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.asArray -import space.kscience.kmath.nd.contentEquals import space.kscience.kmath.operations.Ring import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra @@ -105,7 +104,7 @@ public abstract class TensorFlowAlgebra> internal c protected abstract fun const(value: T): Constant @OptIn(PerformancePitfall::class) - override fun StructureND.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) + override fun StructureND.valueOrNull(): T? = if (shape == ShapeND(1)) get(intArrayOf(0)) else null /** diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 0024f41ad..d9e107d86 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -95,7 +95,7 @@ public open class DoubleTensorAlgebra : override fun StructureND.valueOrNull(): Double? { val dt = asDoubleTensor() - return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null + return if (dt.shape == ShapeND(1)) dt.source[0] else null } override fun StructureND.value(): Double = valueOrNull() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt index e75564ef5..46ec93059 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt @@ -89,7 +89,7 @@ public open class IntTensorAlgebra : TensorAlgebra { override fun StructureND.valueOrNull(): Int? { val dt = asIntTensor() - return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null + return if (dt.shape == ShapeND(1)) dt.source[0] else null } override fun StructureND.value(): Int = valueOrNull() @@ -387,7 +387,7 @@ public open class IntTensorAlgebra : TensorAlgebra { public fun stack(tensors: List>): IntTensor { check(tensors.isNotEmpty()) { "List must have at least 1 element" } val shape = tensors[0].shape - check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } + check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" } val resShape = ShapeND(tensors.size) + shape // val resBuffer: List = tensors.flatMap { // it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt index 1a0aa29d1..736f9c33d 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt @@ -7,7 +7,6 @@ package space.kscience.kmath.tensors.core.internal import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.contentEquals import space.kscience.kmath.nd.linearSize import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.DoubleTensor @@ -32,7 +31,7 @@ internal fun checkBufferShapeConsistency(shape: ShapeND, buffer: DoubleArray) = @PublishedApi internal fun checkShapesCompatible(a: StructureND, b: StructureND): Unit = - check(a.shape contentEquals b.shape) { + check(a.shape == b.shape) { "Incompatible shapes ${a.shape} and ${b.shape} " } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorOps.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorOps.kt index 9b31e5694..bd95cee5f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorOps.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorOps.kt @@ -47,7 +47,7 @@ public fun DoubleTensorAlgebra.randomNormalLike(structure: WithShape, seed: Long public fun stack(tensors: List>): DoubleTensor { check(tensors.isNotEmpty()) { "List must have at least 1 element" } val shape = tensors[0].shape - check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } + check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" } val resShape = ShapeND(tensors.size) + shape // val resBuffer: List = tensors.flatMap { // it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart) @@ -91,7 +91,7 @@ public fun DoubleTensorAlgebra.luPivot( ): Triple { checkSquareMatrix(luTensor.shape) check( - luTensor.shape.first(luTensor.shape.size - 2) contentEquals pivotsTensor.shape.first(pivotsTensor.shape.size - 1) || + luTensor.shape.first(luTensor.shape.size - 2) == pivotsTensor.shape.first(pivotsTensor.shape.size - 1) || luTensor.shape.last() == pivotsTensor.shape.last() - 1 ) { "Inappropriate shapes of input tensors" } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt index f930cf94c..600f8241a 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt @@ -6,29 +6,31 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.nd.ShapeND -import space.kscience.kmath.nd.contentEquals import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors import space.kscience.kmath.tensors.core.internal.broadcastShapes import space.kscience.kmath.tensors.core.internal.broadcastTensors import space.kscience.kmath.tensors.core.internal.broadcastTo import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertTrue internal class TestBroadcasting { @Test fun testBroadcastShapes() = DoubleTensorAlgebra { - assertTrue( + assertEquals( broadcastShapes( listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1)) - ) contentEquals ShapeND(1, 2, 3) + ), + ShapeND(1, 2, 3) ) - assertTrue( + assertEquals( broadcastShapes( listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7)) - ) contentEquals ShapeND(5, 6, 7) + ), + ShapeND(5, 6, 7) ) } @@ -38,7 +40,7 @@ internal class TestBroadcasting { val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0)) val res = broadcastTo(tensor2, tensor1.shape) - assertTrue(res.shape contentEquals ShapeND(2, 3)) + assertTrue(res.shape == ShapeND(2, 3)) assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) } @@ -50,9 +52,9 @@ internal class TestBroadcasting { val res = broadcastTensors(tensor1, tensor2, tensor3) - assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3)) - assertTrue(res[1].shape contentEquals ShapeND(1, 2, 3)) - assertTrue(res[2].shape contentEquals ShapeND(1, 2, 3)) + assertEquals(res[0].shape, ShapeND(1, 2, 3)) + assertEquals(res[1].shape, ShapeND(1, 2, 3)) + assertEquals(res[2].shape, ShapeND(1, 2, 3)) assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) @@ -67,9 +69,9 @@ internal class TestBroadcasting { val res = broadcastOuterTensors(tensor1, tensor2, tensor3) - assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3)) - assertTrue(res[1].shape contentEquals ShapeND(1, 1, 3)) - assertTrue(res[2].shape contentEquals ShapeND(1, 1, 1)) + assertEquals(res[0].shape, ShapeND(1, 2, 3)) + assertEquals(res[1].shape, ShapeND(1, 1, 3)) + assertEquals(res[2].shape, ShapeND(1, 1, 1)) assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0)) @@ -84,9 +86,9 @@ internal class TestBroadcasting { val res = broadcastOuterTensors(tensor1, tensor2, tensor3) - assertTrue(res[0].shape contentEquals ShapeND(4, 2, 5, 3, 2, 3)) - assertTrue(res[1].shape contentEquals ShapeND(4, 2, 5, 3, 3, 3)) - assertTrue(res[2].shape contentEquals ShapeND(4, 2, 5, 3, 1, 1)) + assertEquals(res[0].shape, ShapeND(4, 2, 5, 3, 2, 3)) + assertEquals(res[1].shape, ShapeND(4, 2, 5, 3, 3, 3)) + assertEquals(res[2].shape, ShapeND(4, 2, 5, 3, 1, 1)) } @Test @@ -99,16 +101,16 @@ internal class TestBroadcasting { val tensor31 = tensor3 - tensor1 val tensor32 = tensor3 - tensor2 - assertTrue(tensor21.shape contentEquals ShapeND(2, 3)) + assertEquals(tensor21.shape, ShapeND(2, 3)) assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) - assertTrue(tensor31.shape contentEquals ShapeND(1, 2, 3)) + assertEquals(tensor31.shape, ShapeND(1, 2, 3)) assertTrue( tensor31.source contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0) ) - assertTrue(tensor32.shape contentEquals ShapeND(1, 1, 3)) + assertEquals(tensor32.shape, ShapeND(1, 1, 3)) assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0)) } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index 1c500af8b..d3c192d67 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -6,7 +6,6 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.nd.ShapeND -import space.kscience.kmath.nd.contentEquals import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensors.core.internal.svd1d import kotlin.math.abs @@ -113,8 +112,8 @@ internal class TestDoubleLinearOpsTensorAlgebra { val (q, r) = qr(tensor) - assertTrue { q.shape contentEquals shape } - assertTrue { r.shape contentEquals shape } + assertEquals(q.shape, shape) + assertEquals(r.shape, shape) assertTrue((q matmul r).eq(tensor)) @@ -133,9 +132,9 @@ internal class TestDoubleLinearOpsTensorAlgebra { val (p, l, u) = lu(tensor) - assertTrue { p.shape contentEquals shape } - assertTrue { l.shape contentEquals shape } - assertTrue { u.shape contentEquals shape } + assertEquals(p.shape, shape) + assertEquals(l.shape, shape) + assertEquals(u.shape, shape) assertTrue((p matmul tensor).eq(l matmul u)) } @@ -157,7 +156,7 @@ internal class TestDoubleLinearOpsTensorAlgebra { val res = svd1d(tensor2) - assertTrue(res.shape contentEquals ShapeND(2)) + assertEquals(res.shape, ShapeND(2)) assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 } assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 } } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 4f58846aa..d33e29f26 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -6,7 +6,8 @@ package space.kscience.kmath.tensors.core -import space.kscience.kmath.nd.* +import space.kscience.kmath.nd.ShapeND +import space.kscience.kmath.nd.get import space.kscience.kmath.operations.invoke import space.kscience.kmath.testutils.assertBufferEquals import kotlin.test.Test @@ -43,7 +44,7 @@ internal class TestDoubleTensorAlgebra { val res = tensor.transposed(0, 0) assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0)) - assertTrue(res.shape contentEquals ShapeND(1)) + assertEquals(res.shape, ShapeND(1)) } @Test @@ -52,7 +53,7 @@ internal class TestDoubleTensorAlgebra { val res = tensor.transposed(1, 0) assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) - assertTrue(res.shape contentEquals ShapeND(2, 3)) + assertEquals(res.shape, ShapeND(2, 3)) } @Test @@ -62,9 +63,9 @@ internal class TestDoubleTensorAlgebra { val res02 = tensor.transposed(-3, 2) val res12 = tensor.transposed() - assertTrue(res01.shape contentEquals ShapeND(2, 1, 3)) - assertTrue(res02.shape contentEquals ShapeND(3, 2, 1)) - assertTrue(res12.shape contentEquals ShapeND(1, 3, 2)) + assertEquals(res01.shape, ShapeND(2, 1, 3)) + assertEquals(res02.shape, ShapeND(3, 2, 1)) + assertEquals(res12.shape, ShapeND(1, 3, 2)) assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) @@ -114,19 +115,19 @@ internal class TestDoubleTensorAlgebra { val res12 = tensor1.dot(tensor2) assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0)) - assertTrue(res12.shape contentEquals ShapeND(2)) + assertEquals(res12.shape, ShapeND(2)) val res32 = tensor3.matmul(tensor2) assertTrue(res32.source contentEquals doubleArrayOf(-140.0)) - assertTrue(res32.shape contentEquals ShapeND(1, 1)) + assertEquals(res32.shape, ShapeND(1, 1)) val res22 = tensor2.dot(tensor2) assertTrue(res22.source contentEquals doubleArrayOf(1400.0)) - assertTrue(res22.shape contentEquals ShapeND(1)) + assertEquals(res22.shape, ShapeND(1)) val res11 = tensor1.dot(tensor11) assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0)) - assertTrue(res11.shape contentEquals ShapeND(2, 2)) + assertEquals(res11.shape, ShapeND(2, 2)) val res45 = tensor4.matmul(tensor5) assertTrue( @@ -135,7 +136,7 @@ internal class TestDoubleTensorAlgebra { 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 ) ) - assertTrue(res45.shape contentEquals ShapeND(2, 3, 3)) + assertEquals(res45.shape, ShapeND(2, 3, 3)) } @Test @@ -144,35 +145,35 @@ internal class TestDoubleTensorAlgebra { val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor3 = zeros(ShapeND(2, 3, 4, 5)) - assertTrue( - diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals - ShapeND(2, 3, 4, 5, 5) + assertEquals( + diagonalEmbedding(tensor3, 0, 3, 4).shape, + ShapeND(2, 3, 4, 5, 5) ) - assertTrue( - diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals - ShapeND(2, 3, 4, 6, 6) + assertEquals( + diagonalEmbedding(tensor3, 1, 3, 4).shape, + ShapeND(2, 3, 4, 6, 6) ) - assertTrue( - diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals - ShapeND(7, 2, 3, 7, 4) + assertEquals( + diagonalEmbedding(tensor3, 2, 0, 3).shape, + ShapeND(7, 2, 3, 7, 4) ) val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0) - assertTrue(diagonal1.shape contentEquals ShapeND(3, 3)) + assertEquals(diagonal1.shape, ShapeND(3, 3)) assertTrue( diagonal1.source contentEquals doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0) ) val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0) - assertTrue(diagonal1Offset.shape contentEquals ShapeND(4, 4)) + assertEquals(diagonal1Offset.shape, ShapeND(4, 4)) assertTrue( diagonal1Offset.source contentEquals doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0) ) val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2) - assertTrue(diagonal2.shape contentEquals ShapeND(4, 2, 4)) + assertEquals(diagonal2.shape, ShapeND(4, 2, 4)) assertTrue( diagonal2.source contentEquals doubleArrayOf( @@ -202,7 +203,7 @@ internal class TestDoubleTensorAlgebra { val l = tensor.getTensor(0).map { it + 1.0 } val r = tensor.getTensor(1).map { it - 1.0 } val res = l + r - assertTrue { ShapeND(5, 5) contentEquals res.shape } + assertEquals(ShapeND(5, 5), res.shape) assertEquals(2.0, res[4, 4]) } } diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt index 1d95546ac..61ef00c6f 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt @@ -67,7 +67,7 @@ public open class ViktorFieldOpsND : right: StructureND, transform: Float64Field.(Double, Double) -> Double, ): ViktorStructureND { - require(left.shape.contentEquals(right.shape)) + require(left.shape == right.shape) return F64Array(*left.shape.asArray()).apply { ColumnStrides(left.shape).asSequence().forEach { index -> set(value = Float64Field.transform(left[index], right[index]), indices = index)