Fix EJML to properly treat vectors as columns

This commit is contained in:
Alexander Nozik 2024-12-24 12:57:08 +03:00
parent b4b8f30b2a
commit 19139c0d4e
3 changed files with 14 additions and 5 deletions

View File

@ -18,7 +18,7 @@ import space.kscience.kmath.linear.Point
*/
public abstract class EjmlVector<out T, out M : Matrix>(public open val origin: M) : Point<T> {
override val size: Int
get() = origin.numCols
get() = origin.numRows
override operator fun iterator(): Iterator<T> = object : Iterator<T> {
private var cursor: Int = 0

View File

@ -45,11 +45,11 @@ public fun Complex_F64.toKMathComplex(): Complex = Complex(real, imaginary)
*/
public class EjmlDoubleVector<out M : DMatrix>(override val origin: M) : EjmlVector<Double, M>(origin) {
init {
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
require(origin.numCols == 1) { "The origin matrix must have only one column to form a vector" }
}
override operator fun get(index: Int): Double = origin[0, index]
override operator fun get(index: Int): Double = origin[index, 0]
}
/**
@ -57,10 +57,10 @@ public class EjmlDoubleVector<out M : DMatrix>(override val origin: M) : EjmlVec
*/
public class EjmlFloatVector<out M : FMatrix>(override val origin: M) : EjmlVector<Float, M>(origin) {
init {
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
require(origin.numCols == 1) { "The origin matrix must have only one column to form a vector" }
}
override operator fun get(index: Int): Float = origin[0, index]
override operator fun get(index: Int): Float = origin[index, 0]
}
/**

View File

@ -7,6 +7,9 @@ package space.kscience.kmath.ejml
import org.ejml.data.DMatrixRMaj
import org.ejml.dense.row.RandomMatrices_DDRM
import space.kscience.kmath.linear.invoke
import space.kscience.kmath.structures.asBuffer
import space.kscience.kmath.testutils.assertBufferEquals
import kotlin.random.Random
import kotlin.random.asJavaRandom
import kotlin.test.Test
@ -54,4 +57,10 @@ internal class EjmlVectorTest {
val w = EjmlDoubleVector(m)
assertSame(m, w.origin)
}
@Test
fun unaryMinus() = EjmlLinearSpaceDDRM {
val mu = doubleArrayOf(1.0, 2.0, 3.0).asBuffer()
assertBufferEquals(mu, -(-mu))
}
}