forked from kscience/kmath
LU decomposition and inversion working
This commit is contained in:
parent
bbc012d8cd
commit
a9ed3fb72e
@ -8,7 +8,7 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
||||
buildscript {
|
||||
val kotlinVersion: String by rootProject.extra("1.3.30")
|
||||
val ioVersion: String by rootProject.extra("0.1.5")
|
||||
val coroutinesVersion: String by rootProject.extra("1.1.1")
|
||||
val coroutinesVersion: String by rootProject.extra("1.2.0")
|
||||
val atomicfuVersion: String by rootProject.extra("0.12.1")
|
||||
val dokkaVersion: String by rootProject.extra("0.9.17")
|
||||
val serializationVersion: String by rootProject.extra("0.10.0")
|
||||
|
@ -1,10 +1,9 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.sum
|
||||
import scientifik.kmath.structures.*
|
||||
import scientifik.kmath.structures.Buffer.Companion.boxing
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.Structure2D
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
@ -70,7 +69,7 @@ inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||
VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@ import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.BufferAccessor2D
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.Structure2D
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* Common implementation of [LUPDecompositionFeature]
|
||||
@ -73,7 +74,8 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
||||
/**
|
||||
* Create a lup decomposition of generic matrix
|
||||
*/
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
type: KClass<T>,
|
||||
matrix: Matrix<T>,
|
||||
checkSingular: (T) -> Boolean
|
||||
): LUPDecomposition<T> {
|
||||
@ -85,7 +87,7 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
|
||||
val pivot = IntArray(matrix.rowNum)
|
||||
|
||||
//TODO just waits for KEEP-176
|
||||
BufferAccessor2D(T::class, matrix.rowNum, matrix.colNum).run {
|
||||
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||
elementContext.run {
|
||||
|
||||
val lu = create(matrix)
|
||||
@ -100,7 +102,6 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
|
||||
for (row in 0 until m) {
|
||||
pivot[row] = row
|
||||
}
|
||||
var singular = false
|
||||
|
||||
// Loop over columns
|
||||
for (col in 0 until m) {
|
||||
@ -165,21 +166,25 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
|
||||
}
|
||||
}
|
||||
|
||||
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(matrix) { it < 1e-11 }
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
matrix: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean
|
||||
) = lup(T::class, matrix, checkSingular)
|
||||
|
||||
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> {
|
||||
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 }
|
||||
|
||||
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
||||
|
||||
if (matrix.rowNum != pivot.size) {
|
||||
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
|
||||
}
|
||||
|
||||
BufferAccessor2D(T::class, matrix.rowNum, matrix.colNum).run {
|
||||
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||
elementContext.run {
|
||||
|
||||
val lu = create{i,j-> this@solve.lu[i,j]}
|
||||
|
||||
// Apply permutations to b
|
||||
val bp = create { i, j -> zero }
|
||||
|
||||
for (row in 0 until pivot.size) {
|
||||
val bpRow = bp.row(row)
|
||||
val pRow = pivot[row]
|
||||
@ -220,6 +225,7 @@ inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matri
|
||||
}
|
||||
}
|
||||
|
||||
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix)
|
||||
|
||||
/**
|
||||
* Solve a linear equation **a*x = b**
|
||||
@ -227,11 +233,11 @@ inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matri
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
|
||||
a: Matrix<T>,
|
||||
b: Matrix<T>,
|
||||
crossinline checkSingular: (T) -> Boolean
|
||||
noinline checkSingular: (T) -> Boolean
|
||||
): Matrix<T> {
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
val decomposition = a.getFeature() ?: lup(a, checkSingular)
|
||||
return decomposition.solve(b)
|
||||
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
|
||||
return decomposition.solve(T::class, b)
|
||||
}
|
||||
|
||||
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) =
|
||||
|
@ -9,6 +9,13 @@ class VirtualMatrix<T : Any>(
|
||||
val generator: (i: Int, j: Int) -> T
|
||||
) : FeaturedMatrix<T> {
|
||||
|
||||
constructor(rowNum: Int, colNum: Int, vararg features: MatrixFeature, generator: (i: Int, j: Int) -> T) : this(
|
||||
rowNum,
|
||||
colNum,
|
||||
setOf(*features),
|
||||
generator
|
||||
)
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
||||
|
@ -1,5 +1,6 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow
|
||||
|
||||
/**
|
||||
@ -31,79 +32,152 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
/**
|
||||
* A field for double without boxing. Does not produce appropriate field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
|
||||
override val zero: Double = 0.0
|
||||
override fun add(a: Double, b: Double): Double = a + b
|
||||
override fun multiply(a: Double, b: Double): Double = a * b
|
||||
override fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||
override inline fun add(a: Double, b: Double) = a + b
|
||||
override inline fun multiply(a: Double, b: Double) = a * b
|
||||
override inline fun multiply(a: Double, k: Number) = a * k.toDouble()
|
||||
|
||||
override val one: Double = 1.0
|
||||
override fun divide(a: Double, b: Double): Double = a / b
|
||||
override inline fun divide(a: Double, b: Double) = a / b
|
||||
|
||||
override fun sin(arg: Double): Double = kotlin.math.sin(arg)
|
||||
override fun cos(arg: Double): Double = kotlin.math.cos(arg)
|
||||
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
|
||||
|
||||
override fun power(arg: Double, pow: Number): Double = arg.pow(pow.toDouble())
|
||||
override inline fun power(arg: Double, pow: Number) = arg.pow(pow.toDouble())
|
||||
|
||||
override fun exp(arg: Double): Double = kotlin.math.exp(arg)
|
||||
override fun ln(arg: Double): Double = kotlin.math.ln(arg)
|
||||
override inline fun exp(arg: Double) = kotlin.math.exp(arg)
|
||||
override inline fun ln(arg: Double) = kotlin.math.ln(arg)
|
||||
|
||||
override fun norm(arg: Double): Double = kotlin.math.abs(arg)
|
||||
override inline fun norm(arg: Double) = abs(arg)
|
||||
|
||||
override fun Double.unaryMinus(): Double = -this
|
||||
override inline fun Double.unaryMinus() = -this
|
||||
|
||||
override fun Double.minus(b: Double): Double = this - b
|
||||
override inline fun Double.plus(b: Double) = this + b
|
||||
|
||||
override inline fun Double.minus(b: Double) = this - b
|
||||
|
||||
override inline fun Double.times(b: Double) = this * b
|
||||
|
||||
override inline fun Double.div(b: Double) = this / b
|
||||
}
|
||||
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object FloatField : Field<Float>, ExtendedFieldOperations<Float>, Norm<Float, Float> {
|
||||
override val zero: Float = 0f
|
||||
override inline fun add(a: Float, b: Float) = a + b
|
||||
override inline fun multiply(a: Float, b: Float) = a * b
|
||||
override inline fun multiply(a: Float, k: Number) = a * k.toFloat()
|
||||
|
||||
override val one: Float = 1f
|
||||
override inline fun divide(a: Float, b: Float) = a / b
|
||||
|
||||
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
|
||||
|
||||
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
|
||||
|
||||
override inline fun exp(arg: Float) = kotlin.math.exp(arg)
|
||||
override inline fun ln(arg: Float) = kotlin.math.ln(arg)
|
||||
|
||||
override inline fun norm(arg: Float) = abs(arg)
|
||||
|
||||
override inline fun Float.unaryMinus() = -this
|
||||
|
||||
override inline fun Float.plus(b: Float) = this + b
|
||||
|
||||
override inline fun Float.minus(b: Float) = this - b
|
||||
|
||||
override inline fun Float.times(b: Float) = this * b
|
||||
|
||||
override inline fun Float.div(b: Float) = this / b
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Int] without boxing. Does not produce corresponding field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
override val zero: Int = 0
|
||||
override fun add(a: Int, b: Int): Int = a + b
|
||||
override fun multiply(a: Int, b: Int): Int = a * b
|
||||
override fun multiply(a: Int, k: Number): Int = (k * a)
|
||||
override inline fun add(a: Int, b: Int) = a + b
|
||||
override inline fun multiply(a: Int, b: Int) = a * b
|
||||
override inline fun multiply(a: Int, k: Number) = (k * a)
|
||||
override val one: Int = 1
|
||||
|
||||
override fun norm(arg: Int): Int = arg
|
||||
override inline fun norm(arg: Int) = abs(arg)
|
||||
|
||||
override inline fun Int.unaryMinus() = -this
|
||||
|
||||
override inline fun Int.plus(b: Int): Int = this + b
|
||||
|
||||
override inline fun Int.minus(b: Int): Int = this - b
|
||||
|
||||
override inline fun Int.times(b: Int): Int = this * b
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Short] without boxing. Does not produce appropriate field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
override val zero: Short = 0
|
||||
override fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
override fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
||||
override fun multiply(a: Short, k: Number): Short = (a * k)
|
||||
override inline fun add(a: Short, b: Short) = (a + b).toShort()
|
||||
override inline fun multiply(a: Short, b: Short) = (a * b).toShort()
|
||||
override inline fun multiply(a: Short, k: Number) = (a * k)
|
||||
override val one: Short = 1
|
||||
|
||||
override fun norm(arg: Short): Short = arg
|
||||
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||
|
||||
override inline fun Short.unaryMinus() = (-this).toShort()
|
||||
|
||||
override inline fun Short.plus(b: Short) = (this + b).toShort()
|
||||
|
||||
override inline fun Short.minus(b: Short) = (this - b).toShort()
|
||||
|
||||
override inline fun Short.times(b: Short) = (this * b).toShort()
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Byte] values
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
override val zero: Byte = 0
|
||||
override fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
override fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
||||
override fun multiply(a: Byte, k: Number): Byte = (a * k)
|
||||
override inline fun add(a: Byte, b: Byte) = (a + b).toByte()
|
||||
override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte()
|
||||
override inline fun multiply(a: Byte, k: Number) = (a * k)
|
||||
override val one: Byte = 1
|
||||
|
||||
override fun norm(arg: Byte): Byte = arg
|
||||
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||
|
||||
override inline fun Byte.unaryMinus() = (-this).toByte()
|
||||
|
||||
override inline fun Byte.plus(b: Byte) = (this + b).toByte()
|
||||
|
||||
override inline fun Byte.minus(b: Byte) = (this - b).toByte()
|
||||
|
||||
override inline fun Byte.times(b: Byte) = (this * b).toByte()
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Long] values
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||
override val zero: Long = 0
|
||||
override fun add(a: Long, b: Long): Long = (a + b)
|
||||
override fun multiply(a: Long, b: Long): Long = (a * b)
|
||||
override fun multiply(a: Long, k: Number): Long = (a * k)
|
||||
override inline fun add(a: Long, b: Long) = (a + b)
|
||||
override inline fun multiply(a: Long, b: Long) = (a * b)
|
||||
override inline fun multiply(a: Long, k: Number) = (a * k)
|
||||
override val one: Long = 1
|
||||
|
||||
override fun norm(arg: Long): Long = arg
|
||||
override fun norm(arg: Long): Long = abs(arg)
|
||||
|
||||
override inline fun Long.unaryMinus() = (-this)
|
||||
|
||||
override inline fun Long.plus(b: Long) = (this + b)
|
||||
|
||||
override inline fun Long.minus(b: Long) = (this - b)
|
||||
|
||||
override inline fun Long.times(b: Long) = (this * b)
|
||||
}
|
@ -16,21 +16,28 @@ class RealLUSolverTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testInvert() {
|
||||
fun testDecomposition() {
|
||||
val matrix = Matrix.square(
|
||||
3.0, 1.0,
|
||||
1.0, 3.0
|
||||
)
|
||||
|
||||
val decomposition = MatrixContext.real.lup(matrix)
|
||||
MatrixContext.real.run {
|
||||
val lup = lup(matrix)
|
||||
|
||||
//Check determinant
|
||||
assertEquals(8.0, decomposition.determinant)
|
||||
//Check determinant
|
||||
assertEquals(8.0, lup.determinant)
|
||||
|
||||
//Check decomposition
|
||||
with(MatrixContext.real) {
|
||||
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
||||
assertEquals(lup.p dot matrix, lup.l dot lup.u)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testInvert() {
|
||||
val matrix = Matrix.square(
|
||||
3.0, 1.0,
|
||||
1.0, 3.0
|
||||
)
|
||||
|
||||
val inverted = MatrixContext.real.inverse(matrix)
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
id("kotlinx-atomicfu") version "0.12.1"
|
||||
id("kotlinx-atomicfu") version "0.12.4"
|
||||
}
|
||||
|
||||
val atomicfuVersion: String by rootProject.extra
|
||||
|
||||
kotlin {
|
||||
jvm ()
|
||||
//js()
|
||||
js()
|
||||
|
||||
sourceSets {
|
||||
val commonMain by getting {
|
||||
@ -17,33 +17,16 @@ kotlin {
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:$atomicfuVersion")
|
||||
}
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu:$atomicfuVersion")
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
val jsMain by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:$atomicfuVersion")
|
||||
}
|
||||
}
|
||||
// val jsMain by getting {
|
||||
// dependencies {
|
||||
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:$atomicfuVersion")
|
||||
// }
|
||||
// }
|
||||
// val jsTest by getting {
|
||||
// dependencies {
|
||||
// implementation(kotlin("test-js"))
|
||||
// }
|
||||
// }
|
||||
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user