forked from kscience/kmath
Optimizing inversion performance
This commit is contained in:
parent
14f05eb1e1
commit
271e762a95
@ -15,19 +15,20 @@ fun main() {
|
|||||||
|
|
||||||
val n = 5000 // iterations
|
val n = 5000 // iterations
|
||||||
|
|
||||||
val solver = LUSolver.real
|
MatrixContext.real.run {
|
||||||
|
|
||||||
repeat(50) {
|
repeat(50) {
|
||||||
val res = solver.inverse(matrix)
|
val res = inverse(matrix)
|
||||||
}
|
|
||||||
|
|
||||||
val inverseTime = measureTimeMillis {
|
|
||||||
repeat(n) {
|
|
||||||
val res = solver.inverse(matrix)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
val inverseTime = measureTimeMillis {
|
||||||
|
repeat(n) {
|
||||||
|
val res = inverse(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||||
|
}
|
||||||
|
|
||||||
//commons-math
|
//commons-math
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
|
|
||||||
@ -17,6 +18,10 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class BufferMatrix<T : Any>(
|
class BufferMatrix<T : Any>(
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Common implementation of [LUPDecompositionFeature]
|
* Common implementation of [LUPDecompositionFeature]
|
||||||
*/
|
*/
|
||||||
class LUPDecomposition<T : Comparable<T>>(
|
class LUPDecomposition<T : Any>(
|
||||||
private val elementContext: Ring<T>,
|
private val elementContext: Ring<T>,
|
||||||
internal val lu: NDStructure<T>,
|
val lu: Structure2D<T>,
|
||||||
val pivot: IntArray,
|
val pivot: IntArray,
|
||||||
private val even: Boolean
|
private val even: Boolean
|
||||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||||
@ -62,146 +63,222 @@ class LUPDecomposition<T : Comparable<T>>(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
open class BufferAccessor<T : Any>(val type: KClass<T>, val field: Field<T>, val rowNum: Int, val colNum: Int) {
|
||||||
|
open operator fun MutableBuffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
|
||||||
|
open operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
||||||
|
set(i + colNum * j, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun create(init: (i: Int, j: Int) -> T) =
|
||||||
|
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||||
|
|
||||||
|
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] }
|
||||||
|
|
||||||
|
//TODO optimize wrapper
|
||||||
|
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
||||||
|
NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D()
|
||||||
|
|
||||||
|
open fun MutableBuffer<T>.innerProduct(row: Int, col: Int, max: Int): T {
|
||||||
|
var sum = field.zero
|
||||||
|
field.run {
|
||||||
|
for (i in 0 until max) {
|
||||||
|
sum += get(row, i) * get(i, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
open fun MutableBuffer<T>.divideInPlace(i: Int, j: Int, factor: T) {
|
||||||
|
field.run { set(i, j, get(i, j) / factor) }
|
||||||
|
}
|
||||||
|
|
||||||
|
open fun MutableBuffer<T>.subtractInPlace(i: Int, j: Int, lu: MutableBuffer<T>, col: Int) {
|
||||||
|
field.run {
|
||||||
|
set(i, j, get(i, j) - get(col, j) * lu[i, col])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Common implementation of LUP [LinearSolver] based on commons-math code
|
* Specialized LU operations for Doubles
|
||||||
*/
|
*/
|
||||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
class RealBufferAccessor(rowNum: Int, colNum: Int) : BufferAccessor<Double>(Double::class, RealField, rowNum, colNum) {
|
||||||
val context: GenericMatrixContext<T, F>,
|
override inline fun MutableBuffer<Double>.get(i: Int, j: Int) = (this as DoubleBuffer).array[i + colNum * j]
|
||||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
override inline fun MutableBuffer<Double>.set(i: Int, j: Int, value: Double) {
|
||||||
val singularityCheck: (T) -> Boolean
|
(this as DoubleBuffer).array[i + colNum * j] = value
|
||||||
) : LinearSolver<T> {
|
|
||||||
|
|
||||||
|
|
||||||
private fun abs(value: T) =
|
|
||||||
if (value > context.elementContext.zero) value else with(context.elementContext) { -value }
|
|
||||||
|
|
||||||
fun buildDecomposition(matrix: Matrix<T>): LUPDecomposition<T> {
|
|
||||||
if (matrix.rowNum != matrix.colNum) {
|
|
||||||
error("LU decomposition supports only square matrices")
|
|
||||||
}
|
|
||||||
|
|
||||||
val m = matrix.colNum
|
|
||||||
val pivot = IntArray(matrix.rowNum)
|
|
||||||
|
|
||||||
val lu = Mutable2DStructure.create(matrix.rowNum, matrix.colNum, bufferFactory) { i, j ->
|
|
||||||
matrix[i, j]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
with(context.elementContext) {
|
|
||||||
// Initialize permutation array and parity
|
|
||||||
for (row in 0 until m) {
|
|
||||||
pivot[row] = row
|
|
||||||
}
|
|
||||||
var even = true
|
|
||||||
|
|
||||||
// Loop over columns
|
|
||||||
for (col in 0 until m) {
|
|
||||||
|
|
||||||
// upper
|
|
||||||
for (row in 0 until col) {
|
|
||||||
var sum = lu[row, col]
|
|
||||||
for (i in 0 until row) {
|
|
||||||
sum -= lu[row, i] * lu[i, col]
|
|
||||||
}
|
|
||||||
lu[row, col] = sum
|
|
||||||
}
|
|
||||||
|
|
||||||
// lower
|
|
||||||
val max = (col until m).maxBy { row ->
|
|
||||||
var sum = lu[row, col]
|
|
||||||
for (i in 0 until col) {
|
|
||||||
sum -= lu[row, i] * lu[i, col]
|
|
||||||
}
|
|
||||||
lu[row, col] = sum
|
|
||||||
|
|
||||||
abs(sum)
|
|
||||||
} ?: col
|
|
||||||
|
|
||||||
// Singularity check
|
|
||||||
if (singularityCheck(lu[max, col])) {
|
|
||||||
error("Singular matrix")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pivot if necessary
|
|
||||||
if (max != col) {
|
|
||||||
for (i in 0 until m) {
|
|
||||||
lu[max, i] = lu[col, i]
|
|
||||||
lu[col, i] = lu[max, i]
|
|
||||||
}
|
|
||||||
val temp = pivot[max]
|
|
||||||
pivot[max] = pivot[col]
|
|
||||||
pivot[col] = temp
|
|
||||||
even = !even
|
|
||||||
}
|
|
||||||
|
|
||||||
// Divide the lower elements by the "winning" diagonal elt.
|
|
||||||
val luDiag = lu[col, col]
|
|
||||||
for (row in col + 1 until m) {
|
|
||||||
lu[row, col] = lu[row, col] / luDiag
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return LUPDecomposition(context.elementContext, lu, pivot, even)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
override fun MutableBuffer<Double>.innerProduct(row: Int, col: Int, max: Int): Double {
|
||||||
* Produce a matrix with added decomposition feature
|
var sum = 0.0
|
||||||
*/
|
for (i in 0 until max) {
|
||||||
fun decompose(matrix: Matrix<T>): Matrix<T> {
|
sum += get(row, i) * get(i, col)
|
||||||
if (matrix.hasFeature<LUPDecomposition<*>>()) {
|
|
||||||
return matrix
|
|
||||||
} else {
|
|
||||||
val decomposition = buildDecomposition(matrix)
|
|
||||||
return VirtualMatrix.wrap(matrix, decomposition)
|
|
||||||
}
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun MutableBuffer<Double>.divideInPlace(i: Int, j: Int, factor: Double) {
|
||||||
|
set(i, j, get(i, j) / factor)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun MutableBuffer<Double>.subtractInPlace(i: Int, j: Int, lu: MutableBuffer<Double>, col: Int) {
|
||||||
|
set(i, j, get(i, j) - get(col, j) * lu[i, col])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.buildAccessor(
|
||||||
|
type:KClass<T>,
|
||||||
|
rowNum: Int,
|
||||||
|
colNum: Int
|
||||||
|
): BufferAccessor<T> {
|
||||||
|
return if (elementContext == RealField) {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
RealBufferAccessor(rowNum, colNum) as BufferAccessor<T>
|
||||||
|
} else {
|
||||||
|
BufferAccessor(type, elementContext, rowNum, colNum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
||||||
|
if (value > elementContext.zero) value else with(elementContext) { -value }
|
||||||
|
|
||||||
|
|
||||||
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lupDecompose(
|
||||||
|
type: KClass<T>,
|
||||||
|
matrix: Matrix<T>,
|
||||||
|
checkSingular: (T) -> Boolean
|
||||||
|
): LUPDecomposition<T> {
|
||||||
|
if (matrix.rowNum != matrix.colNum) {
|
||||||
|
error("LU decomposition supports only square matrices")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
val m = matrix.colNum
|
||||||
if (b.rowNum != a.colNum) {
|
val pivot = IntArray(matrix.rowNum)
|
||||||
error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}")
|
|
||||||
|
buildAccessor(type, matrix.rowNum, matrix.colNum).run {
|
||||||
|
|
||||||
|
val lu = create(matrix)
|
||||||
|
|
||||||
|
// Initialize permutation array and parity
|
||||||
|
for (row in 0 until m) {
|
||||||
|
pivot[row] = row
|
||||||
}
|
}
|
||||||
|
var even = true
|
||||||
|
|
||||||
// Use existing decomposition if it is provided by matrix
|
// Loop over columns
|
||||||
val decomposition = a.getFeature() ?: buildDecomposition(a)
|
for (col in 0 until m) {
|
||||||
|
|
||||||
with(decomposition) {
|
// upper
|
||||||
with(context.elementContext) {
|
for (row in 0 until col) {
|
||||||
// Apply permutations to b
|
// var sum = lu[row, col]
|
||||||
val bp = Mutable2DStructure.create(a.rowNum, a.colNum, bufferFactory) { i, j ->
|
// for (i in 0 until row) {
|
||||||
b[pivot[i], j]
|
// sum -= lu[row, i] * lu[i, col]
|
||||||
|
// }
|
||||||
|
val sum = lu.innerProduct(row, col, row)
|
||||||
|
lu[row, col] = field.run { lu[row, col] - sum }
|
||||||
|
}
|
||||||
|
|
||||||
|
// lower
|
||||||
|
val max = (col until m).maxBy { row ->
|
||||||
|
// var sum = lu[row, col]
|
||||||
|
// for (i in 0 until col) {
|
||||||
|
// sum -= lu[row, i] * lu[i, col]
|
||||||
|
// }
|
||||||
|
// lu[row, col] = sum
|
||||||
|
val sum = lu.innerProduct(row, col, col)
|
||||||
|
lu[row, col] = field.run { lu[row, col] - sum }
|
||||||
|
abs(sum)
|
||||||
|
} ?: col
|
||||||
|
|
||||||
|
// Singularity check
|
||||||
|
if (checkSingular(lu[max, col])) {
|
||||||
|
error("Singular matrix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pivot if necessary
|
||||||
|
if (max != col) {
|
||||||
|
for (i in 0 until m) {
|
||||||
|
lu[max, i] = lu[col, i]
|
||||||
|
lu[col, i] = lu[max, i]
|
||||||
}
|
}
|
||||||
|
val temp = pivot[max]
|
||||||
|
pivot[max] = pivot[col]
|
||||||
|
pivot[col] = temp
|
||||||
|
even = !even
|
||||||
|
}
|
||||||
|
|
||||||
// Solve LY = b
|
// Divide the lower elements by the "winning" diagonal elt.
|
||||||
for (col in 0 until a.rowNum) {
|
val luDiag = lu[col, col]
|
||||||
for (i in col + 1 until a.rowNum) {
|
for (row in col + 1 until m) {
|
||||||
for (j in 0 until b.colNum) {
|
lu.divideInPlace(row, col, luDiag)
|
||||||
bp[i, j] -= bp[col, j] * lu[i, col]
|
//lu[row, col] = lu[row, col] / luDiag
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Solve UX = Y
|
|
||||||
for (col in a.rowNum - 1 downTo 0) {
|
|
||||||
for (j in 0 until b.colNum) {
|
|
||||||
bp[col, j] /= lu[col, col]
|
|
||||||
}
|
|
||||||
for (i in 0 until col) {
|
|
||||||
for (j in 0 until b.colNum) {
|
|
||||||
bp[i, j] -= bp[col, j] * lu[i, col]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return context.produce(a.rowNum, a.colNum) { i, j -> bp[i, j] }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return scientifik.kmath.linear.LUPDecomposition(elementContext, lu.collect(), pivot, even)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solve a linear equation **a*x = b**
|
||||||
|
*/
|
||||||
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
|
||||||
|
type: KClass<T>,
|
||||||
|
a: Matrix<T>,
|
||||||
|
b: Matrix<T>,
|
||||||
|
checkSingular: (T) -> Boolean
|
||||||
|
): Matrix<T> {
|
||||||
|
if (b.rowNum != a.colNum) {
|
||||||
|
error("Matrix dimension mismatch. Expected ${a.rowNum}, but got ${b.colNum}")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun inverse(a: Matrix<T>): Matrix<T> = solve(a, context.one(a.rowNum, a.colNum))
|
// Use existing decomposition if it is provided by matrix
|
||||||
|
val decomposition = a.getFeature() ?: lupDecompose(type, a, checkSingular)
|
||||||
|
|
||||||
companion object {
|
buildAccessor(type, a.rowNum, a.colNum).run {
|
||||||
val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
|
||||||
|
val lu = create(decomposition.lu)
|
||||||
|
|
||||||
|
// Apply permutations to b
|
||||||
|
val bp = create { i, j ->
|
||||||
|
b[decomposition.pivot[i], j]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve LY = b
|
||||||
|
for (col in 0 until a.rowNum) {
|
||||||
|
for (i in col + 1 until a.rowNum) {
|
||||||
|
for (j in 0 until b.colNum) {
|
||||||
|
bp.subtractInPlace(i, j, lu, col)
|
||||||
|
//bp[i, j] -= bp[col, j] * lu[i, col]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve UX = Y
|
||||||
|
for (col in a.rowNum - 1 downTo 0) {
|
||||||
|
val luDiag = lu[col, col]
|
||||||
|
for (j in 0 until b.colNum) {
|
||||||
|
bp.divideInPlace(col, j, luDiag)
|
||||||
|
//bp[col, j] /= lu[col, col]
|
||||||
|
}
|
||||||
|
for (i in 0 until col) {
|
||||||
|
for (j in 0 until b.colNum) {
|
||||||
|
bp.subtractInPlace(i, j, lu, col)
|
||||||
|
//bp[i, j] -= bp[col, j] * lu[i, col]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return produce(a.rowNum, a.colNum) { i, j -> bp[i, j] }
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
||||||
|
matrix: Matrix<T>,
|
||||||
|
noinline checkSingular: (T) -> Boolean
|
||||||
|
) =
|
||||||
|
solve(T::class, matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||||
|
|
||||||
|
fun GenericMatrixContext<Double, RealField>.inverse(matrix: Matrix<Double>) =
|
||||||
|
inverse(matrix) { it < 1e-11 }
|
@ -1,40 +0,0 @@
|
|||||||
package scientifik.kmath.linear
|
|
||||||
|
|
||||||
import scientifik.kmath.structures.MutableBuffer
|
|
||||||
import scientifik.kmath.structures.MutableBufferFactory
|
|
||||||
import scientifik.kmath.structures.MutableNDStructure
|
|
||||||
|
|
||||||
class Mutable2DStructure<T>(val rowNum: Int, val colNum: Int, val buffer: MutableBuffer<T>) : MutableNDStructure<T> {
|
|
||||||
override val shape: IntArray
|
|
||||||
get() = intArrayOf(rowNum, colNum)
|
|
||||||
|
|
||||||
operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
|
||||||
for (i in 0 until rowNum) {
|
|
||||||
for (j in 0 until colNum) {
|
|
||||||
yield(intArrayOf(i, j) to get(i, j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
operator fun set(i: Int, j: Int, value: T) {
|
|
||||||
buffer[i * colNum + j] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun set(index: IntArray, value: T) = set(index[0], index[1], value)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
fun <T> create(
|
|
||||||
rowNum: Int,
|
|
||||||
colNum: Int,
|
|
||||||
bufferFactory: MutableBufferFactory<T>,
|
|
||||||
init: (i: Int, j: Int) -> T
|
|
||||||
): Mutable2DStructure<T> {
|
|
||||||
val buffer = bufferFactory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
|
||||||
return Mutable2DStructure(rowNum, colNum, buffer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -2,6 +2,7 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
|
||||||
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
||||||
@ -41,13 +42,10 @@ interface Buffer<T> {
|
|||||||
*/
|
*/
|
||||||
inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer))
|
inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer))
|
||||||
|
|
||||||
/**
|
|
||||||
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
|
||||||
*/
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||||
//TODO add resolution based on Annotation or companion resolution
|
//TODO add resolution based on Annotation or companion resolution
|
||||||
return when (T::class) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
||||||
@ -56,6 +54,13 @@ interface Buffer<T> {
|
|||||||
else -> boxing(size, initializer)
|
else -> boxing(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> =
|
||||||
|
auto(T::class, size, initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,12 +83,9 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
MutableListBuffer(MutableList(size, initializer))
|
MutableListBuffer(MutableList(size, initializer))
|
||||||
|
|
||||||
/**
|
|
||||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
|
||||||
*/
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
inline fun <T : Any> auto(type: KClass<T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||||
return when (T::class) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||||
@ -91,6 +93,17 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
else -> boxing(size, initializer)
|
else -> boxing(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
|
auto(T::class, size, initializer)
|
||||||
|
|
||||||
|
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
||||||
|
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
|
||||||
interface NDStructure<T> {
|
interface NDStructure<T> {
|
||||||
|
|
||||||
@ -40,6 +43,9 @@ interface NDStructure<T> {
|
|||||||
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
|
inline fun <T : Any> auto(type: KClass<T>, strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||||
|
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
fun <T> build(
|
fun <T> build(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
@ -48,6 +54,13 @@ interface NDStructure<T> {
|
|||||||
|
|
||||||
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
|
@JvmName("autoVarArg")
|
||||||
|
inline fun <reified T : Any> auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||||
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
|
inline fun <T : Any> auto(type: KClass<T>, vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||||
|
auto(type, DefaultStrides(shape), initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,7 +70,7 @@ interface MutableNDStructure<T> : NDStructure<T> {
|
|||||||
operator fun set(index: IntArray, value: T)
|
operator fun set(index: IntArray, value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
||||||
elements().forEach { (index, oldValue) ->
|
elements().forEach { (index, oldValue) ->
|
||||||
this[index] = action(index, oldValue)
|
this[index] = action(index, oldValue)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user