From 568e72021200427b09320339cd272fe67646d3bf Mon Sep 17 00:00:00 2001 From: Peter Klimai Date: Tue, 5 May 2020 18:43:20 +0300 Subject: [PATCH] More tests in kmath-for-real and fix for Double.plus --- .../kmath/real/DoubleMatrixOperations.kt | 2 +- .../scientific.kmath.real/RealMatrixTest.kt | 147 +++++++++++++++++- 2 files changed, 145 insertions(+), 4 deletions(-) diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/DoubleMatrixOperations.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/DoubleMatrixOperations.kt index d9c6b5eab..7eeba3031 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/DoubleMatrixOperations.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/DoubleMatrixOperations.kt @@ -57,7 +57,7 @@ operator fun Double.times(matrix: Matrix) = MatrixContext.real.produce(m } operator fun Double.plus(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { - row, col -> this * matrix[row, col] + row, col -> this + matrix[row, col] } operator fun Double.minus(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { diff --git a/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt b/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt index 808794442..31b8b5252 100644 --- a/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt @@ -1,8 +1,9 @@ package scientific.kmath.real -import scientifik.kmath.real.average -import scientifik.kmath.real.realMatrix -import scientifik.kmath.real.sum +import scientifik.kmath.real.* +import scientifik.kmath.linear.VirtualMatrix +import scientifik.kmath.linear.build +import scientifik.kmath.structures.Matrix import kotlin.test.Test import kotlin.test.assertEquals @@ -13,4 +14,144 @@ class RealMatrixTest { assertEquals(m.sum(), 900.0) assertEquals(m.average(), 9.0) } + + @Test + fun testSequenceToMatrix() { + val m = Sequence { + listOf( + DoubleArray(10) { 10.0 }, + DoubleArray(10) { 20.0 }, + DoubleArray(10) { 30.0 }).iterator() + }.toMatrix() + assertEquals(m.sum(), 20.0 * 30) + } + + @Test + fun testRepeatStackVertical() { + val matrix1 = Matrix.build(2, 3)( + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0 + ) + val matrix2 = Matrix.build(6, 3)( + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0 + ) + assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3)) + } + + @Test + fun testMatrixAndDouble() { + val matrix1 = Matrix.build(2, 3)( + 1.0, 0.0, 3.0, + 4.0, 6.0, 2.0 + ) + val matrix2 = (matrix1 * 2.5 + 1.0 - 2.0) / 2.0 + val expectedResult = Matrix.build(2, 3)( + 0.75, -0.5, 3.25, + 4.5, 7.0, 2.0 + ) + assertEquals(matrix2, expectedResult) + } + + @Test + fun testDoubleAndMatrix() { + val matrix1 = Matrix.build(2, 3)( + 1.0, 0.0, 3.0, + 4.0, 6.0, 2.0 + ) + val matrix2 = 20.0 - (10.0 + (5.0 * matrix1)) + //val matrix2 = 10.0 + (5.0 * matrix1) + val expectedResult = Matrix.build(2, 3)( + 5.0, 10.0, -5.0, + -10.0, -20.0, 0.0 + ) + assertEquals(matrix2, expectedResult) + } + + @Test + fun testSquareAndPower() { + val matrix1 = Matrix.build(2, 3)( + -1.0, 0.0, 3.0, + 4.0, -6.0, -2.0 + ) + val matrix2 = Matrix.build(2, 3)( + 1.0, 0.0, 9.0, + 16.0, 36.0, 4.0 + ) + val matrix3 = Matrix.build(2, 3)( + -1.0, 0.0, 27.0, + 64.0, -216.0, -8.0 + ) + assertEquals(matrix1.square(), matrix2) + assertEquals(matrix1.pow(3), matrix3) + } + + @Test + fun testTwoMatrixOperations() { + val matrix1 = Matrix.build(2, 3)( + -1.0, 0.0, 3.0, + 4.0, -6.0, 7.0 + ) + val matrix2 = Matrix.build(2, 3)( + 1.0, 0.0, 3.0, + 4.0, 6.0, -2.0 + ) + val result = matrix1 * matrix2 + matrix1 - matrix2 + val expectedResult = Matrix.build(2, 3)( + -3.0, 0.0, 9.0, + 16.0, -48.0, -5.0 + ) + assertEquals(result, expectedResult) + } + + @Test + fun testColumnOperations() { + val matrix1 = Matrix.build(2, 4)( + -1.0, 0.0, 3.0, 15.0, + 4.0, -6.0, 7.0, -11.0 + ) + val matrix2 = Matrix.build(2, 5)( + -1.0, 0.0, 3.0, 15.0, -1.0, + 4.0, -6.0, 7.0, -11.0, 4.0 + ) + val col1 = Matrix.build(2, 1)(0.0, -6.0) + val cols1to2 = Matrix.build(2, 2)( + 0.0, 3.0, + -6.0, 7.0 + ) + assertEquals(matrix1.appendColumn { it[0] }, matrix2) + assertEquals(matrix1.extractColumn(1), col1) + assertEquals(matrix1.extractColumns(1..2), cols1to2) + assertEquals(matrix1.sumByColumn(), Matrix.build(4, 1)(3.0, -6.0, 10.0, 4.0)) + assertEquals(matrix1.minByColumn(), Matrix.build(4, 1)(-1.0, -6.0, 3.0, -11.0)) + assertEquals(matrix1.maxByColumn(), Matrix.build(4, 1)(4.0, 0.0, 7.0, 15.0)) + assertEquals(matrix1.averageByColumn(), Matrix.build(4, 1)(1.5, -3.0, 5.0, 2.0)) + } + + @Test + fun testAllElementOperations() { + val matrix1 = Matrix.build(2, 4)( + -1.0, 0.0, 3.0, 15.0, + 4.0, -6.0, 7.0, -11.0 + ) + assertEquals(matrix1.sum(), 11.0) + assertEquals(matrix1.min(), -11.0) + assertEquals(matrix1.max(), 15.0) + assertEquals(matrix1.average(), 1.375) + } + +// fun printMatrix(m: Matrix) { +// for (row in 0 until m.shape[0]) { +// for (col in 0 until m.shape[1]) { +// print(m[row, col]) +// print(" ") +// } +// println() +// } +// } + }