LU decomposition and inversion working

This commit is contained in:
Alexander Nozik 2019-04-21 16:17:30 +03:00
parent bbc012d8cd
commit a9ed3fb72e
7 changed files with 150 additions and 74 deletions

View File

@ -8,7 +8,7 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
buildscript {
val kotlinVersion: String by rootProject.extra("1.3.30")
val ioVersion: String by rootProject.extra("0.1.5")
val coroutinesVersion: String by rootProject.extra("1.1.1")
val coroutinesVersion: String by rootProject.extra("1.2.0")
val atomicfuVersion: String by rootProject.extra("0.12.1")
val dokkaVersion: String by rootProject.extra("0.9.17")
val serializationVersion: String by rootProject.extra("0.10.0")

View File

@ -1,10 +1,9 @@
package scientifik.kmath.linear
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.sum
import scientifik.kmath.structures.*
import scientifik.kmath.structures.Buffer.Companion.boxing
import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D
import scientifik.kmath.structures.asBuffer
import kotlin.math.sqrt
/**
@ -70,7 +69,7 @@ inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
*/
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix<T>(rows, columns) { i, j ->
VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
if (i == j) elementContext.one else elementContext.zero
}

View File

@ -6,6 +6,7 @@ import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.BufferAccessor2D
import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D
import kotlin.reflect.KClass
/**
* Common implementation of [LUPDecompositionFeature]
@ -73,7 +74,8 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
/**
* Create a lup decomposition of generic matrix
*/
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
type: KClass<T>,
matrix: Matrix<T>,
checkSingular: (T) -> Boolean
): LUPDecomposition<T> {
@ -85,7 +87,7 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
val pivot = IntArray(matrix.rowNum)
//TODO just waits for KEEP-176
BufferAccessor2D(T::class, matrix.rowNum, matrix.colNum).run {
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
elementContext.run {
val lu = create(matrix)
@ -100,7 +102,6 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
for (row in 0 until m) {
pivot[row] = row
}
var singular = false
// Loop over columns
for (col in 0 until m) {
@ -165,21 +166,25 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
}
}
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(matrix) { it < 1e-11 }
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean
) = lup(T::class, matrix, checkSingular)
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> {
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 }
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
if (matrix.rowNum != pivot.size) {
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
}
BufferAccessor2D(T::class, matrix.rowNum, matrix.colNum).run {
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
elementContext.run {
val lu = create{i,j-> this@solve.lu[i,j]}
// Apply permutations to b
val bp = create { i, j -> zero }
for (row in 0 until pivot.size) {
val bpRow = bp.row(row)
val pRow = pivot[row]
@ -220,6 +225,7 @@ inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matri
}
}
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix)
/**
* Solve a linear equation **a*x = b**
@ -227,11 +233,11 @@ inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matri
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
a: Matrix<T>,
b: Matrix<T>,
crossinline checkSingular: (T) -> Boolean
noinline checkSingular: (T) -> Boolean
): Matrix<T> {
// Use existing decomposition if it is provided by matrix
val decomposition = a.getFeature() ?: lup(a, checkSingular)
return decomposition.solve(b)
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
return decomposition.solve(T::class, b)
}
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) =

View File

@ -9,6 +9,13 @@ class VirtualMatrix<T : Any>(
val generator: (i: Int, j: Int) -> T
) : FeaturedMatrix<T> {
constructor(rowNum: Int, colNum: Int, vararg features: MatrixFeature, generator: (i: Int, j: Int) -> T) : this(
rowNum,
colNum,
setOf(*features),
generator
)
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
override fun get(i: Int, j: Int): T = generator(i, j)

View File

@ -1,5 +1,6 @@
package scientifik.kmath.operations
import kotlin.math.abs
import kotlin.math.pow
/**
@ -31,79 +32,152 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
/**
* A field for double without boxing. Does not produce appropriate field element
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
override val zero: Double = 0.0
override fun add(a: Double, b: Double): Double = a + b
override fun multiply(a: Double, b: Double): Double = a * b
override fun multiply(a: Double, k: Number): Double = a * k.toDouble()
override inline fun add(a: Double, b: Double) = a + b
override inline fun multiply(a: Double, b: Double) = a * b
override inline fun multiply(a: Double, k: Number) = a * k.toDouble()
override val one: Double = 1.0
override fun divide(a: Double, b: Double): Double = a / b
override inline fun divide(a: Double, b: Double) = a / b
override fun sin(arg: Double): Double = kotlin.math.sin(arg)
override fun cos(arg: Double): Double = kotlin.math.cos(arg)
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
override fun power(arg: Double, pow: Number): Double = arg.pow(pow.toDouble())
override inline fun power(arg: Double, pow: Number) = arg.pow(pow.toDouble())
override fun exp(arg: Double): Double = kotlin.math.exp(arg)
override fun ln(arg: Double): Double = kotlin.math.ln(arg)
override inline fun exp(arg: Double) = kotlin.math.exp(arg)
override inline fun ln(arg: Double) = kotlin.math.ln(arg)
override fun norm(arg: Double): Double = kotlin.math.abs(arg)
override inline fun norm(arg: Double) = abs(arg)
override fun Double.unaryMinus(): Double = -this
override inline fun Double.unaryMinus() = -this
override fun Double.minus(b: Double): Double = this - b
override inline fun Double.plus(b: Double) = this + b
override inline fun Double.minus(b: Double) = this - b
override inline fun Double.times(b: Double) = this * b
override inline fun Double.div(b: Double) = this / b
}
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : Field<Float>, ExtendedFieldOperations<Float>, Norm<Float, Float> {
override val zero: Float = 0f
override inline fun add(a: Float, b: Float) = a + b
override inline fun multiply(a: Float, b: Float) = a * b
override inline fun multiply(a: Float, k: Number) = a * k.toFloat()
override val one: Float = 1f
override inline fun divide(a: Float, b: Float) = a / b
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
override inline fun exp(arg: Float) = kotlin.math.exp(arg)
override inline fun ln(arg: Float) = kotlin.math.ln(arg)
override inline fun norm(arg: Float) = abs(arg)
override inline fun Float.unaryMinus() = -this
override inline fun Float.plus(b: Float) = this + b
override inline fun Float.minus(b: Float) = this - b
override inline fun Float.times(b: Float) = this * b
override inline fun Float.div(b: Float) = this / b
}
/**
* A field for [Int] without boxing. Does not produce corresponding field element
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int = 0
override fun add(a: Int, b: Int): Int = a + b
override fun multiply(a: Int, b: Int): Int = a * b
override fun multiply(a: Int, k: Number): Int = (k * a)
override inline fun add(a: Int, b: Int) = a + b
override inline fun multiply(a: Int, b: Int) = a * b
override inline fun multiply(a: Int, k: Number) = (k * a)
override val one: Int = 1
override fun norm(arg: Int): Int = arg
override inline fun norm(arg: Int) = abs(arg)
override inline fun Int.unaryMinus() = -this
override inline fun Int.plus(b: Int): Int = this + b
override inline fun Int.minus(b: Int): Int = this - b
override inline fun Int.times(b: Int): Int = this * b
}
/**
* A field for [Short] without boxing. Does not produce appropriate field element
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short = 0
override fun add(a: Short, b: Short): Short = (a + b).toShort()
override fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override fun multiply(a: Short, k: Number): Short = (a * k)
override inline fun add(a: Short, b: Short) = (a + b).toShort()
override inline fun multiply(a: Short, b: Short) = (a * b).toShort()
override inline fun multiply(a: Short, k: Number) = (a * k)
override val one: Short = 1
override fun norm(arg: Short): Short = arg
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus() = (-this).toShort()
override inline fun Short.plus(b: Short) = (this + b).toShort()
override inline fun Short.minus(b: Short) = (this - b).toShort()
override inline fun Short.times(b: Short) = (this * b).toShort()
}
/**
* A field for [Byte] values
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte = 0
override fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
override fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override fun multiply(a: Byte, k: Number): Byte = (a * k)
override inline fun add(a: Byte, b: Byte) = (a + b).toByte()
override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte()
override inline fun multiply(a: Byte, k: Number) = (a * k)
override val one: Byte = 1
override fun norm(arg: Byte): Byte = arg
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus() = (-this).toByte()
override inline fun Byte.plus(b: Byte) = (this + b).toByte()
override inline fun Byte.minus(b: Byte) = (this - b).toByte()
override inline fun Byte.times(b: Byte) = (this * b).toByte()
}
/**
* A field for [Long] values
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long = 0
override fun add(a: Long, b: Long): Long = (a + b)
override fun multiply(a: Long, b: Long): Long = (a * b)
override fun multiply(a: Long, k: Number): Long = (a * k)
override inline fun add(a: Long, b: Long) = (a + b)
override inline fun multiply(a: Long, b: Long) = (a * b)
override inline fun multiply(a: Long, k: Number) = (a * k)
override val one: Long = 1
override fun norm(arg: Long): Long = arg
override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus() = (-this)
override inline fun Long.plus(b: Long) = (this + b)
override inline fun Long.minus(b: Long) = (this - b)
override inline fun Long.times(b: Long) = (this * b)
}

View File

@ -16,21 +16,28 @@ class RealLUSolverTest {
}
@Test
fun testInvert() {
fun testDecomposition() {
val matrix = Matrix.square(
3.0, 1.0,
1.0, 3.0
)
val decomposition = MatrixContext.real.lup(matrix)
MatrixContext.real.run {
val lup = lup(matrix)
//Check determinant
assertEquals(8.0, decomposition.determinant)
//Check determinant
assertEquals(8.0, lup.determinant)
//Check decomposition
with(MatrixContext.real) {
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
assertEquals(lup.p dot matrix, lup.l dot lup.u)
}
}
@Test
fun testInvert() {
val matrix = Matrix.square(
3.0, 1.0,
1.0, 3.0
)
val inverted = MatrixContext.real.inverse(matrix)

View File

@ -1,13 +1,13 @@
plugins {
kotlin("multiplatform")
id("kotlinx-atomicfu") version "0.12.1"
id("kotlinx-atomicfu") version "0.12.4"
}
val atomicfuVersion: String by rootProject.extra
kotlin {
jvm ()
//js()
js()
sourceSets {
val commonMain by getting {
@ -17,33 +17,16 @@ kotlin {
compileOnly("org.jetbrains.kotlinx:atomicfu-common:$atomicfuVersion")
}
}
val commonTest by getting {
dependencies {
implementation(kotlin("test-common"))
implementation(kotlin("test-annotations-common"))
}
}
val jvmMain by getting {
dependencies {
compileOnly("org.jetbrains.kotlinx:atomicfu:$atomicfuVersion")
}
}
val jvmTest by getting {
val jsMain by getting {
dependencies {
implementation(kotlin("test"))
implementation(kotlin("test-junit"))
compileOnly("org.jetbrains.kotlinx:atomicfu-js:$atomicfuVersion")
}
}
// val jsMain by getting {
// dependencies {
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:$atomicfuVersion")
// }
// }
// val jsTest by getting {
// dependencies {
// implementation(kotlin("test-js"))
// }
// }
}
}