Make ShapeND a usual non-value class. Implement its equals and hashCode methods. Deprecate contentEquals and contentHashCode.

This commit is contained in:
Gleb Minaev 2024-05-08 21:59:49 +03:00
parent fc0393436f
commit 201887187d
21 changed files with 92 additions and 83 deletions

View File

@ -35,7 +35,7 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND<Double, Float64
@OptIn(PerformancePitfall::class)
private val StructureND<Double>.buffer: Float64Buffer
get() = when {
!shape.contentEquals(this@StreamDoubleFieldND.shape) -> throw ShapeMismatchException(
shape != this@StreamDoubleFieldND.shape -> throw ShapeMismatchException(
this@StreamDoubleFieldND.shape,
shape
)

View File

@ -6,7 +6,6 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
@ -62,7 +61,7 @@ fun main() {
// figure out MSE of approximation
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
require(yTrue.shape.size == 1)
require(yTrue.shape contentEquals yPred.shape)
require(yTrue.shape == yPred.shape)
val diff = yTrue - yPred
return sqrt(diff.dot(diff)).value()

View File

@ -6,7 +6,6 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.asIterable
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.*
@ -94,7 +93,7 @@ class Dense(
// simple accuracy equal to the proportion of correct answers
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
check(yPred.shape contentEquals yTrue.shape)
check(yPred.shape == yTrue.shape)
val n = yPred.shape[0]
var correctCnt = 0
for (i in 0 until n) {

View File

@ -32,12 +32,12 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
}
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}

View File

@ -33,12 +33,12 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
}
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}

View File

@ -55,7 +55,7 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
*/
@PerformancePitfall("Very slow on remote execution algebras")
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> {
require(left.shape.contentEquals(right.shape)) {
require(left.shape == right.shape) {
"Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}"
}
return structureND(left.shape) { index ->

View File

@ -106,10 +106,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is ColumnStrides) return false
return shape.contentEquals(other.shape)
return shape == other.shape
}
override fun hashCode(): Int = shape.contentHashCode()
override fun hashCode(): Int = shape.hashCode()
public companion object
@ -156,10 +156,10 @@ public class RowStrides(override val shape: ShapeND) : Strides() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is RowStrides) return false
return shape.contentEquals(other.shape)
return shape == other.shape
}
override fun hashCode(): Int = shape.contentHashCode()
override fun hashCode(): Int = shape.hashCode()
public companion object

View File

@ -11,19 +11,30 @@ import kotlin.jvm.JvmInline
/**
* A read-only ND shape
*/
@JvmInline
public value class ShapeND(@PublishedApi internal val array: IntArray) {
public class ShapeND(@PublishedApi internal val array: IntArray) {
public val size: Int get() = array.size
public operator fun get(index: Int): Int = array[index]
override fun toString(): String = array.contentToString()
override fun hashCode(): Int = array.contentHashCode()
override fun equals(other: Any?): Boolean = other is ShapeND && array.contentEquals(other.array)
}
public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block)
public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block)
@Deprecated(
message = "ShapeND is made a usual class with correct `equals`. Use it instead of the `contentEquals`.",
replaceWith = ReplaceWith("this == other"),
level = DeprecationLevel.WARNING,
)
public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array)
@Deprecated(
message = "ShapeND is made a usual class with correct `hashCode`. Use it instead of the `contentHashCode`.",
replaceWith = ReplaceWith("this.hashCode()"),
level = DeprecationLevel.WARNING,
)
public fun ShapeND.contentHashCode(): Int = array.contentHashCode()
public val ShapeND.indices: IntRange get() = array.indices

View File

@ -48,12 +48,12 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
}
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}

View File

