More tests in kmath-for-real and fix for Double.plus

This commit is contained in:
Peter Klimai 2020-05-05 18:43:20 +03:00
parent f04eeac3b4
commit 568e720212
2 changed files with 145 additions and 4 deletions

View File

@ -57,7 +57,7 @@ operator fun Double.times(matrix: Matrix<Double>) = MatrixContext.real.produce(m
}
operator fun Double.plus(matrix: Matrix<Double>) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) {
row, col -> this * matrix[row, col]
row, col -> this + matrix[row, col]
}
operator fun Double.minus(matrix: Matrix<Double>) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) {

View File

@ -1,8 +1,9 @@
package scientific.kmath.real
import scientifik.kmath.real.average
import scientifik.kmath.real.realMatrix
import scientifik.kmath.real.sum
import scientifik.kmath.real.*
import scientifik.kmath.linear.VirtualMatrix
import scientifik.kmath.linear.build
import scientifik.kmath.structures.Matrix
import kotlin.test.Test
import kotlin.test.assertEquals
@ -13,4 +14,144 @@ class RealMatrixTest {
assertEquals(m.sum(), 900.0)
assertEquals(m.average(), 9.0)
}
@Test
fun testSequenceToMatrix() {
val m = Sequence<DoubleArray> {
listOf(
DoubleArray(10) { 10.0 },
DoubleArray(10) { 20.0 },
DoubleArray(10) { 30.0 }).iterator()
}.toMatrix()
assertEquals(m.sum(), 20.0 * 30)
}
@Test
fun testRepeatStackVertical() {
val matrix1 = Matrix.build<Double>(2, 3)(
1.0, 0.0, 0.0,
0.0, 1.0, 2.0
)
val matrix2 = Matrix.build<Double>(6, 3)(
1.0, 0.0, 0.0,
0.0, 1.0, 2.0,
1.0, 0.0, 0.0,
0.0, 1.0, 2.0,
1.0, 0.0, 0.0,
0.0, 1.0, 2.0
)
assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3))
}
@Test
fun testMatrixAndDouble() {
val matrix1 = Matrix.build<Double>(2, 3)(
1.0, 0.0, 3.0,
4.0, 6.0, 2.0
)
val matrix2 = (matrix1 * 2.5 + 1.0 - 2.0) / 2.0
val expectedResult = Matrix.build<Double>(2, 3)(
0.75, -0.5, 3.25,
4.5, 7.0, 2.0
)
assertEquals(matrix2, expectedResult)
}
@Test
fun testDoubleAndMatrix() {
val matrix1 = Matrix.build<Double>(2, 3)(
1.0, 0.0, 3.0,
4.0, 6.0, 2.0
)
val matrix2 = 20.0 - (10.0 + (5.0 * matrix1))
//val matrix2 = 10.0 + (5.0 * matrix1)
val expectedResult = Matrix.build<Double>(2, 3)(
5.0, 10.0, -5.0,
-10.0, -20.0, 0.0
)
assertEquals(matrix2, expectedResult)
}
@Test
fun testSquareAndPower() {
val matrix1 = Matrix.build<Double>(2, 3)(
-1.0, 0.0, 3.0,
4.0, -6.0, -2.0
)
val matrix2 = Matrix.build<Double>(2, 3)(
1.0, 0.0, 9.0,
16.0, 36.0, 4.0
)
val matrix3 = Matrix.build<Double>(2, 3)(
-1.0, 0.0, 27.0,
64.0, -216.0, -8.0
)
assertEquals(matrix1.square(), matrix2)
assertEquals(matrix1.pow(3), matrix3)
}
@Test
fun testTwoMatrixOperations() {
val matrix1 = Matrix.build<Double>(2, 3)(
-1.0, 0.0, 3.0,
4.0, -6.0, 7.0
)
val matrix2 = Matrix.build<Double>(2, 3)(
1.0, 0.0, 3.0,
4.0, 6.0, -2.0
)
val result = matrix1 * matrix2 + matrix1 - matrix2
val expectedResult = Matrix.build<Double>(2, 3)(
-3.0, 0.0, 9.0,
16.0, -48.0, -5.0
)
assertEquals(result, expectedResult)
}
@Test
fun testColumnOperations() {
val matrix1 = Matrix.build<Double>(2, 4)(
-1.0, 0.0, 3.0, 15.0,
4.0, -6.0, 7.0, -11.0
)
val matrix2 = Matrix.build<Double>(2, 5)(
-1.0, 0.0, 3.0, 15.0, -1.0,
4.0, -6.0, 7.0, -11.0, 4.0
)
val col1 = Matrix.build<Double>(2, 1)(0.0, -6.0)
val cols1to2 = Matrix.build<Double>(2, 2)(
0.0, 3.0,
-6.0, 7.0
)
assertEquals(matrix1.appendColumn { it[0] }, matrix2)
assertEquals(matrix1.extractColumn(1), col1)
assertEquals(matrix1.extractColumns(1..2), cols1to2)
assertEquals(matrix1.sumByColumn(), Matrix.build<Double>(4, 1)(3.0, -6.0, 10.0, 4.0))
assertEquals(matrix1.minByColumn(), Matrix.build<Double>(4, 1)(-1.0, -6.0, 3.0, -11.0))
assertEquals(matrix1.maxByColumn(), Matrix.build<Double>(4, 1)(4.0, 0.0, 7.0, 15.0))
assertEquals(matrix1.averageByColumn(), Matrix.build<Double>(4, 1)(1.5, -3.0, 5.0, 2.0))
}
@Test
fun testAllElementOperations() {
val matrix1 = Matrix.build<Double>(2, 4)(
-1.0, 0.0, 3.0, 15.0,
4.0, -6.0, 7.0, -11.0
)
assertEquals(matrix1.sum(), 11.0)
assertEquals(matrix1.min(), -11.0)
assertEquals(matrix1.max(), 15.0)
assertEquals(matrix1.average(), 1.375)
}
// fun printMatrix(m: Matrix<Double>) {
// for (row in 0 until m.shape[0]) {
// for (col in 0 until m.shape[1]) {
// print(m[row, col])
// print(" ")
// }
// println()
// }
// }
}