forked from kscience/kmath
Linear algebra in progress
This commit is contained in:
parent
76c3025acc
commit
a9c084773b
@ -1,5 +1,5 @@
|
|||||||
buildscript {
|
buildscript {
|
||||||
ext.kotlin_version = '1.3.0-rc-116'
|
ext.kotlin_version = '1.3.0-rc-146'
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
jcenter()
|
jcenter()
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
package scientifik.kmath.commons
|
||||||
|
|
||||||
|
//val solver: DecompositionSolver
|
@ -0,0 +1,321 @@
|
|||||||
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.structures.MutableNDArray
|
||||||
|
import scientifik.kmath.structures.NDArray
|
||||||
|
import scientifik.kmath.structures.NDArrays
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the LUP-decomposition of a square matrix.
|
||||||
|
*
|
||||||
|
* The LUP-decomposition of a matrix A consists of three matrices L, U and
|
||||||
|
* P that satisfy: PA = LU. L is lower triangular (with unit
|
||||||
|
* diagonal terms), U is upper triangular and P is a permutation matrix. All
|
||||||
|
* matrices are mm.
|
||||||
|
*
|
||||||
|
* As shown by the presence of the P matrix, this decomposition is
|
||||||
|
* implemented using partial pivoting.
|
||||||
|
*
|
||||||
|
* This class is based on the class with similar name from the
|
||||||
|
* [JAMA](http://math.nist.gov/javanumerics/jama/) library.
|
||||||
|
*
|
||||||
|
* * a [getP][.getP] method has been added,
|
||||||
|
* * the `det` method has been renamed as [ getDeterminant][.getDeterminant],
|
||||||
|
* * the `getDoublePivot` method has been removed (but the int based
|
||||||
|
* [getPivot][.getPivot] method has been kept),
|
||||||
|
* * the `solve` and `isNonSingular` methods have been replaced
|
||||||
|
* by a [getSolver][.getSolver] method and the equivalent methods
|
||||||
|
* provided by the returned [DecompositionSolver].
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* @see [MathWorld](http://mathworld.wolfram.com/LUDecomposition.html)
|
||||||
|
*
|
||||||
|
* @see [Wikipedia](http://en.wikipedia.org/wiki/LU_decomposition)
|
||||||
|
*
|
||||||
|
* @since 2.0 (changed to concrete class in 3.0)
|
||||||
|
*
|
||||||
|
* @param matrix The matrix to decompose.
|
||||||
|
* @param singularityThreshold threshold (based on partial row norm)
|
||||||
|
* under which a matrix is considered singular
|
||||||
|
* @throws NonSquareMatrixException if matrix is not square
|
||||||
|
|
||||||
|
*/
|
||||||
|
abstract class LUDecomposition<T : Comparable<T>>(val matrix: Matrix<T>) {
|
||||||
|
|
||||||
|
private val field get() = matrix.context.field
|
||||||
|
/** Entries of LU decomposition. */
|
||||||
|
internal val lu: NDArray<T>
|
||||||
|
/** Pivot permutation associated with LU decomposition. */
|
||||||
|
internal val pivot: IntArray
|
||||||
|
/** Parity of the permutation associated with the LU decomposition. */
|
||||||
|
private var even: Boolean = false
|
||||||
|
|
||||||
|
init {
|
||||||
|
val pair = matrix.context.field.calculateLU()
|
||||||
|
lu = pair.first
|
||||||
|
pivot = pair.second
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the matrix L of the decomposition.
|
||||||
|
*
|
||||||
|
* L is a lower-triangular matrix
|
||||||
|
* @return the L matrix (or null if decomposed matrix is singular)
|
||||||
|
*/
|
||||||
|
val l: Matrix<out T> by lazy {
|
||||||
|
matrix.context.produce { i, j ->
|
||||||
|
when {
|
||||||
|
j < i -> lu[i, j]
|
||||||
|
j == i -> matrix.context.field.one
|
||||||
|
else -> matrix.context.field.zero
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the matrix U of the decomposition.
|
||||||
|
*
|
||||||
|
* U is an upper-triangular matrix
|
||||||
|
* @return the U matrix (or null if decomposed matrix is singular)
|
||||||
|
*/
|
||||||
|
val u: Matrix<out T> by lazy {
|
||||||
|
matrix.context.produce { i, j ->
|
||||||
|
if (j >= i) lu[i, j] else field.zero
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the P rows permutation matrix.
|
||||||
|
*
|
||||||
|
* P is a sparse matrix with exactly one element set to 1.0 in
|
||||||
|
* each row and each column, all other elements being set to 0.0.
|
||||||
|
*
|
||||||
|
* The positions of the 1 elements are given by the [ pivot permutation vector][.getPivot].
|
||||||
|
* @return the P rows permutation matrix (or null if decomposed matrix is singular)
|
||||||
|
* @see .getPivot
|
||||||
|
*/
|
||||||
|
val p: Matrix<out T> by lazy {
|
||||||
|
matrix.context.produce { i, j ->
|
||||||
|
//TODO ineffective. Need sparse matrix for that
|
||||||
|
if (j == pivot[i]) field.one else field.zero
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the determinant of the matrix
|
||||||
|
* @return determinant of the matrix
|
||||||
|
*/
|
||||||
|
val determinant: T
|
||||||
|
get() {
|
||||||
|
with(matrix.context.field) {
|
||||||
|
var determinant = if (even) one else -one
|
||||||
|
for (i in 0 until matrix.rows) {
|
||||||
|
determinant *= lu[i, i]
|
||||||
|
}
|
||||||
|
return determinant
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Get a solver for finding the A X = B solution in exact linear
|
||||||
|
// * sense.
|
||||||
|
// * @return a solver
|
||||||
|
// */
|
||||||
|
// val solver: DecompositionSolver
|
||||||
|
// get() = Solver(lu, pivot, singular)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In-place transformation for [MutableNDArray], using given transformation for each element
|
||||||
|
*/
|
||||||
|
operator fun <T> MutableNDArray<T>.set(i: Int, j: Int, value: T) {
|
||||||
|
this[listOf(i, j)] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract fun isSingular(value: T): Boolean
|
||||||
|
|
||||||
|
private fun Field<T>.calculateLU(): Pair<NDArray<T>, IntArray> {
|
||||||
|
if (matrix.rows != matrix.columns) {
|
||||||
|
error("LU decomposition supports only square matrices")
|
||||||
|
}
|
||||||
|
|
||||||
|
fun T.abs() = if (this > zero) this else -this
|
||||||
|
|
||||||
|
val m = matrix.columns
|
||||||
|
val pivot = IntArray(matrix.rows)
|
||||||
|
//TODO fix performance
|
||||||
|
val lu: MutableNDArray<T> = NDArrays.createMutable(matrix.context.field, listOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] }
|
||||||
|
|
||||||
|
// Initialize permutation array and parity
|
||||||
|
for (row in 0 until m) {
|
||||||
|
pivot[row] = row
|
||||||
|
}
|
||||||
|
even = true
|
||||||
|
|
||||||
|
// Loop over columns
|
||||||
|
for (col in 0 until m) {
|
||||||
|
|
||||||
|
// upper
|
||||||
|
for (row in 0 until col) {
|
||||||
|
var sum = lu[row, col]
|
||||||
|
for (i in 0 until row) {
|
||||||
|
sum -= lu[row, i] * lu[i, col]
|
||||||
|
}
|
||||||
|
lu[row, col] = sum
|
||||||
|
}
|
||||||
|
|
||||||
|
// lower
|
||||||
|
val max = (col until m).maxBy { row ->
|
||||||
|
var sum = lu[row, col]
|
||||||
|
for (i in 0 until col) {
|
||||||
|
sum -= lu[row, i] * lu[i, col]
|
||||||
|
}
|
||||||
|
//luRow[col] = sum
|
||||||
|
lu[row, col] = sum
|
||||||
|
|
||||||
|
sum.abs()
|
||||||
|
} ?: col
|
||||||
|
|
||||||
|
// Singularity check
|
||||||
|
if (isSingular(lu[max, col].abs())) {
|
||||||
|
error("Singular matrix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pivot if necessary
|
||||||
|
if (max != col) {
|
||||||
|
//var tmp = zero
|
||||||
|
//val luMax = lu[max]
|
||||||
|
//val luCol = lu[col]
|
||||||
|
for (i in 0 until m) {
|
||||||
|
lu[max, i] = lu[col, i]
|
||||||
|
lu[col, i] = lu[max, i]
|
||||||
|
}
|
||||||
|
val temp = pivot[max]
|
||||||
|
pivot[max] = pivot[col]
|
||||||
|
pivot[col] = temp
|
||||||
|
even = !even
|
||||||
|
}
|
||||||
|
|
||||||
|
// Divide the lower elements by the "winning" diagonal elt.
|
||||||
|
val luDiag = lu[col, col]
|
||||||
|
for (row in col + 1 until m) {
|
||||||
|
lu[row, col] /= luDiag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Pair(lu, pivot)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the pivot permutation vector.
|
||||||
|
* @return the pivot permutation vector
|
||||||
|
* @see .getP
|
||||||
|
*/
|
||||||
|
fun getPivot(): IntArray {
|
||||||
|
return pivot.copyOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/** Default bound to determine effective singularity in LU decomposition. */
|
||||||
|
private const val DEFAULT_TOO_SMALL = 1e-11
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class RealLUDecomposition(matrix: Matrix<Double>, private val singularityThreshold: Double = 1e-11) : LUDecomposition<Double>(matrix) {
|
||||||
|
override fun isSingular(value: Double): Boolean {
|
||||||
|
return value < singularityThreshold
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Specialized solver. */
|
||||||
|
class RealLUSolver : LinearSolver<Double> {
|
||||||
|
|
||||||
|
//
|
||||||
|
// /** {@inheritDoc} */
|
||||||
|
// override fun solve(b: RealVector): RealVector {
|
||||||
|
// val m = pivot.size
|
||||||
|
// if (b.getDimension() != m) {
|
||||||
|
// throw DimensionMismatchException(b.getDimension(), m)
|
||||||
|
// }
|
||||||
|
// if (singular) {
|
||||||
|
// throw SingularMatrixException()
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// val bp = DoubleArray(m)
|
||||||
|
//
|
||||||
|
// // Apply permutations to b
|
||||||
|
// for (row in 0 until m) {
|
||||||
|
// bp[row] = b.getEntry(pivot[row])
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Solve LY = b
|
||||||
|
// for (col in 0 until m) {
|
||||||
|
// val bpCol = bp[col]
|
||||||
|
// for (i in col + 1 until m) {
|
||||||
|
// bp[i] -= bpCol * lu[i][col]
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Solve UX = Y
|
||||||
|
// for (col in m - 1 downTo 0) {
|
||||||
|
// bp[col] /= lu[col][col]
|
||||||
|
// val bpCol = bp[col]
|
||||||
|
// for (i in 0 until col) {
|
||||||
|
// bp[i] -= bpCol * lu[i][col]
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// return ArrayRealVector(bp, false)
|
||||||
|
// }
|
||||||
|
|
||||||
|
|
||||||
|
fun decompose(mat: Matrix<Double>, threshold: Double = 1e-11): RealLUDecomposition = RealLUDecomposition(mat, threshold)
|
||||||
|
|
||||||
|
override fun solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
|
||||||
|
val decomposition = decompose(a, a.context.field.zero)
|
||||||
|
|
||||||
|
if (b.rows != a.rows) {
|
||||||
|
error("Matrix dimension mismatch expected ${a.rows}, but got ${b.rows}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply permutations to b
|
||||||
|
val bp = Array(a.rows) { DoubleArray(b.columns) }
|
||||||
|
for (row in 0 until a.rows) {
|
||||||
|
val bpRow = bp[row]
|
||||||
|
val pRow = decomposition.pivot[row]
|
||||||
|
for (col in 0 until b.columns) {
|
||||||
|
bpRow[col] = b[pRow, col]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve LY = b
|
||||||
|
for (col in 0 until a.rows) {
|
||||||
|
val bpCol = bp[col]
|
||||||
|
for (i in col + 1 until a.rows) {
|
||||||
|
val bpI = bp[i]
|
||||||
|
val luICol = decomposition.lu[i, col]
|
||||||
|
for (j in 0 until b.columns) {
|
||||||
|
bpI[j] -= bpCol[j] * luICol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve UX = Y
|
||||||
|
for (col in a.rows - 1 downTo 0) {
|
||||||
|
val bpCol = bp[col]
|
||||||
|
val luDiag = decomposition.lu[col, col]
|
||||||
|
for (j in 0 until b.columns) {
|
||||||
|
bpCol[j] /= luDiag
|
||||||
|
}
|
||||||
|
for (i in 0 until col) {
|
||||||
|
val bpI = bp[i]
|
||||||
|
val luICol = decomposition.lu[i, col]
|
||||||
|
for (j in 0 until b.columns) {
|
||||||
|
bpI[j] -= bpCol[j] * luICol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.context.produce { i, j -> bp[i][j] }
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,250 @@
|
|||||||
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.DoubleField
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.SpaceElement
|
||||||
|
import scientifik.kmath.structures.NDArray
|
||||||
|
import scientifik.kmath.structures.NDArrays.createFactory
|
||||||
|
import scientifik.kmath.structures.NDFieldFactory
|
||||||
|
import scientifik.kmath.structures.realNDFieldFactory
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The space for linear elements. Supports scalar product alongside with standard linear operations.
|
||||||
|
* @param T type of individual element of the vector or matrix
|
||||||
|
* @param V the type of vector space element
|
||||||
|
*/
|
||||||
|
abstract class LinearSpace<T : Any, V : Matrix<T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce the element of this space
|
||||||
|
*/
|
||||||
|
abstract fun produce(initializer: (Int, Int) -> T): V
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce new linear space with given dimensions. The space produced could be raised from cache since [LinearSpace] does not have mutable elements
|
||||||
|
*/
|
||||||
|
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T, V>
|
||||||
|
|
||||||
|
override val zero: V by lazy {
|
||||||
|
produce { _, _ -> field.zero }
|
||||||
|
}
|
||||||
|
|
||||||
|
val one: V by lazy {
|
||||||
|
produce { i, j -> if (i == j) field.one else field.zero }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun add(a: V, b: V): V {
|
||||||
|
return produce { i, j -> with(field) { a[i, j] + b[i, j] } }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: V, k: Double): V {
|
||||||
|
//TODO it is possible to implement scalable linear elements which normed values and adjustable scale to save memory and processing poser
|
||||||
|
return produce { i, j -> with(field) { a[i, j] * k } }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dot product. Throws exception on dimension mismatch
|
||||||
|
*/
|
||||||
|
fun multiply(a: V, b: V): V {
|
||||||
|
if (a.rows != b.columns) {
|
||||||
|
//TODO replace by specific exception
|
||||||
|
error("Dimension mismatch in linear structure dot product: [${a.rows},${a.columns}]*[${b.rows},${b.columns}]")
|
||||||
|
}
|
||||||
|
return produceSpace(a.rows, b.columns).produce { i, j ->
|
||||||
|
(0..a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
infix fun V.dot(b: V): V = multiply(this, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A specialized [LinearSpace] which works with vectors
|
||||||
|
*/
|
||||||
|
abstract class VectorSpace<T : Any, V : Vector<T>>(size: Int, field: Field<T>) : LinearSpace<T, V>(size, 1, field)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A matrix-like structure
|
||||||
|
*/
|
||||||
|
interface Matrix<T : Any> {
|
||||||
|
val context: LinearSpace<T, out Matrix<T>>
|
||||||
|
/**
|
||||||
|
* Number of rows
|
||||||
|
*/
|
||||||
|
val rows: Int
|
||||||
|
/**
|
||||||
|
* Number of columns
|
||||||
|
*/
|
||||||
|
val columns: Int
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get element in row [i] and column [j]. Throws error in case of call ounside structure dimensions
|
||||||
|
*/
|
||||||
|
operator fun get(i: Int, j: Int): T
|
||||||
|
|
||||||
|
fun transpose(): Matrix<T> {
|
||||||
|
return object : Matrix<T> {
|
||||||
|
override val context: LinearSpace<T, out Matrix<T>> = this@Matrix.context
|
||||||
|
override val rows: Int = this@Matrix.columns
|
||||||
|
override val columns: Int = this@Matrix.rows
|
||||||
|
override fun get(i: Int, j: Int): T = this@Matrix[j, i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Vector<T : Any> : Matrix<T> {
|
||||||
|
override val context: VectorSpace<T, Vector<T>>
|
||||||
|
override val columns: Int
|
||||||
|
get() = 1
|
||||||
|
|
||||||
|
operator fun get(i: Int) = get(i, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* NDArray-based implementation of vector space. By default uses slow [SimpleNDField], but could be overridden with custom [NDField] factory.
|
||||||
|
*/
|
||||||
|
class ArraySpace<T : Any>(
|
||||||
|
rows: Int,
|
||||||
|
columns: Int,
|
||||||
|
field: Field<T>,
|
||||||
|
val ndFactory: NDFieldFactory<T> = createFactory(field)
|
||||||
|
) : LinearSpace<T, Matrix<T>>(rows, columns, field) {
|
||||||
|
|
||||||
|
val ndField by lazy {
|
||||||
|
ndFactory(listOf(rows, columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun produce(initializer: (Int, Int) -> T): Matrix<T> = ArrayMatrix(this, initializer)
|
||||||
|
|
||||||
|
override fun produceSpace(rows: Int, columns: Int): ArraySpace<T> {
|
||||||
|
return ArraySpace(rows, columns, field, ndFactory)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ArrayVectorSpace<T : Any>(
|
||||||
|
size: Int,
|
||||||
|
field: Field<T>,
|
||||||
|
val ndFactory: NDFieldFactory<T> = createFactory(field)
|
||||||
|
) : VectorSpace<T, Vector<T>>(size, field) {
|
||||||
|
val ndField by lazy {
|
||||||
|
ndFactory(listOf(size))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun produce(initializer: (Int, Int) -> T): Vector<T> = produceVector { i -> initializer(i, 0) }
|
||||||
|
|
||||||
|
fun produceVector(initializer: (Int) -> T): Vector<T> = ArrayVector(this, initializer)
|
||||||
|
|
||||||
|
override fun produceSpace(rows: Int, columns: Int): LinearSpace<T, Vector<T>> {
|
||||||
|
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Member of [ArraySpace] which wraps 2-D array
|
||||||
|
*/
|
||||||
|
class ArrayMatrix<T : Any> internal constructor(override val context: ArraySpace<T>, val array: NDArray<T>) : Matrix<T>, SpaceElement<Matrix<T>, ArraySpace<T>> {
|
||||||
|
|
||||||
|
constructor(context: ArraySpace<T>, initializer: (Int, Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0], list[1]) })
|
||||||
|
|
||||||
|
override val rows: Int get() = context.rows
|
||||||
|
|
||||||
|
override val columns: Int get() = context.columns
|
||||||
|
|
||||||
|
override fun get(i: Int, j: Int): T {
|
||||||
|
return array[i, j]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val self: ArrayMatrix<T> get() = this
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayVector<T : Any> internal constructor(override val context: ArrayVectorSpace<T>, val array: NDArray<T>) : Vector<T>, SpaceElement<Vector<T>, ArrayVectorSpace<T>> {
|
||||||
|
|
||||||
|
constructor(context: ArrayVectorSpace<T>, initializer: (Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0]) })
|
||||||
|
|
||||||
|
init {
|
||||||
|
if (context.columns != 1) {
|
||||||
|
error("Vector must have single column")
|
||||||
|
}
|
||||||
|
if (context.rows != array.shape[0]) {
|
||||||
|
error("Array dimension mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//private val array = context.ndField.produce { list -> initializer(list[0]) }
|
||||||
|
|
||||||
|
|
||||||
|
override val rows: Int get() = context.rows
|
||||||
|
|
||||||
|
override val columns: Int = 1
|
||||||
|
|
||||||
|
override fun get(i: Int, j: Int): T {
|
||||||
|
return array[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val self: ArrayVector<T> get() = this
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
||||||
|
*/
|
||||||
|
interface LinearSolver<T : Any> {
|
||||||
|
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||||
|
fun solve(a: Matrix<T>, b: Vector<T>): Vector<T> = solve(a, b as Matrix<T>).toVector()
|
||||||
|
fun inverse(a: Matrix<T>): Matrix<T> = solve(a, a.context.one)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create vector with custom field
|
||||||
|
*/
|
||||||
|
fun <T : Any> vector(size: Int, field: Field<T>, initializer: (Int) -> T) =
|
||||||
|
ArrayVector(ArrayVectorSpace(size, field), initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create vector of [Double]
|
||||||
|
*/
|
||||||
|
fun realVector(size: Int, initializer: (Int) -> Double) =
|
||||||
|
ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert vector to array (copying content of array)
|
||||||
|
*/
|
||||||
|
fun <T : Any> Array<T>.asVector(field: Field<T>) = vector(size, field) { this[it] }
|
||||||
|
|
||||||
|
fun DoubleArray.asVector() = realVector(this.size) { this[it] }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create [ArrayMatrix] with custom field
|
||||||
|
*/
|
||||||
|
fun <T : Any> matrix(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) =
|
||||||
|
ArrayMatrix(ArraySpace(rows, columns, field), initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create [ArrayMatrix] of doubles.
|
||||||
|
*/
|
||||||
|
fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||||
|
ArrayMatrix(ArraySpace(rows, columns, DoubleField, realNDFieldFactory), initializer)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert matrix to vector if it is possible
|
||||||
|
*/
|
||||||
|
fun <T : Any> Matrix<T>.toVector(): Vector<T> {
|
||||||
|
return when {
|
||||||
|
this is Vector -> return this
|
||||||
|
this.columns == 1 -> {
|
||||||
|
if (this is ArrayMatrix) {
|
||||||
|
//Reuse existing underlying array
|
||||||
|
ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array)
|
||||||
|
} else {
|
||||||
|
//Generic vector
|
||||||
|
vector(rows, context.field) { get(it, 0) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else -> error("Can't convert matrix with more than one column to vector")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,41 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field for complex numbers
|
||||||
|
*/
|
||||||
|
object ComplexField : Field<Complex> {
|
||||||
|
override val zero: Complex = Complex(0.0, 0.0)
|
||||||
|
|
||||||
|
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
||||||
|
|
||||||
|
override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k)
|
||||||
|
|
||||||
|
override val one: Complex = Complex(1.0, 0.0)
|
||||||
|
|
||||||
|
override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
||||||
|
|
||||||
|
override fun divide(a: Complex, b: Complex): Complex = Complex(a.re * b.re + a.im * b.im, a.re * b.im - a.im * b.re) / b.square
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Complex number class
|
||||||
|
*/
|
||||||
|
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, ComplexField> {
|
||||||
|
override val self: Complex get() = this
|
||||||
|
override val context: ComplexField
|
||||||
|
get() = ComplexField
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A complex conjugate
|
||||||
|
*/
|
||||||
|
val conjugate: Complex
|
||||||
|
get() = Complex(re, -im)
|
||||||
|
|
||||||
|
val square: Double
|
||||||
|
get() = re * re + im * im
|
||||||
|
|
||||||
|
val abs: Double
|
||||||
|
get() = kotlin.math.sqrt(square)
|
||||||
|
|
||||||
|
}
|
@ -1,7 +1,6 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
import kotlin.math.sqrt
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Field for real values
|
* Field for real values
|
||||||
@ -51,46 +50,6 @@ data class Real(val value: Double) : Number(), FieldElement<Real, RealField> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* A field for complex numbers
|
|
||||||
*/
|
|
||||||
object ComplexField : Field<Complex> {
|
|
||||||
override val zero: Complex = Complex(0.0, 0.0)
|
|
||||||
|
|
||||||
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
|
||||||
|
|
||||||
override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k)
|
|
||||||
|
|
||||||
override val one: Complex = Complex(1.0, 0.0)
|
|
||||||
|
|
||||||
override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
|
||||||
|
|
||||||
override fun divide(a: Complex, b: Complex): Complex = Complex(a.re * b.re + a.im * b.im, a.re * b.im - a.im * b.re) / b.square
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Complex number class
|
|
||||||
*/
|
|
||||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, ComplexField> {
|
|
||||||
override val self: Complex get() = this
|
|
||||||
override val context: ComplexField
|
|
||||||
get() = ComplexField
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A complex conjugate
|
|
||||||
*/
|
|
||||||
val conjugate: Complex
|
|
||||||
get() = Complex(re, -im)
|
|
||||||
|
|
||||||
val square: Double
|
|
||||||
get() = re * re + im * im
|
|
||||||
|
|
||||||
val module: Double
|
|
||||||
get() = sqrt(square)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for double without boxing. Does not produce appropriate field element
|
* A field for double without boxing. Does not produce appropriate field element
|
||||||
*/
|
*/
|
||||||
|
@ -9,6 +9,11 @@ import scientifik.kmath.operations.Field
|
|||||||
interface Buffer<T> {
|
interface Buffer<T> {
|
||||||
operator fun get(index: Int): T
|
operator fun get(index: Int): T
|
||||||
operator fun set(index: Int, value: T)
|
operator fun set(index: Int, value: T)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A shallow copy of the buffer
|
||||||
|
*/
|
||||||
|
fun copy(): Buffer<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -63,8 +68,16 @@ abstract class BufferNDField<T>(shape: List<Int>, field: Field<T>) : NDField<T>(
|
|||||||
return BufferNDArray(this, buffer)
|
return BufferNDArray(this, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce mutable NDArray instance
|
||||||
|
*/
|
||||||
|
fun produceMutable(initializer: (List<Int>) -> T): MutableNDArray<T> {
|
||||||
|
val buffer = createBuffer(capacity) { initializer(index(it)) }
|
||||||
|
return MutableBufferedNDArray(this, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
class BufferNDArray<T>(override val context: BufferNDField<T>, val data: Buffer<T>) : NDArray<T> {
|
|
||||||
|
private open class BufferNDArray<T>(override val context: BufferNDField<T>, val data: Buffer<T>) : NDArray<T> {
|
||||||
|
|
||||||
override fun get(vararg index: Int): T {
|
override fun get(vararg index: Int): T {
|
||||||
return data[context.offset(index.asList())]
|
return data[context.offset(index.asList())]
|
||||||
@ -88,6 +101,12 @@ abstract class BufferNDField<T>(shape: List<Int>, field: Field<T>) : NDField<T>(
|
|||||||
|
|
||||||
override val self: NDArray<T> get() = this
|
override val self: NDArray<T> get() = this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class MutableBufferedNDArray<T>(context: BufferNDField<T>, data: Buffer<T>): BufferNDArray<T>(context,data), MutableNDArray<T>{
|
||||||
|
override operator fun set(index: List<Int>, value: T){
|
||||||
|
data[context.offset(index)] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,169 +0,0 @@
|
|||||||
package scientifik.kmath.structures
|
|
||||||
|
|
||||||
import scientifik.kmath.operations.DoubleField
|
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
import scientifik.kmath.operations.SpaceElement
|
|
||||||
import scientifik.kmath.structures.NDArrays.createSimpleNDFieldFactory
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The space for linear elements. Supports scalar product alongside with standard linear operations.
|
|
||||||
* @param T type of individual element of the vector or matrix
|
|
||||||
* @param V the type of vector space element
|
|
||||||
*/
|
|
||||||
abstract class LinearSpace<T : Any, V : LinearStructure<out T>>(val rows: Int, val columns: Int, val field: Field<T>) : Space<V> {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce the element of this space
|
|
||||||
*/
|
|
||||||
abstract fun produce(initializer: (Int, Int) -> T): V
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce new linear space with given dimensions. The space produced could be raised from cache since [LinearSpace] does not have mutable elements
|
|
||||||
*/
|
|
||||||
abstract fun produceSpace(rows: Int, columns: Int): LinearSpace<T, V>
|
|
||||||
|
|
||||||
override val zero: V by lazy {
|
|
||||||
produce { _, _ -> field.zero }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun add(a: V, b: V): V {
|
|
||||||
return produce { i, j -> with(field) { a[i, j] + b[i, j] } }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: V, k: Double): V {
|
|
||||||
//TODO it is possible to implement scalable linear elements which normed values and adjustable scale to save memory and processing poser
|
|
||||||
return produce { i, j -> with(field) { a[i, j] * k } }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dot product. Throws exception on dimension mismatch
|
|
||||||
*/
|
|
||||||
fun multiply(a: V, b: V): V {
|
|
||||||
if (a.rows != b.columns) {
|
|
||||||
//TODO replace by specific exception
|
|
||||||
error("Dimension mismatch in linear structure dot product: [${a.rows},${a.columns}]*[${b.rows},${b.columns}]")
|
|
||||||
}
|
|
||||||
return produceSpace(a.rows, b.columns).produce { i, j ->
|
|
||||||
(0..a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
infix fun V.dot(b: V): V = multiply(this, b)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A matrix-like structure that is not dependent on specific space implementation
|
|
||||||
*/
|
|
||||||
interface LinearStructure<T : Any> {
|
|
||||||
/**
|
|
||||||
* Number of rows
|
|
||||||
*/
|
|
||||||
val rows: Int
|
|
||||||
/**
|
|
||||||
* Number of columns
|
|
||||||
*/
|
|
||||||
val columns: Int
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get element in row [i] and column [j]. Throws error in case of call ounside structure dimensions
|
|
||||||
*/
|
|
||||||
operator fun get(i: Int, j: Int): T
|
|
||||||
|
|
||||||
fun transpose(): LinearStructure<T> {
|
|
||||||
return object : LinearStructure<T> {
|
|
||||||
override val rows: Int = this@LinearStructure.columns
|
|
||||||
override val columns: Int = this@LinearStructure.rows
|
|
||||||
override fun get(i: Int, j: Int): T = this@LinearStructure[j, i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
interface Vector<T : Any> : LinearStructure<T> {
|
|
||||||
override val columns: Int
|
|
||||||
get() = 1
|
|
||||||
|
|
||||||
operator fun get(i: Int) = get(i, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* NDArray-based implementation of vector space. By default uses slow [SimpleNDField], but could be overridden with custom [NDField] factory.
|
|
||||||
*/
|
|
||||||
class ArraySpace<T : Any>(
|
|
||||||
rows: Int,
|
|
||||||
columns: Int,
|
|
||||||
field: Field<T>,
|
|
||||||
val ndFactory: NDFieldFactory<T> = createSimpleNDFieldFactory(field)
|
|
||||||
) : LinearSpace<T, LinearStructure<out T>>(rows, columns, field) {
|
|
||||||
|
|
||||||
val ndField by lazy {
|
|
||||||
ndFactory(listOf(rows, columns))
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun produce(initializer: (Int, Int) -> T): LinearStructure<T> = ArrayMatrix<T>(this, initializer)
|
|
||||||
|
|
||||||
override fun produceSpace(rows: Int, columns: Int): LinearSpace<T, LinearStructure<out T>> {
|
|
||||||
return ArraySpace(rows, columns, field, ndFactory)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Member of [ArraySpace] which wraps 2-D array
|
|
||||||
*/
|
|
||||||
class ArrayMatrix<T : Any>(override val context: ArraySpace<T>, initializer: (Int, Int) -> T) : LinearStructure<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
|
||||||
|
|
||||||
private val array = context.ndField.produce { list -> initializer(list[0], list[1]) }
|
|
||||||
|
|
||||||
//val list: List<List<T>> = (0 until rows).map { i -> (0 until columns).map { j -> initializer(i, j) } }
|
|
||||||
|
|
||||||
override val rows: Int get() = context.rows
|
|
||||||
|
|
||||||
override val columns: Int get() = context.columns
|
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T {
|
|
||||||
return array[i, j]
|
|
||||||
}
|
|
||||||
|
|
||||||
override val self: ArrayMatrix<T> get() = this
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ArrayVector<T : Any>(override val context: ArraySpace<T>, initializer: (Int) -> T) : Vector<T>, SpaceElement<LinearStructure<out T>, ArraySpace<T>> {
|
|
||||||
|
|
||||||
init {
|
|
||||||
if (context.columns != 1) {
|
|
||||||
error("Vector must have single column")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val array = context.ndField.produce { list -> initializer(list[0]) }
|
|
||||||
|
|
||||||
|
|
||||||
override val rows: Int get() = context.rows
|
|
||||||
|
|
||||||
override val columns: Int = 1
|
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T {
|
|
||||||
return array[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
override val self: ArrayVector<T> get() = this
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <T : Any> vector(size: Int, field: Field<T>, initializer: (Int) -> T) =
|
|
||||||
ArrayVector(ArraySpace(size, 1, field), initializer)
|
|
||||||
|
|
||||||
fun realVector(size: Int, initializer: (Int) -> Double) =
|
|
||||||
ArrayVector(ArraySpace(size, 1, DoubleField, realNDFieldFactory), initializer)
|
|
||||||
|
|
||||||
fun <T : Any> Array<T>.asVector(field: Field<T>) = vector(size, field) { this[it] }
|
|
||||||
|
|
||||||
fun DoubleArray.asVector() = realVector(this.size) { this[it] }
|
|
||||||
|
|
||||||
fun <T : Any> matrix(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) =
|
|
||||||
ArrayMatrix(ArraySpace(rows, columns, field), initializer)
|
|
||||||
|
|
||||||
fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
|
||||||
ArrayMatrix(ArraySpace(rows, columns, DoubleField, realNDFieldFactory), initializer)
|
|
@ -73,7 +73,9 @@ abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDA
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Many-dimensional array
|
||||||
|
*/
|
||||||
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
|
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -125,6 +127,22 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In-place mutable [NDArray]
|
||||||
|
*/
|
||||||
|
interface MutableNDArray<T> : NDArray<T> {
|
||||||
|
operator fun set(index: List<Int>, value: T)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In-place transformation for [MutableNDArray], using given transformation for each element
|
||||||
|
*/
|
||||||
|
fun <T> MutableNDArray<T>.transformInPlace(action: (List<Int>, T) -> T) {
|
||||||
|
for ((index, oldValue) in this) {
|
||||||
|
this[index] = action(index, oldValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||||
*/
|
*/
|
||||||
|
@ -9,6 +9,28 @@ typealias NDFieldFactory<T> = (shape: List<Int>) -> NDField<T>
|
|||||||
*/
|
*/
|
||||||
expect val realNDFieldFactory: NDFieldFactory<Double>
|
expect val realNDFieldFactory: NDFieldFactory<Double>
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNDField<T : Any>(field: Field<T>, shape: List<Int>) : BufferNDField<T>(shape, field) {
|
||||||
|
override fun createBuffer(capacity: Int, initializer: (Int) -> T): Buffer<T> {
|
||||||
|
val array = ArrayList<T>(capacity)
|
||||||
|
(0 until capacity).forEach {
|
||||||
|
array.add(initializer(it))
|
||||||
|
}
|
||||||
|
|
||||||
|
return BufferOfObjects(array)
|
||||||
|
}
|
||||||
|
|
||||||
|
private class BufferOfObjects<T>(val array: ArrayList<T>) : Buffer<T> {
|
||||||
|
override fun get(index: Int): T = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: T) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copy(): Buffer<T> = BufferOfObjects(ArrayList(array))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
object NDArrays {
|
object NDArrays {
|
||||||
/**
|
/**
|
||||||
* Create a platform-optimized NDArray of doubles
|
* Create a platform-optimized NDArray of doubles
|
||||||
@ -29,26 +51,22 @@ object NDArrays {
|
|||||||
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||||
}
|
}
|
||||||
|
|
||||||
class SimpleNDField<T : Any>(field: Field<T>, shape: List<Int>) : BufferNDField<T>(shape, field) {
|
/**
|
||||||
override fun createBuffer(capacity: Int, initializer: (Int) -> T): Buffer<T> {
|
* Simple boxing NDField
|
||||||
val array = ArrayList<T>(capacity)
|
*/
|
||||||
(0 until capacity).forEach {
|
fun <T : Any> createFactory(field: Field<T>): NDFieldFactory<T> = { shape -> SimpleNDField(field, shape) }
|
||||||
array.add(initializer(it))
|
|
||||||
}
|
|
||||||
|
|
||||||
return object : Buffer<T> {
|
/**
|
||||||
override fun get(index: Int): T = array[index]
|
* Simple boxing NDArray
|
||||||
|
*/
|
||||||
override fun set(index: Int, value: T) {
|
fun <T : Any> create(field: Field<T>, shape: List<Int>, initializer: (List<Int>) -> T): NDArray<T> {
|
||||||
array[index] = initializer(index)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <T : Any> createSimpleNDFieldFactory(field: Field<T>): NDFieldFactory<T> = { list -> SimpleNDField(field, list) }
|
|
||||||
|
|
||||||
fun <T : Any> simpleNDArray(field: Field<T>, shape: List<Int>, initializer: (List<Int>) -> T): NDArray<T> {
|
|
||||||
return SimpleNDField(field, shape).produce { initializer(it) }
|
return SimpleNDField(field, shape).produce { initializer(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mutable boxing NDArray
|
||||||
|
*/
|
||||||
|
fun <T : Any> createMutable(field: Field<T>, shape: List<Int>, initializer: (List<Int>) -> T): MutableNDArray<T> {
|
||||||
|
return SimpleNDField(field, shape).produceMutable { initializer(it) }
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.realVector
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.DoubleField
|
import scientifik.kmath.operations.DoubleField
|
||||||
import scientifik.kmath.structures.NDArrays.simpleNDArray
|
import scientifik.kmath.structures.NDArrays.create
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -9,7 +9,7 @@ import kotlin.test.assertEquals
|
|||||||
class SimpleNDFieldTest{
|
class SimpleNDFieldTest{
|
||||||
@Test
|
@Test
|
||||||
fun testStrides(){
|
fun testStrides(){
|
||||||
val ndArray = simpleNDArray(DoubleField, listOf(10,10)){(it[0]+it[1]).toDouble()}
|
val ndArray = create(DoubleField, listOf(10,10)){(it[0]+it[1]).toDouble()}
|
||||||
assertEquals(ndArray[5,5], 10.0)
|
assertEquals(ndArray[5,5], 10.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,4 +5,4 @@ import scientifik.kmath.operations.DoubleField
|
|||||||
/**
|
/**
|
||||||
* Using boxing implementation for js
|
* Using boxing implementation for js
|
||||||
*/
|
*/
|
||||||
actual val realNDFieldFactory: NDFieldFactory<Double> = NDArrays.createSimpleNDFieldFactory(DoubleField)
|
actual val realNDFieldFactory: NDFieldFactory<Double> = NDArrays.createFactory(DoubleField)
|
@ -7,12 +7,18 @@ private class RealNDField(shape: List<Int>) : BufferNDField<Double>(shape, Doubl
|
|||||||
override fun createBuffer(capacity: Int, initializer: (Int) -> Double): Buffer<Double> {
|
override fun createBuffer(capacity: Int, initializer: (Int) -> Double): Buffer<Double> {
|
||||||
val array = DoubleArray(capacity, initializer)
|
val array = DoubleArray(capacity, initializer)
|
||||||
val buffer = DoubleBuffer.wrap(array)
|
val buffer = DoubleBuffer.wrap(array)
|
||||||
return object : Buffer<Double> {
|
return BufferOfDoubles(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
private class BufferOfDoubles(val buffer: DoubleBuffer): Buffer<Double>{
|
||||||
override fun get(index: Int): Double = buffer.get(index)
|
override fun get(index: Int): Double = buffer.get(index)
|
||||||
|
|
||||||
override fun set(index: Int, value: Double) {
|
override fun set(index: Int, value: Double) {
|
||||||
buffer.put(index, value)
|
buffer.put(index, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun copy(): Buffer<Double> {
|
||||||
|
return BufferOfDoubles(buffer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,4 +10,5 @@ enableFeaturePreview('GRADLE_METADATA')
|
|||||||
|
|
||||||
rootProject.name = 'kmath'
|
rootProject.name = 'kmath'
|
||||||
include ':kmath-core'
|
include ':kmath-core'
|
||||||
|
include ':kmath-commons'
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user