Fix #532 by making ShapeND
a non-value class
#522
@ -35,7 +35,7 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND<Double, Float64
|
||||
@OptIn(PerformancePitfall::class)
|
||||
private val StructureND<Double>.buffer: Float64Buffer
|
||||
get() = when {
|
||||
!shape.contentEquals(this@StreamDoubleFieldND.shape) -> throw ShapeMismatchException(
|
||||
shape != this@StreamDoubleFieldND.shape -> throw ShapeMismatchException(
|
||||
this@StreamDoubleFieldND.shape,
|
||||
shape
|
||||
)
|
||||
|
@ -6,7 +6,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
@ -62,7 +61,7 @@ fun main() {
|
||||
// figure out MSE of approximation
|
||||
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
|
||||
require(yTrue.shape.size == 1)
|
||||
require(yTrue.shape contentEquals yPred.shape)
|
||||
require(yTrue.shape == yPred.shape)
|
||||
|
||||
val diff = yTrue - yPred
|
||||
return sqrt(diff.dot(diff)).value()
|
||||
|
@ -6,7 +6,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.asIterable
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.*
|
||||
@ -94,7 +93,7 @@ class Dense(
|
||||
|
||||
// simple accuracy equal to the proportion of correct answers
|
||||
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
|
||||
check(yPred.shape contentEquals yTrue.shape)
|
||||
check(yPred.shape == yTrue.shape)
|
||||
val n = yPred.shape[0]
|
||||
var correctCnt = 0
|
||||
for (i in 0 until n) {
|
||||
|
@ -32,12 +32,12 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
|
||||
}
|
||||
|
||||
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
|
@ -33,12 +33,12 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
|
||||
*/
|
||||
@PerformancePitfall("Very slow on remote execution algebras")
|
||||
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> {
|
||||
require(left.shape.contentEquals(right.shape)) {
|
||||
require(left.shape == right.shape) {
|
||||
"Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}"
|
||||
}
|
||||
return structureND(left.shape) { index ->
|
||||
|
@ -106,10 +106,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is ColumnStrides) return false
|
||||
return shape.contentEquals(other.shape)
|
||||
return shape == other.shape
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = shape.contentHashCode()
|
||||
override fun hashCode(): Int = shape.hashCode()
|
||||
|
||||
|
||||
public companion object
|
||||
@ -156,10 +156,10 @@ public class RowStrides(override val shape: ShapeND) : Strides() {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is RowStrides) return false
|
||||
return shape.contentEquals(other.shape)
|
||||
return shape == other.shape
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = shape.contentHashCode()
|
||||
override fun hashCode(): Int = shape.hashCode()
|
||||
|
||||
public companion object
|
||||
|
||||
|
@ -11,19 +11,30 @@ import kotlin.jvm.JvmInline
|
||||
/**
|
||||
* A read-only ND shape
|
||||
*/
|
||||
@JvmInline
|
||||
public value class ShapeND(@PublishedApi internal val array: IntArray) {
|
||||
public class ShapeND(@PublishedApi internal val array: IntArray) {
|
||||
public val size: Int get() = array.size
|
||||
public operator fun get(index: Int): Int = array[index]
|
||||
override fun toString(): String = array.contentToString()
|
||||
override fun hashCode(): Int = array.contentHashCode()
|
||||
override fun equals(other: Any?): Boolean = other is ShapeND && array.contentEquals(other.array)
|
||||
}
|
||||
|
||||
public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block)
|
||||
|
||||
public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block)
|
||||
|
||||
@Deprecated(
|
||||
message = "ShapeND is made a usual class with correct `equals`. Use it instead of the `contentEquals`.",
|
||||
replaceWith = ReplaceWith("this == other"),
|
||||
level = DeprecationLevel.WARNING,
|
||||
)
|
||||
public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array)
|
||||
|
||||
@Deprecated(
|
||||
message = "ShapeND is made a usual class with correct `hashCode`. Use it instead of the `contentHashCode`.",
|
||||
replaceWith = ReplaceWith("this.hashCode()"),
|
||||
level = DeprecationLevel.WARNING,
|
||||
)
|
||||
public fun ShapeND.contentHashCode(): Int = array.contentHashCode()
|
||||
|
||||
public val ShapeND.indices: IntRange get() = array.indices
|
||||
|
@ -48,12 +48,12 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
|
@ -92,7 +92,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||
require(left.shape == right.shape) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||
val leftArray = left.asMultik().array
|
||||
val rightArray = right.asMultik().array
|
||||
val data = initMemoryView<T>(leftArray.size, dataType)
|
||||
@ -124,7 +124,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) {
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1)) {
|
||||
get(intArrayOf(0))
|
||||
} else null
|
||||
|
||||
@ -224,7 +224,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
}
|
||||
|
||||
val mt = asMultik().array
|
||||
return if (ShapeND(mt.shape).contentEquals(shape)) {
|
||||
return if (ShapeND(mt.shape) == shape) {
|
||||
mt
|
||||
} else {
|
||||
@OptIn(UnsafeKMathAPI::class)
|
||||
|
@ -63,7 +63,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
||||
right: StructureND<T>,
|
||||
transform: C.(T, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
||||
require(left.shape == right.shape) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
||||
val new = Nd4j.create(*left.shape.asArray()).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
|
||||
return new
|
||||
|
@ -49,7 +49,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
|
||||
require(left.shape.contentEquals(right.shape))
|
||||
require(left.shape == right.shape)
|
||||
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
|
||||
}
|
||||
|
||||
@ -203,7 +203,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, Float64Field>
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.valueOrNull(): Double? =
|
||||
if (shape contentEquals ShapeND(1)) ndArray.getDouble(0) else null
|
||||
if (shape == ShapeND(1)) ndArray.getDouble(0) else null
|
||||
|
||||
// TODO rewrite
|
||||
override fun diagonalEmbedding(
|
||||
|
@ -23,7 +23,6 @@ import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.asArray
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
@ -105,7 +104,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
||||
protected abstract fun const(value: T): Constant<TT>
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1))
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1))
|
||||
get(intArrayOf(0)) else null
|
||||
|
||||
/**
|
||||
|
@ -95,7 +95,7 @@ public open class DoubleTensorAlgebra :
|
||||
|
||||
override fun StructureND<Double>.valueOrNull(): Double? {
|
||||
val dt = asDoubleTensor()
|
||||
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
|
||||
return if (dt.shape == ShapeND(1)) dt.source[0] else null
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.value(): Double = valueOrNull()
|
||||
|
@ -89,7 +89,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
|
||||
|
||||
override fun StructureND<Int>.valueOrNull(): Int? {
|
||||
val dt = asIntTensor()
|
||||
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
|
||||
return if (dt.shape == ShapeND(1)) dt.source[0] else null
|
||||
}
|
||||
|
||||
override fun StructureND<Int>.value(): Int = valueOrNull()
|
||||
@ -387,7 +387,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
|
||||
public fun stack(tensors: List<Tensor<Int>>): IntTensor {
|
||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||
val shape = tensors[0].shape
|
||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
||||
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
|
||||
val resShape = ShapeND(tensors.size) + shape
|
||||
// val resBuffer: List<Int> = tensors.flatMap {
|
||||
// it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart)
|
||||
|
@ -7,7 +7,6 @@ package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.nd.linearSize
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
@ -32,7 +31,7 @@ internal fun checkBufferShapeConsistency(shape: ShapeND, buffer: DoubleArray) =
|
||||
|
||||
@PublishedApi
|
||||
internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
check(a.shape == b.shape) {
|
||||
"Incompatible shapes ${a.shape} and ${b.shape} "
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ public fun DoubleTensorAlgebra.randomNormalLike(structure: WithShape, seed: Long
|
||||
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
|
||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||
val shape = tensors[0].shape
|
||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
||||
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
|
||||
val resShape = ShapeND(tensors.size) + shape
|
||||
// val resBuffer: List<Double> = tensors.flatMap {
|
||||
// it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart)
|
||||
@ -91,7 +91,7 @@ public fun DoubleTensorAlgebra.luPivot(
|
||||
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(luTensor.shape)
|
||||
check(
|
||||
luTensor.shape.first(luTensor.shape.size - 2) contentEquals pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
|
||||
luTensor.shape.first(luTensor.shape.size - 2) == pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
|
||||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
|
||||
) { "Inappropriate shapes of input tensors" }
|
||||
|
||||
|
@ -6,29 +6,31 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastShapes
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastTensors
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastTo
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal class TestBroadcasting {
|
||||
|
||||
@Test
|
||||
fun testBroadcastShapes() = DoubleTensorAlgebra {
|
||||
assertTrue(
|
||||
assertEquals(
|
||||
broadcastShapes(
|
||||
listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1))
|
||||
) contentEquals ShapeND(1, 2, 3)
|
||||
),
|
||||
ShapeND(1, 2, 3)
|
||||
)
|
||||
|
||||
assertTrue(
|
||||
assertEquals(
|
||||
broadcastShapes(
|
||||
listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7))
|
||||
) contentEquals ShapeND(5, 6, 7)
|
||||
),
|
||||
ShapeND(5, 6, 7)
|
||||
)
|
||||
}
|
||||
|
||||
@ -38,7 +40,7 @@ internal class TestBroadcasting {
|
||||
val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
|
||||
val res = broadcastTo(tensor2, tensor1.shape)
|
||||
assertTrue(res.shape contentEquals ShapeND(2, 3))
|
||||
assertTrue(res.shape == ShapeND(2, 3))
|
||||
assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
}
|
||||
|
||||
@ -50,9 +52,9 @@ internal class TestBroadcasting {
|
||||
|
||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals ShapeND(1, 2, 3))
|
||||
assertTrue(res[2].shape contentEquals ShapeND(1, 2, 3))
|
||||
assertEquals(res[0].shape, ShapeND(1, 2, 3))
|
||||
assertEquals(res[1].shape, ShapeND(1, 2, 3))
|
||||
assertEquals(res[2].shape, ShapeND(1, 2, 3))
|
||||
|
||||
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
@ -67,9 +69,9 @@ internal class TestBroadcasting {
|
||||
|
||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals ShapeND(1, 1, 3))
|
||||
assertTrue(res[2].shape contentEquals ShapeND(1, 1, 1))
|
||||
assertEquals(res[0].shape, ShapeND(1, 2, 3))
|
||||
assertEquals(res[1].shape, ShapeND(1, 1, 3))
|
||||
assertEquals(res[2].shape, ShapeND(1, 1, 1))
|
||||
|
||||
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0))
|
||||
@ -84,9 +86,9 @@ internal class TestBroadcasting {
|
||||
|
||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals ShapeND(4, 2, 5, 3, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals ShapeND(4, 2, 5, 3, 3, 3))
|
||||
assertTrue(res[2].shape contentEquals ShapeND(4, 2, 5, 3, 1, 1))
|
||||
assertEquals(res[0].shape, ShapeND(4, 2, 5, 3, 2, 3))
|
||||
assertEquals(res[1].shape, ShapeND(4, 2, 5, 3, 3, 3))
|
||||
assertEquals(res[2].shape, ShapeND(4, 2, 5, 3, 1, 1))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -99,16 +101,16 @@ internal class TestBroadcasting {
|
||||
val tensor31 = tensor3 - tensor1
|
||||
val tensor32 = tensor3 - tensor2
|
||||
|
||||
assertTrue(tensor21.shape contentEquals ShapeND(2, 3))
|
||||
assertEquals(tensor21.shape, ShapeND(2, 3))
|
||||
assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
|
||||
assertTrue(tensor31.shape contentEquals ShapeND(1, 2, 3))
|
||||
assertEquals(tensor31.shape, ShapeND(1, 2, 3))
|
||||
assertTrue(
|
||||
tensor31.source
|
||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
|
||||
)
|
||||
|
||||
assertTrue(tensor32.shape contentEquals ShapeND(1, 1, 3))
|
||||
assertEquals(tensor32.shape, ShapeND(1, 1, 3))
|
||||
assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,6 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.internal.svd1d
|
||||
import kotlin.math.abs
|
||||
@ -113,8 +112,8 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
val (q, r) = qr(tensor)
|
||||
|
||||
assertTrue { q.shape contentEquals shape }
|
||||
assertTrue { r.shape contentEquals shape }
|
||||
assertEquals(q.shape, shape)
|
||||
assertEquals(r.shape, shape)
|
||||
|
||||
assertTrue((q matmul r).eq(tensor))
|
||||
|
||||
@ -133,9 +132,9 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
val (p, l, u) = lu(tensor)
|
||||
|
||||
assertTrue { p.shape contentEquals shape }
|
||||
assertTrue { l.shape contentEquals shape }
|
||||
assertTrue { u.shape contentEquals shape }
|
||||
assertEquals(p.shape, shape)
|
||||
assertEquals(l.shape, shape)
|
||||
assertEquals(u.shape, shape)
|
||||
|
||||
assertTrue((p matmul tensor).eq(l matmul u))
|
||||
}
|
||||
@ -157,7 +156,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
val res = svd1d(tensor2)
|
||||
|
||||
assertTrue(res.shape contentEquals ShapeND(2))
|
||||
assertEquals(res.shape, ShapeND(2))
|
||||
assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 }
|
||||
assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 }
|
||||
}
|
||||
|
@ -6,7 +6,8 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.get
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.testutils.assertBufferEquals
|
||||
import kotlin.test.Test
|
||||
@ -43,7 +44,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
val res = tensor.transposed(0, 0)
|
||||
|
||||
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0))
|
||||
assertTrue(res.shape contentEquals ShapeND(1))
|
||||
assertEquals(res.shape, ShapeND(1))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -52,7 +53,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
val res = tensor.transposed(1, 0)
|
||||
|
||||
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||
assertTrue(res.shape contentEquals ShapeND(2, 3))
|
||||
assertEquals(res.shape, ShapeND(2, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -62,9 +63,9 @@ internal class TestDoubleTensorAlgebra {
|
||||
val res02 = tensor.transposed(-3, 2)
|
||||
val res12 = tensor.transposed()
|
||||
|
||||
assertTrue(res01.shape contentEquals ShapeND(2, 1, 3))
|
||||
assertTrue(res02.shape contentEquals ShapeND(3, 2, 1))
|
||||
assertTrue(res12.shape contentEquals ShapeND(1, 3, 2))
|
||||
assertEquals(res01.shape, ShapeND(2, 1, 3))
|
||||
assertEquals(res02.shape, ShapeND(3, 2, 1))
|
||||
assertEquals(res12.shape, ShapeND(1, 3, 2))
|
||||
|
||||
assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
@ -114,19 +115,19 @@ internal class TestDoubleTensorAlgebra {
|
||||
|
||||
val res12 = tensor1.dot(tensor2)
|
||||
assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0))
|
||||
assertTrue(res12.shape contentEquals ShapeND(2))
|
||||
assertEquals(res12.shape, ShapeND(2))
|
||||
|
||||
val res32 = tensor3.matmul(tensor2)
|
||||
assertTrue(res32.source contentEquals doubleArrayOf(-140.0))
|
||||
assertTrue(res32.shape contentEquals ShapeND(1, 1))
|
||||
assertEquals(res32.shape, ShapeND(1, 1))
|
||||
|
||||
val res22 = tensor2.dot(tensor2)
|
||||
assertTrue(res22.source contentEquals doubleArrayOf(1400.0))
|
||||
assertTrue(res22.shape contentEquals ShapeND(1))
|
||||
assertEquals(res22.shape, ShapeND(1))
|
||||
|
||||
val res11 = tensor1.dot(tensor11)
|
||||
assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
||||
assertTrue(res11.shape contentEquals ShapeND(2, 2))
|
||||
assertEquals(res11.shape, ShapeND(2, 2))
|
||||
|
||||
val res45 = tensor4.matmul(tensor5)
|
||||
assertTrue(
|
||||
@ -135,7 +136,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||
)
|
||||
)
|
||||
assertTrue(res45.shape contentEquals ShapeND(2, 3, 3))
|
||||
assertEquals(res45.shape, ShapeND(2, 3, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -144,35 +145,35 @@ internal class TestDoubleTensorAlgebra {
|
||||
val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor3 = zeros(ShapeND(2, 3, 4, 5))
|
||||
|
||||
assertTrue(
|
||||
diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
|
||||
ShapeND(2, 3, 4, 5, 5)
|
||||
assertEquals(
|
||||
diagonalEmbedding(tensor3, 0, 3, 4).shape,
|
||||
ShapeND(2, 3, 4, 5, 5)
|
||||
)
|
||||
assertTrue(
|
||||
diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
|
||||
ShapeND(2, 3, 4, 6, 6)
|
||||
assertEquals(
|
||||
diagonalEmbedding(tensor3, 1, 3, 4).shape,
|
||||
ShapeND(2, 3, 4, 6, 6)
|
||||
)
|
||||
assertTrue(
|
||||
diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
|
||||
ShapeND(7, 2, 3, 7, 4)
|
||||
assertEquals(
|
||||
diagonalEmbedding(tensor3, 2, 0, 3).shape,
|
||||
ShapeND(7, 2, 3, 7, 4)
|
||||
)
|
||||
|
||||
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
|
||||
assertTrue(diagonal1.shape contentEquals ShapeND(3, 3))
|
||||
assertEquals(diagonal1.shape, ShapeND(3, 3))
|
||||
assertTrue(
|
||||
diagonal1.source contentEquals
|
||||
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)
|
||||
)
|
||||
|
||||
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
|
||||
assertTrue(diagonal1Offset.shape contentEquals ShapeND(4, 4))
|
||||
assertEquals(diagonal1Offset.shape, ShapeND(4, 4))
|
||||
assertTrue(
|
||||
diagonal1Offset.source contentEquals
|
||||
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)
|
||||
)
|
||||
|
||||
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
|
||||
assertTrue(diagonal2.shape contentEquals ShapeND(4, 2, 4))
|
||||
assertEquals(diagonal2.shape, ShapeND(4, 2, 4))
|
||||
assertTrue(
|
||||
diagonal2.source contentEquals
|
||||
doubleArrayOf(
|
||||
@ -202,7 +203,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
val l = tensor.getTensor(0).map { it + 1.0 }
|
||||
val r = tensor.getTensor(1).map { it - 1.0 }
|
||||
val res = l + r
|
||||
assertTrue { ShapeND(5, 5) contentEquals res.shape }
|
||||
assertEquals(ShapeND(5, 5), res.shape)
|
||||
assertEquals(2.0, res[4, 4])
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ public open class ViktorFieldOpsND :
|
||||
right: StructureND<Double>,
|
||||
transform: Float64Field.(Double, Double) -> Double,
|
||||
): ViktorStructureND {
|
||||
require(left.shape.contentEquals(right.shape))
|
||||
require(left.shape == right.shape)
|
||||
return F64Array(*left.shape.asArray()).apply {
|
||||
ColumnStrides(left.shape).asSequence().forEach { index ->
|
||||
set(value = Float64Field.transform(left[index], right[index]), indices = index)
|
||||
|
Loading…
Reference in New Issue
Block a user