Fix #532 by making ShapeND a non-value class #522

Merged
altavir merged 1 commits from bug/defaultStridesCache into dev 2024-05-09 09:16:53 +03:00
21 changed files with 92 additions and 83 deletions

View File

@ -35,7 +35,7 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND<Double, Float64
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
private val StructureND<Double>.buffer: Float64Buffer private val StructureND<Double>.buffer: Float64Buffer
get() = when { get() = when {
!shape.contentEquals(this@StreamDoubleFieldND.shape) -> throw ShapeMismatchException( shape != this@StreamDoubleFieldND.shape -> throw ShapeMismatchException(
this@StreamDoubleFieldND.shape, this@StreamDoubleFieldND.shape,
shape shape
) )

View File

@ -6,7 +6,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
@ -62,7 +61,7 @@ fun main() {
// figure out MSE of approximation // figure out MSE of approximation
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double { fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
require(yTrue.shape.size == 1) require(yTrue.shape.size == 1)
require(yTrue.shape contentEquals yPred.shape) require(yTrue.shape == yPred.shape)
val diff = yTrue - yPred val diff = yTrue - yPred
return sqrt(diff.dot(diff)).value() return sqrt(diff.dot(diff)).value()

View File

@ -6,7 +6,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.asIterable import space.kscience.kmath.operations.asIterable
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.*
@ -94,7 +93,7 @@ class Dense(
// simple accuracy equal to the proportion of correct answers // simple accuracy equal to the proportion of correct answers
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double { fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
check(yPred.shape contentEquals yTrue.shape) check(yPred.shape == yTrue.shape)
val n = yPred.shape[0] val n = yPred.shape[0]
var correctCnt = 0 var correctCnt = 0
for (i in 0 until n) { for (i in 0 until n) {

View File

@ -32,12 +32,12 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
} }
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra { override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D() asND().plus(other.asND()).as2D()
} }
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra { override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D() asND().minus(other.asND()).as2D()
} }

View File

@ -33,12 +33,12 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
} }
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND { override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D() asND().plus(other.asND()).as2D()
} }
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND { override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D() asND().minus(other.asND()).as2D()
} }

View File

@ -55,7 +55,7 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
*/ */
@PerformancePitfall("Very slow on remote execution algebras") @PerformancePitfall("Very slow on remote execution algebras")
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> { public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> {
require(left.shape.contentEquals(right.shape)) { require(left.shape == right.shape) {
"Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}" "Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}"
} }
return structureND(left.shape) { index -> return structureND(left.shape) { index ->

View File

@ -106,10 +106,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other !is ColumnStrides) return false if (other !is ColumnStrides) return false
return shape.contentEquals(other.shape) return shape == other.shape
} }
override fun hashCode(): Int = shape.contentHashCode() override fun hashCode(): Int = shape.hashCode()
public companion object public companion object
@ -156,10 +156,10 @@ public class RowStrides(override val shape: ShapeND) : Strides() {
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other !is RowStrides) return false if (other !is RowStrides) return false
return shape.contentEquals(other.shape) return shape == other.shape
} }
override fun hashCode(): Int = shape.contentHashCode() override fun hashCode(): Int = shape.hashCode()
public companion object public companion object

View File

@ -11,19 +11,30 @@ import kotlin.jvm.JvmInline
/** /**
* A read-only ND shape * A read-only ND shape
*/ */
@JvmInline public class ShapeND(@PublishedApi internal val array: IntArray) {
public value class ShapeND(@PublishedApi internal val array: IntArray) {
public val size: Int get() = array.size public val size: Int get() = array.size
public operator fun get(index: Int): Int = array[index] public operator fun get(index: Int): Int = array[index]
override fun toString(): String = array.contentToString() override fun toString(): String = array.contentToString()
override fun hashCode(): Int = array.contentHashCode()
override fun equals(other: Any?): Boolean = other is ShapeND && array.contentEquals(other.array)
} }
public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block) public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block)
public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block) public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block)
@Deprecated(
message = "ShapeND is made a usual class with correct `equals`. Use it instead of the `contentEquals`.",
replaceWith = ReplaceWith("this == other"),
level = DeprecationLevel.WARNING,
)
public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array) public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array)
@Deprecated(
message = "ShapeND is made a usual class with correct `hashCode`. Use it instead of the `contentHashCode`.",
replaceWith = ReplaceWith("this.hashCode()"),
level = DeprecationLevel.WARNING,
)
public fun ShapeND.contentHashCode(): Int = array.contentHashCode() public fun ShapeND.contentHashCode(): Int = array.contentHashCode()
public val ShapeND.indices: IntRange get() = array.indices public val ShapeND.indices: IntRange get() = array.indices

