forked from kscience/kmath
fix dot bug introduced in the last refactor. Add test for parallel linear algebra.
This commit is contained in:
parent
79642a869d
commit
41a325d428
@ -9,6 +9,7 @@
|
||||
- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
|
||||
- Explicit `mutableStructureND` builders for mutable structures.
|
||||
- `Buffer.asList()` zero-copy transformation.
|
||||
- Wasm support.
|
||||
- Parallel implementation of `LinearSpace` for Float64
|
||||
- Parallel buffer factories
|
||||
|
||||
|
@ -54,7 +54,7 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
|
||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||
val rows = this@dot.rows.map { it.linearize() }
|
||||
val columns = other.columns.map { it.linearize() }
|
||||
val indices = 0 until this.rowNum
|
||||
val indices = 0 until this.colNum
|
||||
return buildMatrix(rowNum, other.colNum) { i, j ->
|
||||
val r = rows[i]
|
||||
val c = columns[j]
|
||||
@ -70,7 +70,7 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
|
||||
override fun Matrix<Double>.dot(vector: Point<Double>): Float64Buffer {
|
||||
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||
val rows = this@dot.rows.map { it.linearize() }
|
||||
val indices = 0 until this.rowNum
|
||||
val indices = 0 until this.colNum
|
||||
return Float64Buffer(rowNum) { i ->
|
||||
val r = rows[i]
|
||||
var res = 0.0
|
||||
|
@ -28,13 +28,13 @@ class DoubleLUSolverTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDecomposition() = Double.algebra.linearSpace.run {
|
||||
fun testDecomposition() = with(Double.algebra.linearSpace){
|
||||
val matrix = matrix(2, 2)(
|
||||
3.0, 1.0,
|
||||
2.0, 3.0
|
||||
)
|
||||
|
||||
val lup = lup(matrix)
|
||||
val lup = elementAlgebra.lup(matrix)
|
||||
|
||||
//Check determinant
|
||||
// assertEquals(7.0, lup.determinant)
|
||||
|
@ -8,6 +8,7 @@ package space.kscience.kmath.linear
|
||||
import space.kscience.kmath.PerformancePitfall
|
||||
import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.Float64Field
|
||||
import space.kscience.kmath.operations.algebra
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
@ -58,7 +59,7 @@ class MatrixTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun test2DDot() = Double.algebra.linearSpace.run {
|
||||
fun test2DDot() = Float64Field.linearSpace {
|
||||
val firstMatrix = buildMatrix(2, 3) { i, j -> (i + j).toDouble() }
|
||||
val secondMatrix = buildMatrix(3, 2) { i, j -> (i + j).toDouble() }
|
||||
|
||||
@ -70,6 +71,5 @@ class MatrixTest {
|
||||
assertEquals(8.0, result[0, 1])
|
||||
assertEquals(8.0, result[1, 0])
|
||||
assertEquals(14.0, result[1, 1])
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
|
||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||
val rows = this@dot.rows.map { it.linearize() }
|
||||
val columns = other.columns.map { it.linearize() }
|
||||
val indices = 0 until this.rowNum
|
||||
val indices = 0 until this.colNum
|
||||
return buildMatrix(rowNum, other.colNum) { i, j ->
|
||||
val r = rows[i]
|
||||
val c = columns[j]
|
||||
@ -85,7 +85,7 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
|
||||
override fun Matrix<Double>.dot(vector: Point<Double>): Float64Buffer {
|
||||
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||
val rows = this@dot.rows.map { it.linearize() }
|
||||
val indices = 0 until this.rowNum
|
||||
val indices = 0 until this.colNum
|
||||
return Float64Buffer(rowNum) { i ->
|
||||
val r = rows[i]
|
||||
var res = 0.0
|
||||
|
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.linear
|
||||
|
||||
import space.kscience.kmath.PerformancePitfall
|
||||
import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.Float64Field
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
@UnstableKMathAPI
|
||||
@OptIn(PerformancePitfall::class)
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
class ParallelMatrixTest {
|
||||
|
||||
@Test
|
||||
fun testTranspose() = Float64Field.linearSpace.parallel{
|
||||
val matrix = one(3, 3)
|
||||
val transposed = matrix.transposed()
|
||||
assertTrue { StructureND.contentEquals(matrix, transposed) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBuilder() = Float64Field.linearSpace.parallel{
|
||||
val matrix = matrix(2, 3)(
|
||||
1.0, 0.0, 0.0,
|
||||
0.0, 1.0, 2.0
|
||||
)
|
||||
|
||||
assertEquals(2.0, matrix[1, 2])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMatrixExtension() = Float64Field.linearSpace.parallel{
|
||||
val transitionMatrix: Matrix<Double> = VirtualMatrix(type,6, 6) { row, col ->
|
||||
when {
|
||||
col == 0 -> .50
|
||||
row + 1 == col -> .50
|
||||
row == 5 && col == 5 -> 1.0
|
||||
else -> 0.0
|
||||
}
|
||||
}
|
||||
|
||||
infix fun Matrix<Double>.pow(power: Int): Matrix<Double> {
|
||||
var res = this
|
||||
repeat(power - 1) {
|
||||
res = res dot this@pow
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
val toTenthPower = transitionMatrix pow 10
|
||||
}
|
||||
|
||||
@Test
|
||||
fun test2DDot() = Float64Field.linearSpace.parallel {
|
||||
val firstMatrix = buildMatrix(2, 3) { i, j -> (i + j).toDouble() }
|
||||
val secondMatrix = buildMatrix(3, 2) { i, j -> (i + j).toDouble() }
|
||||
|
||||
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
|
||||
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
|
||||
val result = firstMatrix dot secondMatrix
|
||||
assertEquals(2, result.rowNum)
|
||||
assertEquals(2, result.colNum)
|
||||
assertEquals(8.0, result[0, 1])
|
||||
assertEquals(8.0, result[1, 0])
|
||||
assertEquals(14.0, result[1, 1])
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user