Merge branch 'dev' into feature/quaternion

This commit is contained in:
Iaroslav Postovalov 2021-01-24 02:00:22 +07:00
commit 9342824c96
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
21 changed files with 90 additions and 92 deletions

View File

@ -28,7 +28,7 @@ internal class ExpressionsInterpretersBenchmark {
@Benchmark @Benchmark
fun mstExpression() { fun mstExpression() {
val expr = algebra.mstInField { val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0
} }
invokeAndSum(expr) invokeAndSum(expr)
@ -37,7 +37,7 @@ internal class ExpressionsInterpretersBenchmark {
@Benchmark @Benchmark
fun asmExpression() { fun asmExpression() {
val expr = algebra.mstInField { val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0
}.compile() }.compile()
invokeAndSum(expr) invokeAndSum(expr)

View File

@ -2,9 +2,8 @@ package kscience.kmath.benchmarks
import kotlinx.benchmark.Benchmark import kotlinx.benchmark.Benchmark
import kscience.kmath.commons.linear.CMMatrixContext import kscience.kmath.commons.linear.CMMatrixContext
import kscience.kmath.commons.linear.toCM
import kscience.kmath.ejml.EjmlMatrixContext import kscience.kmath.ejml.EjmlMatrixContext
import kscience.kmath.ejml.toEjml
import kscience.kmath.linear.BufferMatrixContext import kscience.kmath.linear.BufferMatrixContext
import kscience.kmath.linear.RealMatrixContext import kscience.kmath.linear.RealMatrixContext
import kscience.kmath.linear.real import kscience.kmath.linear.real
@ -26,11 +25,11 @@ class DotBenchmark {
val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
val cmMatrix1 = matrix1.toCM() val cmMatrix1 = CMMatrixContext { matrix1.toCM() }
val cmMatrix2 = matrix2.toCM() val cmMatrix2 = CMMatrixContext { matrix2.toCM() }
val ejmlMatrix1 = matrix1.toEjml() val ejmlMatrix1 = EjmlMatrixContext { matrix1.toEjml() }
val ejmlMatrix2 = matrix2.toEjml() val ejmlMatrix2 = EjmlMatrixContext { matrix2.toEjml() }
} }
@Benchmark @Benchmark
@ -49,22 +48,23 @@ class DotBenchmark {
@Benchmark @Benchmark
fun ejmlMultiplicationwithConversion() { fun ejmlMultiplicationwithConversion() {
EjmlMatrixContext {
val ejmlMatrix1 = matrix1.toEjml() val ejmlMatrix1 = matrix1.toEjml()
val ejmlMatrix2 = matrix2.toEjml() val ejmlMatrix2 = matrix2.toEjml()
EjmlMatrixContext {
ejmlMatrix1 dot ejmlMatrix2 ejmlMatrix1 dot ejmlMatrix2
} }
} }
@Benchmark @Benchmark
fun bufferedMultiplication() { fun bufferedMultiplication() {
BufferMatrixContext(RealField, Buffer.Companion::real).invoke{ BufferMatrixContext(RealField, Buffer.Companion::real).invoke {
matrix1 dot matrix2 matrix1 dot matrix2
} }
} }
@Benchmark @Benchmark
fun realMultiplication(){ fun realMultiplication() {
RealMatrixContext { RealMatrixContext {
matrix1 dot matrix2 matrix1 dot matrix2
} }

View File

@ -1,25 +0,0 @@
package kscience.kmath.benchmarks
import kscience.kmath.structures.NDField
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import org.openjdk.jmh.infra.Blackhole
import kotlin.random.Random
@State(Scope.Benchmark)
class LargeNDBenchmark {
val arraySize = 10000
val RANDOM = Random(222)
val src1 = DoubleArray(arraySize) { RANDOM.nextDouble() }
val src2 = DoubleArray(arraySize) { RANDOM.nextDouble() }
val field = NDField.real(arraySize)
val kmathArray1 = field.produce { (a) -> src1[a] }
val kmathArray2 = field.produce { (a) -> src2[a] }
@Benchmark
fun test10000(bh: Blackhole) {
bh.consume(field.add(kmathArray1, kmathArray2))
}
}

View File

@ -5,10 +5,8 @@ import kotlinx.benchmark.Benchmark
import kscience.kmath.commons.linear.CMMatrixContext import kscience.kmath.commons.linear.CMMatrixContext
import kscience.kmath.commons.linear.CMMatrixContext.dot import kscience.kmath.commons.linear.CMMatrixContext.dot
import kscience.kmath.commons.linear.inverse import kscience.kmath.commons.linear.inverse
import kscience.kmath.commons.linear.toCM
import kscience.kmath.ejml.EjmlMatrixContext import kscience.kmath.ejml.EjmlMatrixContext
import kscience.kmath.ejml.inverse import kscience.kmath.ejml.inverse
import kscience.kmath.ejml.toEjml
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
@ -35,16 +33,14 @@ class LinearAlgebraBenchmark {
@Benchmark @Benchmark
fun cmLUPInversion() { fun cmLUPInversion() {
CMMatrixContext { CMMatrixContext {
val cm = matrix.toCM() //avoid overhead on conversion inverse(matrix)
inverse(cm)
} }
} }
@Benchmark @Benchmark
fun ejmlInverse() { fun ejmlInverse() {
EjmlMatrixContext { EjmlMatrixContext {
val km = matrix.toEjml() //avoid overhead on conversion inverse(matrix)
inverse(km)
} }
} }
} }

View File

@ -2,8 +2,8 @@ package kscience.kmath.commons.linear
import kscience.kmath.linear.DiagonalFeature import kscience.kmath.linear.DiagonalFeature
import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.MatrixWrapper
import kscience.kmath.linear.Point import kscience.kmath.linear.Point
import kscience.kmath.linear.origin
import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.*
@ -47,9 +47,9 @@ public object CMMatrixContext : MatrixContext<Double, CMMatrix> {
return CMMatrix(Array2DRowRealMatrix(array)) return CMMatrix(Array2DRowRealMatrix(array))
} }
public fun Matrix<Double>.toCM(): CMMatrix = when { @OptIn(UnstableKMathAPI::class)
this is CMMatrix -> this public fun Matrix<Double>.toCM(): CMMatrix = when (val matrix = origin) {
this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix is CMMatrix -> matrix
else -> { else -> {
//TODO add feature analysis //TODO add feature analysis
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }

View File

@ -95,8 +95,9 @@ public open class FunctionalExpressionRing<T, A : Ring<T>>(
super<FunctionalExpressionSpace>.binaryOperationFunction(operation) super<FunctionalExpressionSpace>.binaryOperationFunction(operation)
} }
public open class FunctionalExpressionField<T, A : Field<T>>(algebra: A) : public open class FunctionalExpressionField<T, A : Field<T>>(
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>> { algebra: A,
) : FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>> {
/** /**
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */

View File

@ -43,7 +43,7 @@ public class BufferMatrix<T : Any>(
if (this === other) return true if (this === other) return true
return when (other) { return when (other) {
is NDStructure<*> -> NDStructure.equals(this, other) is NDStructure<*> -> NDStructure.contentEquals(this, other)
else -> false else -> false
} }
} }

View File

@ -224,6 +224,7 @@ public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext
): Matrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) ): Matrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
@OptIn(UnstableKMathAPI::class)
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> { public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
// Use existing decomposition if it is provided by matrix // Use existing decomposition if it is provided by matrix
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real

View File

@ -15,29 +15,37 @@ import kotlin.reflect.safeCast
* *
* @param T the type of items. * @param T the type of items.
*/ */
public class MatrixWrapper<T : Any>( public class MatrixWrapper<T : Any> internal constructor(
public val matrix: Matrix<T>, public val origin: Matrix<T>,
public val features: Set<MatrixFeature>, public val features: Set<MatrixFeature>,
) : Matrix<T> by matrix { ) : Matrix<T> by origin {
/** /**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/ */
@UnstableKMathAPI @UnstableKMathAPI
override fun <T : Any> getFeature(type: KClass<T>): T? = type.safeCast(features.find { type.isInstance(it) }) override fun <T : Any> getFeature(type: KClass<T>): T? = type.safeCast(features.find { type.isInstance(it) })
?: origin.getFeature(type)
override fun equals(other: Any?): Boolean = matrix == other override fun equals(other: Any?): Boolean = origin == other
override fun hashCode(): Int = matrix.hashCode() override fun hashCode(): Int = origin.hashCode()
override fun toString(): String { override fun toString(): String {
return "MatrixWrapper(matrix=$matrix, features=$features)" return "MatrixWrapper(matrix=$origin, features=$features)"
} }
} }
/**
* Return the original matrix. If this is a wrapper, return its origin. If not, this matrix.
* Origin does not necessary store all features.
*/
@UnstableKMathAPI
public val <T : Any> Matrix<T>.origin: Matrix<T> get() = (this as? MatrixWrapper)?.origin ?: this
/** /**
* Add a single feature to a [Matrix] * Add a single feature to a [Matrix]
*/ */
public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixWrapper<T> = if (this is MatrixWrapper) { public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixWrapper<T> = if (this is MatrixWrapper) {
MatrixWrapper(matrix, features + newFeature) MatrixWrapper(origin, features + newFeature)
} else { } else {
MatrixWrapper(this, setOf(newFeature)) MatrixWrapper(this, setOf(newFeature))
} }
@ -47,7 +55,7 @@ public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixW
*/ */
public operator fun <T : Any> Matrix<T>.plus(newFeatures: Collection<MatrixFeature>): MatrixWrapper<T> = public operator fun <T : Any> Matrix<T>.plus(newFeatures: Collection<MatrixFeature>): MatrixWrapper<T> =
if (this is MatrixWrapper) { if (this is MatrixWrapper) {
MatrixWrapper(matrix, features + newFeatures) MatrixWrapper(origin, features + newFeatures)
} else { } else {
MatrixWrapper(this, newFeatures.toSet()) MatrixWrapper(this, newFeatures.toSet())
} }

View File

@ -92,7 +92,7 @@ public interface Algebra<T> {
* Call a block with an [Algebra] as receiver. * Call a block with an [Algebra] as receiver.
*/ */
// TODO add contract when KT-32313 is fixed // TODO add contract when KT-32313 is fixed
public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = block() public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
/** /**
* Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as

View File

@ -3,6 +3,7 @@ package kscience.kmath.operations
import kscience.kmath.memory.MemoryReader import kscience.kmath.memory.MemoryReader
import kscience.kmath.memory.MemorySpec import kscience.kmath.memory.MemorySpec
import kscience.kmath.memory.MemoryWriter import kscience.kmath.memory.MemoryWriter
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import kscience.kmath.structures.MemoryBuffer import kscience.kmath.structures.MemoryBuffer
import kscience.kmath.structures.MutableBuffer import kscience.kmath.structures.MutableBuffer
@ -41,6 +42,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
/** /**
* A field of [Complex]. * A field of [Complex].
*/ */
@OptIn(UnstableKMathAPI::class)
public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex>, RingWithNumbers<Complex> { public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex>, RingWithNumbers<Complex> {
override val zero: Complex = 0.0.toComplex() override val zero: Complex = 0.0.toComplex()
override val one: Complex = 1.0.toComplex() override val one: Complex = 1.0.toComplex()

View File

@ -54,7 +54,7 @@ public interface NDStructure<T> {
/** /**
* Indicates whether some [NDStructure] is equal to another one. * Indicates whether some [NDStructure] is equal to another one.
*/ */
public fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { public fun contentEquals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
@ -275,7 +275,7 @@ public abstract class NDBuffer<T> : NDStructure<T> {
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] } override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
} }
override fun hashCode(): Int { override fun hashCode(): Int {

View File

@ -11,7 +11,7 @@ public typealias RealNDElement = BufferedNDFieldElement<Double, RealField>
public class RealNDField(override val shape: IntArray) : public class RealNDField(override val shape: IntArray) :
BufferedNDField<Double, RealField>, BufferedNDField<Double, RealField>,
ExtendedNDField<Double, RealField, NDBuffer<Double>>, ExtendedNDField<Double, RealField, NDBuffer<Double>>,
RingWithNumbers<NDBuffer<Double>>{ RingWithNumbers<NDBuffer<Double>> {
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
@ -24,35 +24,31 @@ public class RealNDField(override val shape: IntArray) :
return produce { d } return produce { d }
} }
private inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> = @Suppress("OVERRIDE_BY_INLINE")
RealBuffer(DoubleArray(size) { initializer(it) }) override inline fun map(
/**
* Inline transform an NDStructure to
*/
override fun map(
arg: NDBuffer<Double>, arg: NDBuffer<Double>,
transform: RealField.(Double) -> Double transform: RealField.(Double) -> Double,
): RealNDElement { ): RealNDElement {
check(arg) check(arg)
val array = buildBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } val array = RealBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) }
return BufferedNDFieldElement(this, array) return BufferedNDFieldElement(this, array)
} }
override fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement { @Suppress("OVERRIDE_BY_INLINE")
val array = buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } override inline fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement {
val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
return BufferedNDFieldElement(this, array) return BufferedNDFieldElement(this, array)
} }
override fun mapIndexed( @Suppress("OVERRIDE_BY_INLINE")
override inline fun mapIndexed(
arg: NDBuffer<Double>, arg: NDBuffer<Double>,
transform: RealField.(index: IntArray, Double) -> Double transform: RealField.(index: IntArray, Double) -> Double,
): RealNDElement { ): RealNDElement {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> RealBuffer(arg.strides.linearSize) { offset ->
elementContext.transform( elementContext.transform(
arg.strides.index(offset), arg.strides.index(offset),
arg.buffer[offset] arg.buffer[offset]
@ -60,16 +56,17 @@ public class RealNDField(override val shape: IntArray) :
}) })
} }
override fun combine( @Suppress("OVERRIDE_BY_INLINE")
override inline fun combine(
a: NDBuffer<Double>, a: NDBuffer<Double>,
b: NDBuffer<Double>, b: NDBuffer<Double>,
transform: RealField.(Double, Double) -> Double transform: RealField.(Double, Double) -> Double,
): RealNDElement { ): RealNDElement {
check(a, b) check(a, b)
return BufferedNDFieldElement( val buffer = RealBuffer(strides.linearSize) { offset ->
this, elementContext.transform(a.buffer[offset], b.buffer[offset])
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) } }
) return BufferedNDFieldElement(this, buffer)
} }
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> = override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =

View File

@ -7,6 +7,7 @@ import kscience.kmath.structures.as2D
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@Suppress("UNUSED_VARIABLE")
class MatrixTest { class MatrixTest {
@Test @Test
fun testTranspose() { fun testTranspose() {

View File

@ -8,6 +8,7 @@ import kotlin.math.pow
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@Suppress("UNUSED_VARIABLE")
class NumberNDFieldTest { class NumberNDFieldTest {
val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() } val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() }
val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() } val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() }

View File

@ -24,7 +24,7 @@ public class LazyNDStructure<T>(
} }
public override fun equals(other: Any?): Boolean { public override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
} }
public override fun hashCode(): Int { public override fun hashCode(): Int {

View File

@ -6,6 +6,7 @@ import kscience.kmath.dimensions.DMatrixContext
import kscience.kmath.dimensions.one import kscience.kmath.dimensions.one
import kotlin.test.Test import kotlin.test.Test
@Suppress("UNUSED_VARIABLE")
internal class DMatrixContextTest { internal class DMatrixContextTest {
@Test @Test
fun testDimensionSafeMatrix() { fun testDimensionSafeMatrix() {

View File

@ -3,6 +3,7 @@ package kscience.kmath.ejml
import kscience.kmath.linear.* import kscience.kmath.linear.*
import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.NDStructure
import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.RealBuffer
import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix import org.ejml.simple.SimpleMatrix
@ -15,7 +16,7 @@ import kotlin.reflect.cast
* @property origin the underlying [SimpleMatrix]. * @property origin the underlying [SimpleMatrix].
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public inline class EjmlMatrix( public class EjmlMatrix(
public val origin: SimpleMatrix, public val origin: SimpleMatrix,
) : Matrix<Double> { ) : Matrix<Double> {
public override val rowNum: Int get() = origin.numRows() public override val rowNum: Int get() = origin.numRows()
@ -73,7 +74,17 @@ public inline class EjmlMatrix(
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
} }
else -> null else -> null
}?.let{type.cast(it)} }?.let { type.cast(it) }
public override operator fun get(i: Int, j: Int): Double = origin[i, j] public override operator fun get(i: Int, j: Int): Double = origin[i, j]
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is Matrix<*>) return false
return NDStructure.contentEquals(this, other)
}
override fun hashCode(): Int = origin.hashCode()
} }

View File

@ -2,8 +2,8 @@ package kscience.kmath.ejml
import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.InverseMatrixFeature
import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.MatrixWrapper
import kscience.kmath.linear.Point import kscience.kmath.linear.Point
import kscience.kmath.linear.origin
import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.getFeature import kscience.kmath.structures.getFeature
@ -19,9 +19,9 @@ public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> {
/** /**
* Converts this matrix to EJML one. * Converts this matrix to EJML one.
*/ */
public fun Matrix<Double>.toEjml(): EjmlMatrix = when { @OptIn(UnstableKMathAPI::class)
this is EjmlMatrix -> this public fun Matrix<Double>.toEjml(): EjmlMatrix = when (val matrix = origin) {
this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix is EjmlMatrix -> matrix
else -> produce(rowNum, colNum) { i, j -> get(i, j) } else -> produce(rowNum, colNum) { i, j -> get(i, j) }
} }

View File

@ -4,6 +4,7 @@ import kscience.kmath.linear.DeterminantFeature
import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.LupDecompositionFeature
import kscience.kmath.linear.MatrixFeature import kscience.kmath.linear.MatrixFeature
import kscience.kmath.linear.plus import kscience.kmath.linear.plus
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.getFeature import kscience.kmath.structures.getFeature
import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix import org.ejml.simple.SimpleMatrix
@ -39,6 +40,7 @@ internal class EjmlMatrixTest {
assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList()) assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList())
} }
@OptIn(UnstableKMathAPI::class)
@Test @Test
fun features() { fun features() {
val m = randomMatrix val m = randomMatrix
@ -57,6 +59,7 @@ internal class EjmlMatrixTest {
private object SomeFeature : MatrixFeature {} private object SomeFeature : MatrixFeature {}
@OptIn(UnstableKMathAPI::class)
@Test @Test
fun suggestFeature() { fun suggestFeature() {
assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>()) assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>())

View File

@ -62,6 +62,7 @@ class MCScopeTest {
} }
@OptIn(ObsoleteCoroutinesApi::class)
fun compareResult(test: ATest) { fun compareResult(test: ATest) {
val res1 = runBlocking(Dispatchers.Default) { test() } val res1 = runBlocking(Dispatchers.Default) { test() }
val res2 = runBlocking(newSingleThreadContext("test")) { test() } val res2 = runBlocking(newSingleThreadContext("test")) { test() }