forked from kscience/kmath
More tests in kmath-for-real and fix for Double.plus
This commit is contained in:
parent
f04eeac3b4
commit
568e720212
@ -57,7 +57,7 @@ operator fun Double.times(matrix: Matrix<Double>) = MatrixContext.real.produce(m
|
|||||||
}
|
}
|
||||||
|
|
||||||
operator fun Double.plus(matrix: Matrix<Double>) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) {
|
operator fun Double.plus(matrix: Matrix<Double>) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) {
|
||||||
row, col -> this * matrix[row, col]
|
row, col -> this + matrix[row, col]
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Double.minus(matrix: Matrix<Double>) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) {
|
operator fun Double.minus(matrix: Matrix<Double>) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) {
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package scientific.kmath.real
|
package scientific.kmath.real
|
||||||
|
|
||||||
import scientifik.kmath.real.average
|
import scientifik.kmath.real.*
|
||||||
import scientifik.kmath.real.realMatrix
|
import scientifik.kmath.linear.VirtualMatrix
|
||||||
import scientifik.kmath.real.sum
|
import scientifik.kmath.linear.build
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -13,4 +14,144 @@ class RealMatrixTest {
|
|||||||
assertEquals(m.sum(), 900.0)
|
assertEquals(m.sum(), 900.0)
|
||||||
assertEquals(m.average(), 9.0)
|
assertEquals(m.average(), 9.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSequenceToMatrix() {
|
||||||
|
val m = Sequence<DoubleArray> {
|
||||||
|
listOf(
|
||||||
|
DoubleArray(10) { 10.0 },
|
||||||
|
DoubleArray(10) { 20.0 },
|
||||||
|
DoubleArray(10) { 30.0 }).iterator()
|
||||||
|
}.toMatrix()
|
||||||
|
assertEquals(m.sum(), 20.0 * 30)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testRepeatStackVertical() {
|
||||||
|
val matrix1 = Matrix.build<Double>(2, 3)(
|
||||||
|
1.0, 0.0, 0.0,
|
||||||
|
0.0, 1.0, 2.0
|
||||||
|
)
|
||||||
|
val matrix2 = Matrix.build<Double>(6, 3)(
|
||||||
|
1.0, 0.0, 0.0,
|
||||||
|
0.0, 1.0, 2.0,
|
||||||
|
1.0, 0.0, 0.0,
|
||||||
|
0.0, 1.0, 2.0,
|
||||||
|
1.0, 0.0, 0.0,
|
||||||
|
0.0, 1.0, 2.0
|
||||||
|
)
|
||||||
|
assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMatrixAndDouble() {
|
||||||
|
val matrix1 = Matrix.build<Double>(2, 3)(
|
||||||
|
1.0, 0.0, 3.0,
|
||||||
|
4.0, 6.0, 2.0
|
||||||
|
)
|
||||||
|
val matrix2 = (matrix1 * 2.5 + 1.0 - 2.0) / 2.0
|
||||||
|
val expectedResult = Matrix.build<Double>(2, 3)(
|
||||||
|
0.75, -0.5, 3.25,
|
||||||
|
4.5, 7.0, 2.0
|
||||||
|
)
|
||||||
|
assertEquals(matrix2, expectedResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDoubleAndMatrix() {
|
||||||
|
val matrix1 = Matrix.build<Double>(2, 3)(
|
||||||
|
1.0, 0.0, 3.0,
|
||||||
|
4.0, 6.0, 2.0
|
||||||
|
)
|
||||||
|
val matrix2 = 20.0 - (10.0 + (5.0 * matrix1))
|
||||||
|
//val matrix2 = 10.0 + (5.0 * matrix1)
|
||||||
|
val expectedResult = Matrix.build<Double>(2, 3)(
|
||||||
|
5.0, 10.0, -5.0,
|
||||||
|
-10.0, -20.0, 0.0
|
||||||
|
)
|
||||||
|
assertEquals(matrix2, expectedResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSquareAndPower() {
|
||||||
|
val matrix1 = Matrix.build<Double>(2, 3)(
|
||||||
|
-1.0, 0.0, 3.0,
|
||||||
|
4.0, -6.0, -2.0
|
||||||
|
)
|
||||||
|
val matrix2 = Matrix.build<Double>(2, 3)(
|
||||||
|
1.0, 0.0, 9.0,
|
||||||
|
16.0, 36.0, 4.0
|
||||||
|
)
|
||||||
|
val matrix3 = Matrix.build<Double>(2, 3)(
|
||||||
|
-1.0, 0.0, 27.0,
|
||||||
|
64.0, -216.0, -8.0
|
||||||
|
)
|
||||||
|
assertEquals(matrix1.square(), matrix2)
|
||||||
|
assertEquals(matrix1.pow(3), matrix3)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTwoMatrixOperations() {
|
||||||
|
val matrix1 = Matrix.build<Double>(2, 3)(
|
||||||
|
-1.0, 0.0, 3.0,
|
||||||
|
4.0, -6.0, 7.0
|
||||||
|
)
|
||||||
|
val matrix2 = Matrix.build<Double>(2, 3)(
|
||||||
|
1.0, 0.0, 3.0,
|
||||||
|
4.0, 6.0, -2.0
|
||||||
|
)
|
||||||
|
val result = matrix1 * matrix2 + matrix1 - matrix2
|
||||||
|
val expectedResult = Matrix.build<Double>(2, 3)(
|
||||||
|
-3.0, 0.0, 9.0,
|
||||||
|
16.0, -48.0, -5.0
|
||||||
|
)
|
||||||
|
assertEquals(result, expectedResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testColumnOperations() {
|
||||||
|
val matrix1 = Matrix.build<Double>(2, 4)(
|
||||||
|
-1.0, 0.0, 3.0, 15.0,
|
||||||
|
4.0, -6.0, 7.0, -11.0
|
||||||
|
)
|
||||||
|
val matrix2 = Matrix.build<Double>(2, 5)(
|
||||||
|
-1.0, 0.0, 3.0, 15.0, -1.0,
|
||||||
|
4.0, -6.0, 7.0, -11.0, 4.0
|
||||||
|
)
|
||||||
|
val col1 = Matrix.build<Double>(2, 1)(0.0, -6.0)
|
||||||
|
val cols1to2 = Matrix.build<Double>(2, 2)(
|
||||||
|
0.0, 3.0,
|
||||||
|
-6.0, 7.0
|
||||||
|
)
|
||||||
|
assertEquals(matrix1.appendColumn { it[0] }, matrix2)
|
||||||
|
assertEquals(matrix1.extractColumn(1), col1)
|
||||||
|
assertEquals(matrix1.extractColumns(1..2), cols1to2)
|
||||||
|
assertEquals(matrix1.sumByColumn(), Matrix.build<Double>(4, 1)(3.0, -6.0, 10.0, 4.0))
|
||||||
|
assertEquals(matrix1.minByColumn(), Matrix.build<Double>(4, 1)(-1.0, -6.0, 3.0, -11.0))
|
||||||
|
assertEquals(matrix1.maxByColumn(), Matrix.build<Double>(4, 1)(4.0, 0.0, 7.0, 15.0))
|
||||||
|
assertEquals(matrix1.averageByColumn(), Matrix.build<Double>(4, 1)(1.5, -3.0, 5.0, 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAllElementOperations() {
|
||||||
|
val matrix1 = Matrix.build<Double>(2, 4)(
|
||||||
|
-1.0, 0.0, 3.0, 15.0,
|
||||||
|
4.0, -6.0, 7.0, -11.0
|
||||||
|
)
|
||||||
|
assertEquals(matrix1.sum(), 11.0)
|
||||||
|
assertEquals(matrix1.min(), -11.0)
|
||||||
|
assertEquals(matrix1.max(), 15.0)
|
||||||
|
assertEquals(matrix1.average(), 1.375)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fun printMatrix(m: Matrix<Double>) {
|
||||||
|
// for (row in 0 until m.shape[0]) {
|
||||||
|
// for (col in 0 until m.shape[1]) {
|
||||||
|
// print(m[row, col])
|
||||||
|
// print(" ")
|
||||||
|
// }
|
||||||
|
// println()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user