Optimizing decomposition performance
This commit is contained in:
parent
f1b1010c4d
commit
bbc012d8cd
@ -9,7 +9,7 @@ import kotlin.system.measureTimeMillis
|
|||||||
|
|
||||||
@ExperimentalContracts
|
@ExperimentalContracts
|
||||||
fun main() {
|
fun main() {
|
||||||
val random = Random(12224)
|
val random = Random(1224)
|
||||||
val dim = 100
|
val dim = 100
|
||||||
//creating invertible matrix
|
//creating invertible matrix
|
||||||
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import com.moowork.gradle.node.NodeExtension
|
import com.moowork.gradle.node.NodeExtension
|
||||||
import com.moowork.gradle.node.npm.NpmTask
|
import com.moowork.gradle.node.npm.NpmTask
|
||||||
import com.moowork.gradle.node.task.NodeTask
|
import com.moowork.gradle.node.task.NodeTask
|
||||||
|
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
|
||||||
|
import org.jetbrains.kotlin.gradle.tasks.Kotlin2JsCompile
|
||||||
|
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
||||||
|
|
||||||
buildscript {
|
buildscript {
|
||||||
val kotlinVersion: String by rootProject.extra("1.3.21")
|
val kotlinVersion: String by rootProject.extra("1.3.30")
|
||||||
val ioVersion: String by rootProject.extra("0.1.5")
|
val ioVersion: String by rootProject.extra("0.1.5")
|
||||||
val coroutinesVersion: String by rootProject.extra("1.1.1")
|
val coroutinesVersion: String by rootProject.extra("1.1.1")
|
||||||
val atomicfuVersion: String by rootProject.extra("0.12.1")
|
val atomicfuVersion: String by rootProject.extra("0.12.1")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
|
|
||||||
@ -23,6 +24,18 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||||
|
|
||||||
|
override val elementContext = RealField
|
||||||
|
|
||||||
|
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||||
|
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
|
return BufferMatrix(rows, columns, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer)
|
||||||
|
}
|
||||||
|
|
||||||
class BufferMatrix<T : Any>(
|
class BufferMatrix<T : Any>(
|
||||||
override val rowNum: Int,
|
override val rowNum: Int,
|
||||||
override val colNum: Int,
|
override val colNum: Int,
|
||||||
|
@ -3,22 +3,22 @@ package scientifik.kmath.linear
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.BufferAccessor2D
|
||||||
import kotlin.contracts.ExperimentalContracts
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.contracts.InvocationKind
|
import scientifik.kmath.structures.Structure2D
|
||||||
import kotlin.contracts.contract
|
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Common implementation of [LUPDecompositionFeature]
|
* Common implementation of [LUPDecompositionFeature]
|
||||||
*/
|
*/
|
||||||
class LUPDecomposition<T : Any>(
|
class LUPDecomposition<T : Any>(
|
||||||
private val elementContext: Ring<T>,
|
val context: GenericMatrixContext<T, out Field<T>>,
|
||||||
val lu: Structure2D<T>,
|
val lu: Structure2D<T>,
|
||||||
val pivot: IntArray,
|
val pivot: IntArray,
|
||||||
private val even: Boolean
|
private val even: Boolean
|
||||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||||
|
|
||||||
|
val elementContext get() = context.elementContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the matrix L of the decomposition.
|
* Returns the matrix L of the decomposition.
|
||||||
*
|
*
|
||||||
@ -66,102 +66,14 @@ class LUPDecomposition<T : Any>(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
internal open class BufferAccessor<T : Any>(
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
||||||
val type: KClass<T>,
|
|
||||||
val field: Field<T>,
|
|
||||||
val rowNum: Int,
|
|
||||||
val colNum: Int
|
|
||||||
) {
|
|
||||||
open operator fun MutableBuffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
|
|
||||||
open operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
|
||||||
set(i + colNum * j, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun create(init: (i: Int, j: Int) -> T) =
|
|
||||||
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
|
||||||
|
|
||||||
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] }
|
|
||||||
|
|
||||||
//TODO optimize wrapper
|
|
||||||
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
|
||||||
NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D()
|
|
||||||
|
|
||||||
open fun MutableBuffer<T>.innerProduct(row: Int, col: Int, max: Int): T {
|
|
||||||
var sum = field.zero
|
|
||||||
field.run {
|
|
||||||
for (i in 0 until max) {
|
|
||||||
sum += get(row, i) * get(i, col)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sum
|
|
||||||
}
|
|
||||||
|
|
||||||
open fun MutableBuffer<T>.divideInPlace(i: Int, j: Int, factor: T) {
|
|
||||||
field.run { set(i, j, get(i, j) / factor) }
|
|
||||||
}
|
|
||||||
|
|
||||||
open fun MutableBuffer<T>.subtractInPlace(i: Int, j: Int, lu: MutableBuffer<T>, col: Int) {
|
|
||||||
field.run {
|
|
||||||
set(i, j, get(i, j) - get(col, j) * lu[i, col])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Specialized LU operations for Doubles
|
|
||||||
*/
|
|
||||||
private class RealBufferAccessor(rowNum: Int, colNum: Int) :
|
|
||||||
BufferAccessor<Double>(Double::class, RealField, rowNum, colNum) {
|
|
||||||
override fun MutableBuffer<Double>.get(i: Int, j: Int) = (this as DoubleBuffer).array[i + colNum * j]
|
|
||||||
override fun MutableBuffer<Double>.set(i: Int, j: Int, value: Double) {
|
|
||||||
(this as DoubleBuffer).array[i + colNum * j] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun MutableBuffer<Double>.innerProduct(row: Int, col: Int, max: Int): Double {
|
|
||||||
var sum = 0.0
|
|
||||||
for (i in 0 until max) {
|
|
||||||
sum += get(row, i) * get(i, col)
|
|
||||||
}
|
|
||||||
return sum
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun MutableBuffer<Double>.divideInPlace(i: Int, j: Int, factor: Double) {
|
|
||||||
set(i, j, get(i, j) / factor)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun MutableBuffer<Double>.subtractInPlace(i: Int, j: Int, lu: MutableBuffer<Double>, col: Int) {
|
|
||||||
set(i, j, get(i, j) - get(col, j) * lu[i, col])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@ExperimentalContracts
|
|
||||||
private inline fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.withAccessor(
|
|
||||||
type: KClass<T>,
|
|
||||||
rowNum: Int,
|
|
||||||
colNum: Int,
|
|
||||||
block: BufferAccessor<T>.() -> Unit
|
|
||||||
) {
|
|
||||||
contract {
|
|
||||||
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
|
|
||||||
}
|
|
||||||
if (elementContext == RealField) {
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
RealBufferAccessor(rowNum, colNum) as BufferAccessor<T>
|
|
||||||
} else {
|
|
||||||
BufferAccessor(type, elementContext, rowNum, colNum)
|
|
||||||
}.run(block)
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
|
||||||
if (value > elementContext.zero) value else with(elementContext) { -value }
|
if (value > elementContext.zero) value else with(elementContext) { -value }
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a lup decomposition of generic matrix
|
* Create a lup decomposition of generic matrix
|
||||||
*/
|
*/
|
||||||
@ExperimentalContracts
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|
||||||
type: KClass<T>,
|
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
checkSingular: (T) -> Boolean
|
checkSingular: (T) -> Boolean
|
||||||
): LUPDecomposition<T> {
|
): LUPDecomposition<T> {
|
||||||
@ -169,11 +81,12 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
error("LU decomposition supports only square matrices")
|
error("LU decomposition supports only square matrices")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
val m = matrix.colNum
|
val m = matrix.colNum
|
||||||
val pivot = IntArray(matrix.rowNum)
|
val pivot = IntArray(matrix.rowNum)
|
||||||
|
|
||||||
withAccessor(type, matrix.rowNum, matrix.colNum) {
|
//TODO just waits for KEEP-176
|
||||||
|
BufferAccessor2D(T::class, matrix.rowNum, matrix.colNum).run {
|
||||||
|
elementContext.run {
|
||||||
|
|
||||||
val lu = create(matrix)
|
val lu = create(matrix)
|
||||||
|
|
||||||
@ -183,32 +96,56 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
}
|
}
|
||||||
var even = true
|
var even = true
|
||||||
|
|
||||||
|
// Initialize permutation array and parity
|
||||||
|
for (row in 0 until m) {
|
||||||
|
pivot[row] = row
|
||||||
|
}
|
||||||
|
var singular = false
|
||||||
|
|
||||||
// Loop over columns
|
// Loop over columns
|
||||||
for (col in 0 until m) {
|
for (col in 0 until m) {
|
||||||
|
|
||||||
// upper
|
// upper
|
||||||
for (row in 0 until col) {
|
for (row in 0 until col) {
|
||||||
val sum = lu.innerProduct(row, col, row)
|
val luRow = lu.row(row)
|
||||||
lu[row, col] = field.run { lu[row, col] - sum }
|
var sum = luRow[col]
|
||||||
|
for (i in 0 until row) {
|
||||||
|
sum -= luRow[i] * lu[i, col]
|
||||||
|
}
|
||||||
|
luRow[col] = sum
|
||||||
}
|
}
|
||||||
|
|
||||||
// lower
|
// lower
|
||||||
val max = (col until m).maxBy { row ->
|
var max = col // permutation row
|
||||||
val sum = lu.innerProduct(row, col, col)
|
var largest = -one
|
||||||
lu[row, col] = field.run { lu[row, col] - sum }
|
for (row in col until m) {
|
||||||
abs(sum)
|
val luRow = lu.row(row)
|
||||||
} ?: col
|
var sum = luRow[col]
|
||||||
|
for (i in 0 until col) {
|
||||||
|
sum -= luRow[i] * lu[i, col]
|
||||||
|
}
|
||||||
|
luRow[col] = sum
|
||||||
|
|
||||||
|
// maintain best permutation choice
|
||||||
|
if (abs(sum) > largest) {
|
||||||
|
largest = abs(sum)
|
||||||
|
max = row
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Singularity check
|
// Singularity check
|
||||||
if (checkSingular(lu[max, col])) {
|
if (checkSingular(abs(lu[max, col]))) {
|
||||||
error("Singular matrix")
|
error("The matrix is singular")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pivot if necessary
|
// Pivot if necessary
|
||||||
if (max != col) {
|
if (max != col) {
|
||||||
|
val luMax = lu.row(max)
|
||||||
|
val luCol = lu.row(col)
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
lu[max, i] = lu[col, i]
|
val tmp = luMax[i]
|
||||||
lu[col, i] = lu[max, i]
|
luMax[i] = luCol[i]
|
||||||
|
luCol[i] = tmp
|
||||||
}
|
}
|
||||||
val temp = pivot[max]
|
val temp = pivot[max]
|
||||||
pivot[max] = pivot[col]
|
pivot[max] = pivot[col]
|
||||||
@ -219,83 +156,91 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
// Divide the lower elements by the "winning" diagonal elt.
|
// Divide the lower elements by the "winning" diagonal elt.
|
||||||
val luDiag = lu[col, col]
|
val luDiag = lu[col, col]
|
||||||
for (row in col + 1 until m) {
|
for (row in col + 1 until m) {
|
||||||
lu.divideInPlace(row, col, luDiag)
|
lu[row, col] /= luDiag
|
||||||
//lu[row, col] = lu[row, col] / luDiag
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return LUPDecomposition(elementContext, lu.collect(), pivot, even)
|
|
||||||
|
return LUPDecomposition(this@lup, lu.collect(), pivot, even)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ExperimentalContracts
|
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(matrix) { it < 1e-11 }
|
||||||
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 }
|
|
||||||
|
|
||||||
/**
|
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> {
|
||||||
* Solve a linear equation **a*x = b**
|
|
||||||
*/
|
if (matrix.rowNum != pivot.size) {
|
||||||
@ExperimentalContracts
|
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
|
||||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
|
|
||||||
type: KClass<T>,
|
|
||||||
a: Matrix<T>,
|
|
||||||
b: Matrix<T>,
|
|
||||||
checkSingular: (T) -> Boolean
|
|
||||||
): Matrix<T> {
|
|
||||||
if (b.rowNum != a.colNum) {
|
|
||||||
error("Matrix dimension mismatch. Expected ${a.rowNum}, but got ${b.colNum}")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use existing decomposition if it is provided by matrix
|
BufferAccessor2D(T::class, matrix.rowNum, matrix.colNum).run {
|
||||||
val decomposition = a.getFeature() ?: lup(type, a, checkSingular)
|
elementContext.run {
|
||||||
|
|
||||||
withAccessor(type, a.rowNum, a.colNum) {
|
val lu = create{i,j-> this@solve.lu[i,j]}
|
||||||
|
|
||||||
val lu = create(decomposition.lu)
|
|
||||||
|
|
||||||
// Apply permutations to b
|
// Apply permutations to b
|
||||||
val bp = create { i, j ->
|
val bp = create { i, j -> zero }
|
||||||
b[decomposition.pivot[i], j]
|
for (row in 0 until pivot.size) {
|
||||||
|
val bpRow = bp.row(row)
|
||||||
|
val pRow = pivot[row]
|
||||||
|
for (col in 0 until matrix.colNum) {
|
||||||
|
bpRow[col] = matrix[pRow, col]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve LY = b
|
// Solve LY = b
|
||||||
for (col in 0 until a.rowNum) {
|
for (col in 0 until pivot.size) {
|
||||||
for (i in col + 1 until a.rowNum) {
|
val bpCol = bp.row(col)
|
||||||
for (j in 0 until b.colNum) {
|
for (i in col + 1 until pivot.size) {
|
||||||
bp.subtractInPlace(i, j, lu, col)
|
val bpI = bp.row(i)
|
||||||
//bp[i, j] -= bp[col, j] * lu[i, col]
|
val luICol = lu[i, col]
|
||||||
|
for (j in 0 until matrix.colNum) {
|
||||||
|
bpI[j] -= bpCol[j] * luICol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve UX = Y
|
// Solve UX = Y
|
||||||
for (col in a.rowNum - 1 downTo 0) {
|
for (col in pivot.size - 1 downTo 0) {
|
||||||
|
val bpCol = bp.row(col)
|
||||||
val luDiag = lu[col, col]
|
val luDiag = lu[col, col]
|
||||||
for (j in 0 until b.colNum) {
|
for (j in 0 until matrix.colNum) {
|
||||||
bp.divideInPlace(col, j, luDiag)
|
bpCol[j] /= luDiag
|
||||||
//bp[col, j] /= lu[col, col]
|
|
||||||
}
|
}
|
||||||
for (i in 0 until col) {
|
for (i in 0 until col) {
|
||||||
for (j in 0 until b.colNum) {
|
val bpI = bp.row(i)
|
||||||
bp.subtractInPlace(i, j, lu, col)
|
val luICol = lu[i, col]
|
||||||
//bp[i, j] -= bp[col, j] * lu[i, col]
|
for (j in 0 until matrix.colNum) {
|
||||||
|
bpI[j] -= bpCol[j] * luICol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
|
||||||
return produce(a.rowNum, a.colNum) { i, j -> bp[i, j] }
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ExperimentalContracts
|
|
||||||
fun GenericMatrixContext<Double, RealField>.solve(a: Matrix<Double>, b: Matrix<Double>) =
|
|
||||||
solve(Double::class, a, b) { it < 1e-11 }
|
|
||||||
|
|
||||||
@ExperimentalContracts
|
/**
|
||||||
|
* Solve a linear equation **a*x = b**
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
|
||||||
|
a: Matrix<T>,
|
||||||
|
b: Matrix<T>,
|
||||||
|
crossinline checkSingular: (T) -> Boolean
|
||||||
|
): Matrix<T> {
|
||||||
|
// Use existing decomposition if it is provided by matrix
|
||||||
|
val decomposition = a.getFeature() ?: lup(a, checkSingular)
|
||||||
|
return decomposition.solve(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) =
|
||||||
|
solve(a, b) { it < 1e-11 }
|
||||||
|
|
||||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
noinline checkSingular: (T) -> Boolean
|
noinline checkSingular: (T) -> Boolean
|
||||||
) =
|
) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||||
solve(T::class, matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
|
||||||
|
|
||||||
@ExperimentalContracts
|
fun RealMatrixContext.inverse(matrix: Matrix<Double>) =
|
||||||
fun GenericMatrixContext<Double, RealField>.inverse(matrix: Matrix<Double>) =
|
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
||||||
inverse(matrix) { it < 1e-11 }
|
|
@ -1,6 +1,5 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.operations.SpaceOperations
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
import scientifik.kmath.operations.sum
|
import scientifik.kmath.operations.sum
|
||||||
@ -30,7 +29,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
|||||||
/**
|
/**
|
||||||
* Non-boxing double matrix
|
* Non-boxing double matrix
|
||||||
*/
|
*/
|
||||||
val real = BufferMatrixContext(RealField, Buffer.Companion::auto)
|
val real = RealMatrixContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structured matrix with custom buffer
|
* A structured matrix with custom buffer
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context that allows to operate on a [MutableBuffer] as on 2d array
|
||||||
|
*/
|
||||||
|
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
||||||
|
|
||||||
|
inline operator fun Buffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
|
||||||
|
|
||||||
|
inline operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
||||||
|
set(i + colNum * j, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun create(init: (i: Int, j: Int) -> T) =
|
||||||
|
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||||
|
|
||||||
|
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] }
|
||||||
|
|
||||||
|
//TODO optimize wrapper
|
||||||
|
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
||||||
|
NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D()
|
||||||
|
|
||||||
|
|
||||||
|
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
|
||||||
|
override val size: Int get() = colNum
|
||||||
|
|
||||||
|
override fun get(index: Int): T = buffer[rowIndex, index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: T) {
|
||||||
|
buffer[rowIndex, index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) }
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get row
|
||||||
|
*/
|
||||||
|
fun MutableBuffer<T>.row(i: Int) = Row(this, i)
|
||||||
|
}
|
@ -84,7 +84,7 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
MutableListBuffer(MutableList(size, initializer))
|
MutableListBuffer(MutableList(size, initializer))
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <T : Any> auto(type: KClass<T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||||
return when (type) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||||
|
Loading…
Reference in New Issue
Block a user