View File

@ -48,12 +48,12 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
} }
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND { override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D() asND().plus(other.asND()).as2D()
} }
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND { override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D() asND().minus(other.asND()).as2D()
} }

View File

@ -92,7 +92,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> { override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException require(left.shape == right.shape) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array val leftArray = left.asMultik().array
val rightArray = right.asMultik().array val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, dataType) val data = initMemoryView<T>(leftArray.size, dataType)
@ -124,7 +124,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray()) public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) { override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1)) {
get(intArrayOf(0)) get(intArrayOf(0))
} else null } else null
@ -224,7 +224,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
} }
val mt = asMultik().array val mt = asMultik().array
return if (ShapeND(mt.shape).contentEquals(shape)) { return if (ShapeND(mt.shape) == shape) {
mt mt
} else { } else {
@OptIn(UnsafeKMathAPI::class) @OptIn(UnsafeKMathAPI::class)

View File

@ -63,7 +63,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
right: StructureND<T>, right: StructureND<T>,
transform: C.(T, T) -> T, transform: C.(T, T) -> T,
): Nd4jArrayStructure<T> { ): Nd4jArrayStructure<T> {
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" } require(left.shape == right.shape) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
val new = Nd4j.create(*left.shape.asArray()).wrap() val new = Nd4j.create(*left.shape.asArray()).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) } new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
return new return new

View File