@ -92,7 +92,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
@OptIn(PerformancePitfall::class)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
require(left.shape == right.shape) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array
val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, dataType)
@ -124,7 +124,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
@OptIn(PerformancePitfall::class)
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) {
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1)) {
get(intArrayOf(0))
} else null
@ -224,7 +224,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
}
val mt = asMultik().array
return if (ShapeND(mt.shape).contentEquals(shape)) {
return if (ShapeND(mt.shape) == shape) {
mt
} else {
@OptIn(UnsafeKMathAPI::class)

View File

@ -63,7 +63,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
right: StructureND<T>,
transform: C.(T, T) -> T,
): Nd4jArrayStructure<T> {
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
require(left.shape == right.shape) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
val new = Nd4j.create(*left.shape.asArray()).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
return new

View File

@ -49,7 +49,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
@OptIn(PerformancePitfall::class)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
require(left.shape.contentEquals(right.shape))
require(left.shape == right.shape)
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
}
@ -203,7 +203,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, Float64Field>
}
override fun StructureND<Double>.valueOrNull(): Double? =
if (shape contentEquals ShapeND(1)) ndArray.getDouble(0) else null
if (shape == ShapeND(1)) ndArray.getDouble(0) else null
// TODO rewrite
override fun diagonalEmbedding(

View File

@ -23,7 +23,6 @@ import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.asArray
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
@ -105,7 +104,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
protected abstract fun const(value: T): Constant<TT>
@OptIn(PerformancePitfall::class)
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1))
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1))
get(intArrayOf(0)) else null
/**

View File

@ -95,7 +95,7 @@ public open class DoubleTensorAlgebra :
override fun StructureND<Double>.valueOrNull(): Double? {
val dt = asDoubleTensor()
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
return if (dt.shape == ShapeND(1)) dt.source[0] else null
}
override fun StructureND<Double>.value(): Double = valueOrNull()

View File

@ -89,7 +89,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
override fun StructureND<Int>.valueOrNull(): Int? {
val dt = asIntTensor()
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
return if (dt.shape == ShapeND(1)) dt.source[0] else null
}
override fun StructureND<Int>.value(): Int = valueOrNull()
@ -387,7 +387,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
public fun stack(tensors: List<Tensor<Int>>): IntTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
val resShape = ShapeND(tensors.size) + shape
// val resBuffer: List<Int> = tensors.flatMap {
// it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart)

View File

@ -7,7 +7,6 @@ package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.nd.linearSize
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor
@ -32,7 +31,7 @@ internal fun checkBufferShapeConsistency(shape: ShapeND, buffer: DoubleArray) =
@PublishedApi
internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit =
check(a.shape contentEquals b.shape) {
check(a.shape == b.shape) {
"Incompatible shapes ${a.shape} and ${b.shape} "
}

View File

@ -47,7 +47,7 @@ public fun DoubleTensorAlgebra.randomNormalLike(structure: WithShape, seed: Long
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
val resShape = ShapeND(tensors.size) + shape
// val resBuffer: List<Double> = tensors.flatMap {
// it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart)
@ -91,7 +91,7 @@ public fun DoubleTensorAlgebra.luPivot(
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape)
check(
luTensor.shape.first(luTensor.shape.size - 2) contentEquals pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
luTensor.shape.first(luTensor.shape.size - 2) == pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
) { "Inappropriate shapes of input tensors" }

View File

@ -6,29 +6,31 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
import space.kscience.kmath.tensors.core.internal.broadcastShapes
import space.kscience.kmath.tensors.core.internal.broadcastTensors
import space.kscience.kmath.tensors.core.internal.broadcastTo
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class TestBroadcasting {
@Test
fun testBroadcastShapes() = DoubleTensorAlgebra {
assertTrue(
assertEquals(
broadcastShapes(
listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1))
) contentEquals ShapeND(1, 2, 3)
),
ShapeND(1, 2, 3)
)
assertTrue(
assertEquals(
broadcastShapes(
listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7))
) contentEquals ShapeND(5, 6, 7)
),
ShapeND(5, 6, 7)
)
}
@ -38,7 +40,7 @@ internal class TestBroadcasting {
val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape contentEquals ShapeND(2, 3))
assertTrue(res.shape == ShapeND(2, 3))
assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
}
@ -50,9 +52,9 @@ internal class TestBroadcasting {
val res = broadcastTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(1, 2, 3))
assertTrue(res[2].shape contentEquals ShapeND(1, 2, 3))
assertEquals(res[0].shape, ShapeND(1, 2, 3))
assertEquals(res[1].shape, ShapeND(1, 2, 3))
assertEquals(res[2].shape, ShapeND(1, 2, 3))
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
@ -67,9 +69,9 @@ internal class TestBroadcasting {
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(1, 1, 3))
assertTrue(res[2].shape contentEquals ShapeND(1, 1, 1))
assertEquals(res[0].shape, ShapeND(1, 2, 3))
assertEquals(res[1].shape, ShapeND(1, 1, 3))
assertEquals(res[2].shape, ShapeND(1, 1, 1))
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0))
@ -84,9 +86,9 @@ internal class TestBroadcasting {
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals ShapeND(4, 2, 5, 3, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(4, 2, 5, 3, 3, 3))
assertTrue(res[2].shape contentEquals ShapeND(4, 2, 5, 3, 1, 1))
assertEquals(res[0].shape, ShapeND(4, 2, 5, 3, 2, 3))
assertEquals(res[1].shape, ShapeND(4, 2, 5, 3, 3, 3))
assertEquals(res[2].shape, ShapeND(4, 2, 5, 3, 1, 1))
}
@Test
@ -99,16 +101,16 @@ internal class TestBroadcasting {
val tensor31 = tensor3 - tensor1
val tensor32 = tensor3 - tensor2
assertTrue(tensor21.shape contentEquals ShapeND(2, 3))
assertEquals(tensor21.shape, ShapeND(2, 3))
assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue(tensor31.shape contentEquals ShapeND(1, 2, 3))
assertEquals(tensor31.shape, ShapeND(1, 2, 3))
assertTrue(
tensor31.source
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
)
assertTrue(tensor32.shape contentEquals ShapeND(1, 1, 3))
assertEquals(tensor32.shape, ShapeND(1, 1, 3))
assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}

View File

@ -6,7 +6,6 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.svd1d
import kotlin.math.abs
@ -113,8 +112,8 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val (q, r) = qr(tensor)
assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape }
assertEquals(q.shape, shape)
assertEquals(r.shape, shape)
assertTrue((q matmul r).eq(tensor))
@ -133,9 +132,9 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val (p, l, u) = lu(tensor)
assertTrue { p.shape contentEquals shape }
assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape }
assertEquals(p.shape, shape)
assertEquals(l.shape, shape)
assertEquals(u.shape, shape)
assertTrue((p matmul tensor).eq(l matmul u))
}
@ -157,7 +156,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val res = svd1d(tensor2)
assertTrue(res.shape contentEquals ShapeND(2))
assertEquals(res.shape, ShapeND(2))
assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 }
assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 }
}

View File

@ -6,7 +6,8 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.get
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.assertBufferEquals
import kotlin.test.Test
@ -43,7 +44,7 @@ internal class TestDoubleTensorAlgebra {
val res = tensor.transposed(0, 0)
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals ShapeND(1))
assertEquals(res.shape, ShapeND(1))
}
@Test
@ -52,7 +53,7 @@ internal class TestDoubleTensorAlgebra {
val res = tensor.transposed(1, 0)
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals ShapeND(2, 3))
assertEquals(res.shape, ShapeND(2, 3))
}
@Test
@ -62,9 +63,9 @@ internal class TestDoubleTensorAlgebra {
val res02 = tensor.transposed(-3, 2)
val res12 = tensor.transposed()
assertTrue(res01.shape contentEquals ShapeND(2, 1, 3))
assertTrue(res02.shape contentEquals ShapeND(3, 2, 1))
assertTrue(res12.shape contentEquals ShapeND(1, 3, 2))
assertEquals(res01.shape, ShapeND(2, 1, 3))
assertEquals(res02.shape, ShapeND(3, 2, 1))
assertEquals(res12.shape, ShapeND(1, 3, 2))
assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
@ -114,19 +115,19 @@ internal class TestDoubleTensorAlgebra {
val res12 = tensor1.dot(tensor2)
assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0))
assertTrue(res12.shape contentEquals ShapeND(2))
assertEquals(res12.shape, ShapeND(2))
val res32 = tensor3.matmul(tensor2)
assertTrue(res32.source contentEquals doubleArrayOf(-140.0))
assertTrue(res32.shape contentEquals ShapeND(1, 1))
assertEquals(res32.shape, ShapeND(1, 1))
val res22 = tensor2.dot(tensor2)
assertTrue(res22.source contentEquals doubleArrayOf(1400.0))
assertTrue(res22.shape contentEquals ShapeND(1))
assertEquals(res22.shape, ShapeND(1))
val res11 = tensor1.dot(tensor11)
assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals ShapeND(2, 2))
assertEquals(res11.shape, ShapeND(2, 2))
val res45 = tensor4.matmul(tensor5)
assertTrue(
@ -135,7 +136,7 @@ internal class TestDoubleTensorAlgebra {
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
)
)
assertTrue(res45.shape contentEquals ShapeND(2, 3, 3))
assertEquals(res45.shape, ShapeND(2, 3, 3))
}
@Test
@ -144,35 +145,35 @@ internal class TestDoubleTensorAlgebra {
val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor3 = zeros(ShapeND(2, 3, 4, 5))
assertTrue(
diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
ShapeND(2, 3, 4, 5, 5)
assertEquals(
diagonalEmbedding(tensor3, 0, 3, 4).shape,
ShapeND(2, 3, 4, 5, 5)
)
assertTrue(
diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
ShapeND(2, 3, 4, 6, 6)
assertEquals(
diagonalEmbedding(tensor3, 1, 3, 4).shape,
ShapeND(2, 3, 4, 6, 6)
)
assertTrue(
diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
ShapeND(7, 2, 3, 7, 4)
assertEquals(
diagonalEmbedding(tensor3, 2, 0, 3).shape,
ShapeND(7, 2, 3, 7, 4)
)
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
assertTrue(diagonal1.shape contentEquals ShapeND(3, 3))
assertEquals(diagonal1.shape, ShapeND(3, 3))
assertTrue(
diagonal1.source contentEquals
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)
)
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
assertTrue(diagonal1Offset.shape contentEquals ShapeND(4, 4))
assertEquals(diagonal1Offset.shape, ShapeND(4, 4))
assertTrue(
diagonal1Offset.source contentEquals
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)
)
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
assertTrue(diagonal2.shape contentEquals ShapeND(4, 2, 4))
assertEquals(diagonal2.shape, ShapeND(4, 2, 4))
assertTrue(
diagonal2.source contentEquals
doubleArrayOf(
@ -202,7 +203,7 @@ internal class TestDoubleTensorAlgebra {
val l = tensor.getTensor(0).map { it + 1.0 }
val r = tensor.getTensor(1).map { it - 1.0 }
val res = l + r
assertTrue { ShapeND(5, 5) contentEquals res.shape }
assertEquals(ShapeND(5, 5), res.shape)
assertEquals(2.0, res[4, 4])
}
}

View File

@ -67,7 +67,7 @@ public open class ViktorFieldOpsND :
right: StructureND<Double>,
transform: Float64Field.(Double, Double) -> Double,
): ViktorStructureND {
require(left.shape.contentEquals(right.shape))
require(left.shape == right.shape)
return F64Array(*left.shape.asArray()).apply {
ColumnStrides(left.shape).asSequence().forEach { index ->
set(value = Float64Field.transform(left[index], right[index]), indices = index)