Fixed LUP solver implementation
This commit is contained in:
parent
271d7886df
commit
b4e8b9050f
@ -1,52 +1,28 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
|
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
|
||||||
|
|
||||||
/**
|
|
||||||
* Matrix LUP decomposition
|
|
||||||
*/
|
|
||||||
interface LUPDecompositionFeature<T : Any> : DeterminantFeature<T> {
|
|
||||||
/**
|
|
||||||
* A reference to L-matrix
|
|
||||||
*/
|
|
||||||
val l: Matrix<T>
|
|
||||||
/**
|
|
||||||
* A reference to u-matrix
|
|
||||||
*/
|
|
||||||
val u: Matrix<T>
|
|
||||||
/**
|
|
||||||
* Pivoting points for each row
|
|
||||||
*/
|
|
||||||
val pivot: IntArray
|
|
||||||
/**
|
|
||||||
* Permutation matrix based on [pivot]
|
|
||||||
*/
|
|
||||||
val p: Matrix<T>
|
|
||||||
}
|
|
||||||
|
|
||||||
|
class LUPDecomposition<T : Comparable<T>>(
|
||||||
private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
private val elementContext: Ring<T>,
|
||||||
val context: R,
|
private val lu: NDStructure<T>,
|
||||||
val lu: NDStructure<T>,
|
val pivot: IntArray,
|
||||||
override val pivot: IntArray,
|
|
||||||
private val even: Boolean
|
private val even: Boolean
|
||||||
) : LUPDecompositionFeature<T> {
|
) : DeterminantFeature<T> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the matrix L of the decomposition.
|
* Returns the matrix L of the decomposition.
|
||||||
*
|
*
|
||||||
* L is a lower-triangular matrix
|
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
||||||
* @return the L matrix (or null if decomposed matrix is singular)
|
|
||||||
*/
|
*/
|
||||||
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||||
when {
|
when {
|
||||||
j < i -> lu[i, j]
|
j < i -> lu[i, j]
|
||||||
j == i -> context.one
|
j == i -> elementContext.one
|
||||||
else -> context.zero
|
else -> elementContext.zero
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,26 +30,21 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
|||||||
/**
|
/**
|
||||||
* Returns the matrix U of the decomposition.
|
* Returns the matrix U of the decomposition.
|
||||||
*
|
*
|
||||||
* U is an upper-triangular matrix
|
* U is an upper-triangular matrix including the diagonal
|
||||||
* @return the U matrix (or null if decomposed matrix is singular)
|
|
||||||
*/
|
*/
|
||||||
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||||
if (j >= i) lu[i, j] else context.zero
|
if (j >= i) lu[i, j] else elementContext.zero
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the P rows permutation matrix.
|
* Returns the P rows permutation matrix.
|
||||||
*
|
*
|
||||||
* P is a sparse matrix with exactly one element set to 1.0 in
|
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
||||||
* each row and each column, all other elements being set to 0.0.
|
* each row and each column, all other elements being set to [Ring.zero].
|
||||||
*
|
|
||||||
* The positions of the 1 elements are given by the [ pivot permutation vector][.getPivot].
|
|
||||||
* @return the P rows permutation matrix (or null if decomposed matrix is singular)
|
|
||||||
* @see .getPivot
|
|
||||||
*/
|
*/
|
||||||
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||||
if (j == pivot[i]) context.one else context.zero
|
if (j == pivot[i]) elementContext.one else elementContext.zero
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +53,7 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
|||||||
* @return determinant of the matrix
|
* @return determinant of the matrix
|
||||||
*/
|
*/
|
||||||
override val determinant: T by lazy {
|
override val determinant: T by lazy {
|
||||||
with(context) {
|
with(elementContext) {
|
||||||
(0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] }
|
(0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -90,14 +61,12 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||||
* Implementation based on Apache common-maths LU-decomposition
|
override val context: MatrixContext<T, F>,
|
||||||
*/
|
|
||||||
class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
|
||||||
val context: F,
|
|
||||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||||
val singularityCheck: (T) -> Boolean
|
val singularityCheck: (T) -> Boolean
|
||||||
) {
|
) : LinearSolver<T, F> {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In-place transformation for [MutableNDStructure], using given transformation for each element
|
* In-place transformation for [MutableNDStructure], using given transformation for each element
|
||||||
@ -106,9 +75,10 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
|||||||
this[intArrayOf(i, j)] = value
|
this[intArrayOf(i, j)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun abs(value: T) = if (value > context.zero) value else with(context) { -value }
|
private fun abs(value: T) =
|
||||||
|
if (value > context.elementContext.zero) value else with(context.elementContext) { -value }
|
||||||
|
|
||||||
fun decompose(matrix: Matrix<T>): LUPDecompositionFeature<T> {
|
fun buildDecomposition(matrix: Matrix<T>): LUPDecomposition<T> {
|
||||||
if (matrix.rowNum != matrix.colNum) {
|
if (matrix.rowNum != matrix.colNum) {
|
||||||
error("LU decomposition supports only square matrices")
|
error("LU decomposition supports only square matrices")
|
||||||
}
|
}
|
||||||
@ -121,7 +91,7 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
with(context) {
|
with(context.elementContext) {
|
||||||
// Initialize permutation array and parity
|
// Initialize permutation array and parity
|
||||||
for (row in 0 until m) {
|
for (row in 0 until m) {
|
||||||
pivot[row] = row
|
pivot[row] = row
|
||||||
@ -174,23 +144,22 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
|||||||
lu[row, col] = lu[row, col] / luDiag
|
lu[row, col] = lu[row, col] / luDiag
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return LUPDecomposition(context, lu, pivot, even)
|
return LUPDecomposition(context.elementContext, lu, pivot, even)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
/**
|
||||||
val real: LUPDecompositionBuilder<Double, RealField> = LUPDecompositionBuilder(RealField) { it < 1e-11 }
|
* Produce a matrix with added decomposition feature
|
||||||
|
*/
|
||||||
|
fun decompose(matrix: Matrix<T>): Matrix<T> {
|
||||||
|
if (matrix.hasFeature<LUPDecomposition<*>>()) {
|
||||||
|
return matrix
|
||||||
|
} else {
|
||||||
|
val decomposition = buildDecomposition(matrix)
|
||||||
|
return VirtualMatrix.wrap(matrix, decomposition)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
|
||||||
override val context: MatrixContext<T, F>,
|
|
||||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
|
||||||
val singularityCheck: (T) -> Boolean
|
|
||||||
) : LinearSolver<T, F> {
|
|
||||||
|
|
||||||
|
|
||||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
if (b.rowNum != a.colNum) {
|
if (b.rowNum != a.colNum) {
|
||||||
@ -198,10 +167,7 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Use existing decomposition if it is provided by matrix
|
// Use existing decomposition if it is provided by matrix
|
||||||
@Suppress("UNCHECKED_CAST")
|
val decomposition = a.getFeature() ?: buildDecomposition(a)
|
||||||
val decomposition = a.features.find { it is LUPDecompositionFeature<*> }?.let {
|
|
||||||
it as LUPDecompositionFeature<T>
|
|
||||||
} ?: LUPDecompositionBuilder(context.elementContext, bufferFactory, singularityCheck).decompose(a)
|
|
||||||
|
|
||||||
with(decomposition) {
|
with(decomposition) {
|
||||||
with(context.elementContext) {
|
with(context.elementContext) {
|
||||||
@ -219,12 +185,14 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Solve UX = Y
|
// Solve UX = Y
|
||||||
for (col in a.rowNum - 1 downTo 0) {
|
for (col in a.rowNum - 1 downTo 0) {
|
||||||
|
for(j in 0 until b.colNum){
|
||||||
|
bp[col,j]/= u[col,col]
|
||||||
|
}
|
||||||
for (i in 0 until col) {
|
for (i in 0 until col) {
|
||||||
for (j in 0 until b.colNum) {
|
for (j in 0 until b.colNum) {
|
||||||
bp[i, j] -= bp[col, j] / u[col, col] * u[i, col]
|
bp[i, j] -= bp[col, j] * u[i, col]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -5,6 +5,7 @@ import scientifik.kmath.operations.Ring
|
|||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory
|
import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory
|
||||||
import scientifik.kmath.structures.Buffer.Companion.boxing
|
import scientifik.kmath.structures.Buffer.Companion.boxing
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
|
||||||
interface MatrixContext<T : Any, R : Ring<T>> {
|
interface MatrixContext<T : Any, R : Ring<T>> {
|
||||||
@ -146,43 +147,56 @@ interface Matrix<T : Any> : NDStructure<T> {
|
|||||||
companion object {
|
companion object {
|
||||||
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||||
MatrixContext.real.produce(rows, columns, initializer)
|
MatrixContext.real.produce(rows, columns, initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a square matrix from given elements.
|
||||||
|
*/
|
||||||
|
fun <T : Any> build(vararg elements: T): Matrix<T> {
|
||||||
|
val buffer = elements.asBuffer()
|
||||||
|
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||||
|
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
||||||
|
val structure = Mutable2DStructure(size, size, buffer)
|
||||||
|
return StructureMatrix(structure)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if matrix has the given feature class
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { T::class.isInstance(it) } != null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j->
|
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
||||||
if (i == j) elementContext.one else elementContext.zero
|
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||||
}
|
if (i == j) elementContext.one else elementContext.zero
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A virtual matrix of zeroes
|
* A virtual matrix of zeroes
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j->
|
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||||
elementContext.zero
|
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||||
}
|
elementContext.zero
|
||||||
|
}
|
||||||
|
|
||||||
|
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
||||||
inline class TransposedMatrix<T : Any>(val original: Matrix<T>) : Matrix<T> {
|
|
||||||
override val rowNum: Int get() = original.colNum
|
|
||||||
override val colNum: Int get() = original.rowNum
|
|
||||||
override val features: Set<MatrixFeature> get() = emptySet() //TODO retain some features
|
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = original[j, i]
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
|
||||||
original.elements().map { (key, value) -> intArrayOf(key[1], key[0]) to value }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> Matrix<T>.transpose(): Matrix<T> {
|
fun <T : Any, R : Ring<T>> Matrix<T>.transpose(): Matrix<T> {
|
||||||
return if (this is TransposedMatrix) {
|
return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||||
original
|
this.colNum,
|
||||||
} else {
|
this.rowNum,
|
||||||
TransposedMatrix(this)
|
setOf(TransposedFeature(this))
|
||||||
}
|
) { i, j -> get(j, i) }
|
||||||
}
|
}
|
@ -1,10 +1,7 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.BufferFactory
|
import scientifik.kmath.structures.*
|
||||||
import scientifik.kmath.structures.NDStructure
|
|
||||||
import scientifik.kmath.structures.get
|
|
||||||
import scientifik.kmath.structures.ndStructure
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Basic implementation of Matrix space based on [NDStructure]
|
* Basic implementation of Matrix space based on [NDStructure]
|
||||||
@ -24,7 +21,7 @@ class StructureMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
class StructureMatrix<T : Any>(
|
class StructureMatrix<T : Any>(
|
||||||
val structure: NDStructure<T>,
|
val structure: NDStructure<out T>,
|
||||||
override val features: Set<MatrixFeature> = emptySet()
|
override val features: Set<MatrixFeature> = emptySet()
|
||||||
) : Matrix<T> {
|
) : Matrix<T> {
|
||||||
|
|
||||||
@ -51,8 +48,7 @@ class StructureMatrix<T : Any>(
|
|||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
return when (other) {
|
return when (other) {
|
||||||
is StructureMatrix<*> -> return this.structure == other.structure
|
is NDStructure<*> -> return NDStructure.equals(this, other)
|
||||||
is Matrix<*> -> elements().all { (index, value) -> value == other[index] }
|
|
||||||
else -> false
|
else -> false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -63,5 +59,16 @@ class StructureMatrix<T : Any>(
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
return if (rowNum <= 5 && colNum <= 5) {
|
||||||
|
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
||||||
|
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") {
|
||||||
|
it.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
@ -27,4 +27,18 @@ class VirtualMatrix<T : Any>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* Wrap a matrix adding additional features to it
|
||||||
|
*/
|
||||||
|
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): Matrix<T> {
|
||||||
|
return if (matrix is VirtualMatrix) {
|
||||||
|
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
||||||
|
} else {
|
||||||
|
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j ->
|
||||||
|
matrix[i, j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -11,6 +11,16 @@ interface NDStructure<T> {
|
|||||||
operator fun get(index: IntArray): T
|
operator fun get(index: IntArray): T
|
||||||
|
|
||||||
fun elements(): Sequence<Pair<IntArray, T>>
|
fun elements(): Sequence<Pair<IntArray, T>>
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||||
|
return when {
|
||||||
|
st1 === st2 -> true
|
||||||
|
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(st2.buffer)
|
||||||
|
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
||||||
@ -140,7 +150,7 @@ interface NDBuffer<T> : NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Boxing generic [NDStructure]
|
* Boxing generic [NDStructure]
|
||||||
*/
|
*/
|
||||||
data class BufferNDStructure<T>(
|
class BufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : NDBuffer<T> {
|
) : NDBuffer<T> {
|
||||||
@ -160,7 +170,7 @@ data class BufferNDStructure<T>(
|
|||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
return when {
|
return when {
|
||||||
this === other -> true
|
this === other -> true
|
||||||
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer)
|
other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer)
|
||||||
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
|
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
|
||||||
else -> false
|
else -> false
|
||||||
}
|
}
|
||||||
@ -193,7 +203,11 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
*
|
*
|
||||||
* Strides should be reused if possible
|
* Strides should be reused if possible
|
||||||
*/
|
*/
|
||||||
fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, initializer: (IntArray) -> T) =
|
fun <T> ndStructure(
|
||||||
|
strides: Strides,
|
||||||
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
|
initializer: (IntArray) -> T
|
||||||
|
) =
|
||||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -202,7 +216,11 @@ fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = Buffer.C
|
|||||||
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
fun <T> ndStructure(shape: IntArray, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, initializer: (IntArray) -> T) =
|
fun <T> ndStructure(
|
||||||
|
shape: IntArray,
|
||||||
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
|
initializer: (IntArray) -> T
|
||||||
|
) =
|
||||||
ndStructure(DefaultStrides(shape), bufferFactory, initializer)
|
ndStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||||
|
|
||||||
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||||
|
@ -11,10 +11,31 @@ class RealLUSolverTest {
|
|||||||
assertEquals(matrix, inverted)
|
assertEquals(matrix, inverted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Test
|
@Test
|
||||||
// fun testInvert() {
|
fun testInvert() {
|
||||||
// val matrix = realMatrix(2,2){}
|
val matrix = Matrix.build(
|
||||||
// val inverted = LUSolver.real.inverse(matrix)
|
3.0, 1.0,
|
||||||
// assertEquals(matrix, inverted)
|
1.0, 3.0
|
||||||
// }
|
)
|
||||||
|
|
||||||
|
val decomposed = LUSolver.real.decompose(matrix)
|
||||||
|
val decomposition = decomposed.getFeature<LUPDecomposition<Double>>()!!
|
||||||
|
|
||||||
|
//Check determinant
|
||||||
|
assertEquals(8.0, decomposition.determinant)
|
||||||
|
|
||||||
|
//Check decomposition
|
||||||
|
with(MatrixContext.real) {
|
||||||
|
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
||||||
|
}
|
||||||
|
|
||||||
|
val inverted = LUSolver.real.inverse(decomposed)
|
||||||
|
|
||||||
|
val expected = Matrix.build(
|
||||||
|
0.375, -0.125,
|
||||||
|
-0.125, 0.375
|
||||||
|
)
|
||||||
|
|
||||||
|
assertEquals(expected, inverted)
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user