@ -49,7 +49,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> { override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
require(left.shape.contentEquals(right.shape)) require(left.shape == right.shape)
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) } return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
} }
@ -203,7 +203,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, Float64Field>
} }
override fun StructureND<Double>.valueOrNull(): Double? = override fun StructureND<Double>.valueOrNull(): Double? =
if (shape contentEquals ShapeND(1)) ndArray.getDouble(0) else null if (shape == ShapeND(1)) ndArray.getDouble(0) else null
// TODO rewrite // TODO rewrite
override fun diagonalEmbedding( override fun diagonalEmbedding(

View File

@ -23,7 +23,6 @@ import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.asArray import space.kscience.kmath.nd.asArray
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
@ -105,7 +104,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
protected abstract fun const(value: T): Constant<TT> protected abstract fun const(value: T): Constant<TT>
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1))
get(intArrayOf(0)) else null get(intArrayOf(0)) else null
/** /**

View File

@ -95,7 +95,7 @@ public open class DoubleTensorAlgebra :
override fun StructureND<Double>.valueOrNull(): Double? { override fun StructureND<Double>.valueOrNull(): Double? {
val dt = asDoubleTensor() val dt = asDoubleTensor()
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null return if (dt.shape == ShapeND(1)) dt.source[0] else null
} }
override fun StructureND<Double>.value(): Double = valueOrNull() override fun StructureND<Double>.value(): Double = valueOrNull()

View File

@ -89,7 +89,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
override fun StructureND<Int>.valueOrNull(): Int? { override fun StructureND<Int>.valueOrNull(): Int? {
val dt = asIntTensor() val dt = asIntTensor()
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null return if (dt.shape == ShapeND(1)) dt.source[0] else null
} }
override fun StructureND<Int>.value(): Int = valueOrNull() override fun StructureND<Int>.value(): Int = valueOrNull()
@ -387,7 +387,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
public fun stack(tensors: List<Tensor<Int>>): IntTensor { public fun stack(tensors: List<Tensor<Int>>): IntTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" } check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
val resShape = ShapeND(tensors.size) + shape val resShape = ShapeND(tensors.size) + shape
// val resBuffer: List<Int> = tensors.flatMap { // val resBuffer: List<Int> = tensors.flatMap {
// it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart) // it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart)

View File

@ -7,7 +7,6 @@ package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.nd.linearSize import space.kscience.kmath.nd.linearSize
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
@ -32,7 +31,7 @@ internal fun checkBufferShapeConsistency(shape: ShapeND, buffer: DoubleArray) =
@PublishedApi @PublishedApi
internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit = internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit =
check(a.shape contentEquals b.shape) { check(a.shape == b.shape) {
"Incompatible shapes ${a.shape} and ${b.shape} " "Incompatible shapes ${a.shape} and ${b.shape} "
} }

View File

@ -47,7 +47,7 @@ public fun DoubleTensorAlgebra.randomNormalLike(structure: WithShape, seed: Long
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor { public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" } check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
val resShape = ShapeND(tensors.size) + shape val resShape = ShapeND(tensors.size) + shape
// val resBuffer: List<Double> = tensors.flatMap { // val resBuffer: List<Double> = tensors.flatMap {
// it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart) // it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart)
@ -91,7 +91,7 @@ public fun DoubleTensorAlgebra.luPivot(
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { ): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape) checkSquareMatrix(luTensor.shape)
check( check(
luTensor.shape.first(luTensor.shape.size - 2) contentEquals pivotsTensor.shape.first(pivotsTensor.shape.size - 1) || luTensor.shape.first(luTensor.shape.size - 2) == pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
luTensor.shape.last() == pivotsTensor.shape.last() - 1 luTensor.shape.last() == pivotsTensor.shape.last() - 1
) { "Inappropriate shapes of input tensors" } ) { "Inappropriate shapes of input tensors" }

View File

@ -6,29 +6,31 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
import space.kscience.kmath.tensors.core.internal.broadcastShapes import space.kscience.kmath.tensors.core.internal.broadcastShapes
import space.kscience.kmath.tensors.core.internal.broadcastTensors import space.kscience.kmath.tensors.core.internal.broadcastTensors
import space.kscience.kmath.tensors.core.internal.broadcastTo import space.kscience.kmath.tensors.core.internal.broadcastTo
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
internal class TestBroadcasting { internal class TestBroadcasting {
@Test @Test
fun testBroadcastShapes() = DoubleTensorAlgebra { fun testBroadcastShapes() = DoubleTensorAlgebra {
assertTrue( assertEquals(
broadcastShapes( broadcastShapes(
listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1)) listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1))
) contentEquals ShapeND(1, 2, 3) ),
ShapeND(1, 2, 3)
) )
assertTrue( assertEquals(
broadcastShapes( broadcastShapes(
listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7)) listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7))
) contentEquals ShapeND(5, 6, 7) ),
ShapeND(5, 6, 7)
) )
} }
@ -38,7 +40,7 @@ internal class TestBroadcasting {
val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0)) val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val res = broadcastTo(tensor2, tensor1.shape) val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape contentEquals ShapeND(2, 3)) assertTrue(res.shape == ShapeND(2, 3))
assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
} }
@ -50,9 +52,9 @@ internal class TestBroadcasting {
val res = broadcastTensors(tensor1, tensor2, tensor3) val res = broadcastTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3)) assertEquals(res[0].shape, ShapeND(1, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(1, 2, 3)) assertEquals(res[1].shape, ShapeND(1, 2, 3))
assertTrue(res[2].shape contentEquals ShapeND(1, 2, 3)) assertEquals(res[2].shape, ShapeND(1, 2, 3))
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
@ -67,9 +69,9 @@ internal class TestBroadcasting {
val res = broadcastOuterTensors(tensor1, tensor2, tensor3) val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3)) assertEquals(res[0].shape, ShapeND(1, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(1, 1, 3)) assertEquals(res[1].shape, ShapeND(1, 1, 3))
assertTrue(res[2].shape contentEquals ShapeND(1, 1, 1)) assertEquals(res[2].shape, ShapeND(1, 1, 1))
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0)) assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0))
@ -84,9 +86,9 @@ internal class TestBroadcasting {
val res = broadcastOuterTensors(tensor1, tensor2, tensor3) val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals ShapeND(4, 2, 5, 3, 2, 3)) assertEquals(res[0].shape, ShapeND(4, 2, 5, 3, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(4, 2, 5, 3, 3, 3)) assertEquals(res[1].shape, ShapeND(4, 2, 5, 3, 3, 3))
assertTrue(res[2].shape contentEquals ShapeND(4, 2, 5, 3, 1, 1)) assertEquals(res[2].shape, ShapeND(4, 2, 5, 3, 1, 1))
} }
@Test @Test
@ -99,16 +101,16 @@ internal class TestBroadcasting {
val tensor31 = tensor3 - tensor1 val tensor31 = tensor3 - tensor1
val tensor32 = tensor3 - tensor2 val tensor32 = tensor3 - tensor2
assertTrue(tensor21.shape contentEquals ShapeND(2, 3)) assertEquals(tensor21.shape, ShapeND(2, 3))
assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue(tensor31.shape contentEquals ShapeND(1, 2, 3)) assertEquals(tensor31.shape, ShapeND(1, 2, 3))
assertTrue( assertTrue(
tensor31.source tensor31.source
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0) contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
) )
assertTrue(tensor32.shape contentEquals ShapeND(1, 1, 3)) assertEquals(tensor32.shape, ShapeND(1, 1, 3))
assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0)) assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0))
} }

View File

@ -6,7 +6,6 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.svd1d import space.kscience.kmath.tensors.core.internal.svd1d
import kotlin.math.abs import kotlin.math.abs
@ -113,8 +112,8 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val (q, r) = qr(tensor) val (q, r) = qr(tensor)
assertTrue { q.shape contentEquals shape } assertEquals(q.shape, shape)
assertTrue { r.shape contentEquals shape } assertEquals(r.shape, shape)
assertTrue((q matmul r).eq(tensor)) assertTrue((q matmul r).eq(tensor))
@ -133,9 +132,9 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val (p, l, u) = lu(tensor) val (p, l, u) = lu(tensor)
assertTrue { p.shape contentEquals shape } assertEquals(p.shape, shape)
assertTrue { l.shape contentEquals shape } assertEquals(l.shape, shape)
assertTrue { u.shape contentEquals shape } assertEquals(u.shape, shape)
assertTrue((p matmul tensor).eq(l matmul u)) assertTrue((p matmul tensor).eq(l matmul u))
} }
@ -157,7 +156,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val res = svd1d(tensor2) val res = svd1d(tensor2)
assertTrue(res.shape contentEquals ShapeND(2)) assertEquals(res.shape, ShapeND(2))
assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 } assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 }
assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 } assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 }
} }

View File

@ -6,7 +6,8 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.get
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.assertBufferEquals import space.kscience.kmath.testutils.assertBufferEquals
import kotlin.test.Test import kotlin.test.Test
@ -43,7 +44,7 @@ internal class TestDoubleTensorAlgebra {
val res = tensor.transposed(0, 0) val res = tensor.transposed(0, 0)
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0)) assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals ShapeND(1)) assertEquals(res.shape, ShapeND(1))
} }
@Test @Test
@ -52,7 +53,7 @@ internal class TestDoubleTensorAlgebra {
val res = tensor.transposed(1, 0) val res = tensor.transposed(1, 0)
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals ShapeND(2, 3)) assertEquals(res.shape, ShapeND(2, 3))
} }
@Test @Test
@ -62,9 +63,9 @@ internal class TestDoubleTensorAlgebra {
val res02 = tensor.transposed(-3, 2) val res02 = tensor.transposed(-3, 2)
val res12 = tensor.transposed() val res12 = tensor.transposed()
assertTrue(res01.shape contentEquals ShapeND(2, 1, 3)) assertEquals(res01.shape, ShapeND(2, 1, 3))
assertTrue(res02.shape contentEquals ShapeND(3, 2, 1)) assertEquals(res02.shape, ShapeND(3, 2, 1))
assertTrue(res12.shape contentEquals ShapeND(1, 3, 2)) assertEquals(res12.shape, ShapeND(1, 3, 2))
assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
@ -114,19 +115,19 @@ internal class TestDoubleTensorAlgebra {
val res12 = tensor1.dot(tensor2) val res12 = tensor1.dot(tensor2)
assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0)) assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0))
assertTrue(res12.shape contentEquals ShapeND(2)) assertEquals(res12.shape, ShapeND(2))
val res32 = tensor3.matmul(tensor2) val res32 = tensor3.matmul(tensor2)
assertTrue(res32.source contentEquals doubleArrayOf(-140.0)) assertTrue(res32.source contentEquals doubleArrayOf(-140.0))
assertTrue(res32.shape contentEquals ShapeND(1, 1)) assertEquals(res32.shape, ShapeND(1, 1))
val res22 = tensor2.dot(tensor2) val res22 = tensor2.dot(tensor2)
assertTrue(res22.source contentEquals doubleArrayOf(1400.0)) assertTrue(res22.source contentEquals doubleArrayOf(1400.0))
assertTrue(res22.shape contentEquals ShapeND(1)) assertEquals(res22.shape, ShapeND(1))
val res11 = tensor1.dot(tensor11) val res11 = tensor1.dot(tensor11)
assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0)) assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals ShapeND(2, 2)) assertEquals(res11.shape, ShapeND(2, 2))
val res45 = tensor4.matmul(tensor5) val res45 = tensor4.matmul(tensor5)
assertTrue( assertTrue(
@ -135,7 +136,7 @@ internal class TestDoubleTensorAlgebra {
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
) )
) )
assertTrue(res45.shape contentEquals ShapeND(2, 3, 3)) assertEquals(res45.shape, ShapeND(2, 3, 3))
} }
@Test @Test
@ -144,35 +145,35 @@ internal class TestDoubleTensorAlgebra {
val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor3 = zeros(ShapeND(2, 3, 4, 5)) val tensor3 = zeros(ShapeND(2, 3, 4, 5))
assertTrue( assertEquals(
diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals diagonalEmbedding(tensor3, 0, 3, 4).shape,
ShapeND(2, 3, 4, 5, 5) ShapeND(2, 3, 4, 5, 5)
) )
assertTrue( assertEquals(
diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals diagonalEmbedding(tensor3, 1, 3, 4).shape,
ShapeND(2, 3, 4, 6, 6) ShapeND(2, 3, 4, 6, 6)
) )
assertTrue( assertEquals(
diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals diagonalEmbedding(tensor3, 2, 0, 3).shape,
ShapeND(7, 2, 3, 7, 4) ShapeND(7, 2, 3, 7, 4)
) )
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0) val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
assertTrue(diagonal1.shape contentEquals ShapeND(3, 3)) assertEquals(diagonal1.shape, ShapeND(3, 3))
assertTrue( assertTrue(
diagonal1.source contentEquals diagonal1.source contentEquals
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0) doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)
) )
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0) val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
assertTrue(diagonal1Offset.shape contentEquals ShapeND(4, 4)) assertEquals(diagonal1Offset.shape, ShapeND(4, 4))
assertTrue( assertTrue(
diagonal1Offset.source contentEquals diagonal1Offset.source contentEquals
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0) doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)
) )
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2) val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
assertTrue(diagonal2.shape contentEquals ShapeND(4, 2, 4)) assertEquals(diagonal2.shape, ShapeND(4, 2, 4))
assertTrue( assertTrue(
diagonal2.source contentEquals diagonal2.source contentEquals
doubleArrayOf( doubleArrayOf(
@ -202,7 +203,7 @@ internal class TestDoubleTensorAlgebra {
val l = tensor.getTensor(0).map { it + 1.0 } val l = tensor.getTensor(0).map { it + 1.0 }
val r = tensor.getTensor(1).map { it - 1.0 } val r = tensor.getTensor(1).map { it - 1.0 }
val res = l + r val res = l + r
assertTrue { ShapeND(5, 5) contentEquals res.shape } assertEquals(ShapeND(5, 5), res.shape)
assertEquals(2.0, res[4, 4]) assertEquals(2.0, res[4, 4])
} }
} }

View File

@ -67,7 +67,7 @@ public open class ViktorFieldOpsND :
right: StructureND<Double>, right: StructureND<Double>,
transform: Float64Field.(Double, Double) -> Double, transform: Float64Field.(Double, Double) -> Double,
): ViktorStructureND { ): ViktorStructureND {
require(left.shape.contentEquals(right.shape)) require(left.shape == right.shape)
return F64Array(*left.shape.asArray()).apply { return F64Array(*left.shape.asArray()).apply {
ColumnStrides(left.shape).asSequence().forEach { index -> ColumnStrides(left.shape).asSequence().forEach { index ->
set(value = Float64Field.transform(left[index], right[index]), indices = index) set(value = Float64Field.transform(left[index], right[index]), indices = index)