forked from kscience/kmath
WIP Matrix refactor
This commit is contained in:
parent
19d3998c3b
commit
d5ba816b7d
@ -22,7 +22,7 @@ public typealias Vector<T> = Point<T>
|
|||||||
* @param T the type of items in the matrices.
|
* @param T the type of items in the matrices.
|
||||||
* @param M the type of operated matrices.
|
* @param M the type of operated matrices.
|
||||||
*/
|
*/
|
||||||
public interface LinearSpace<T : Any, A : Ring<T>> {
|
public interface LinearSpace<T : Any, out A : Ring<T>> {
|
||||||
public val elementAlgebra: A
|
public val elementAlgebra: A
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -228,7 +228,7 @@ public inline fun <reified T : Comparable<T>> LinearSpace<T, Field<T>>.inverseWi
|
|||||||
|
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public fun RealLinearSpace.solveWithLup(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
|
public fun LinearSpace<Double, RealField>.solveWithLup(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
|
||||||
// Use existing decomposition if it is provided by matrix
|
// Use existing decomposition if it is provided by matrix
|
||||||
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
|
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
|
||||||
val decomposition: LupDecomposition<Double> = a.getFeature() ?: lup(bufferFactory, a) { it < 1e-11 }
|
val decomposition: LupDecomposition<Double> = a.getFeature() ?: lup(bufferFactory, a) { it < 1e-11 }
|
||||||
@ -238,5 +238,5 @@ public fun RealLinearSpace.solveWithLup(a: Matrix<Double>, b: Matrix<Double>): M
|
|||||||
/**
|
/**
|
||||||
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
|
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
|
||||||
*/
|
*/
|
||||||
public fun RealLinearSpace.inverseWithLup(matrix: Matrix<Double>): Matrix<Double> =
|
public fun LinearSpace<Double, RealField>.inverseWithLup(matrix: Matrix<Double>): Matrix<Double> =
|
||||||
solveWithLup(matrix, one(matrix.rowNum, matrix.colNum))
|
solveWithLup(matrix, one(matrix.rowNum, matrix.colNum))
|
@ -2,7 +2,6 @@ package space.kscience.kmath.linear
|
|||||||
|
|
||||||
import space.kscience.kmath.nd.NDStructure
|
import space.kscience.kmath.nd.NDStructure
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.operations.invoke
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -17,7 +16,7 @@ class MatrixTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBuilder() {
|
fun testBuilder() {
|
||||||
val matrix = Matrix.build(2, 3)(
|
val matrix = LinearSpace.real.matrix(2, 3)(
|
||||||
1.0, 0.0, 0.0,
|
1.0, 0.0, 0.0,
|
||||||
0.0, 1.0, 2.0
|
0.0, 1.0, 2.0
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,8 @@ package space.kscience.kmath.dimensions
|
|||||||
|
|
||||||
import space.kscience.kmath.linear.*
|
import space.kscience.kmath.linear.*
|
||||||
import space.kscience.kmath.nd.Structure2D
|
import space.kscience.kmath.nd.Structure2D
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.RealField
|
||||||
|
import space.kscience.kmath.operations.Ring
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A matrix with compile-time controlled dimension
|
* A matrix with compile-time controlled dimension
|
||||||
@ -77,7 +78,7 @@ public inline class DPointWrapper<T, D : Dimension>(public val point: Point<T>)
|
|||||||
/**
|
/**
|
||||||
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
||||||
*/
|
*/
|
||||||
public inline class DMatrixContext<T : Any>(public val context: LinearSpace<T, Matrix<T>>) {
|
public inline class DMatrixContext<T : Any, out A : Ring<T>>(public val context: LinearSpace<T, A>) {
|
||||||
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
||||||
require(rowNum == Dimension.dim<R>().toInt()) {
|
require(rowNum == Dimension.dim<R>().toInt()) {
|
||||||
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
|
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
|
||||||
@ -93,13 +94,15 @@ public inline class DMatrixContext<T : Any>(public val context: LinearSpace<T, M
|
|||||||
/**
|
/**
|
||||||
* Produce a matrix with this context and given dimensions
|
* Produce a matrix with this context and given dimensions
|
||||||
*/
|
*/
|
||||||
public inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
|
public inline fun <reified R : Dimension, reified C : Dimension> produce(
|
||||||
|
noinline initializer: A.(i: Int, j: Int) -> T
|
||||||
|
): DMatrix<T, R, C> {
|
||||||
val rows = Dimension.dim<R>()
|
val rows = Dimension.dim<R>()
|
||||||
val cols = Dimension.dim<C>()
|
val cols = Dimension.dim<C>()
|
||||||
return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
|
return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
|
public inline fun <reified D : Dimension> point(noinline initializer: A.(Int) -> T): DPoint<T, D> {
|
||||||
val size = Dimension.dim<D>()
|
val size = Dimension.dim<D>()
|
||||||
|
|
||||||
return DPoint.coerceUnsafe(
|
return DPoint.coerceUnsafe(
|
||||||
@ -112,31 +115,31 @@ public inline class DMatrixContext<T : Any>(public val context: LinearSpace<T, M
|
|||||||
|
|
||||||
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||||
other: DMatrix<T, C1, C2>,
|
other: DMatrix<T, C1, C2>,
|
||||||
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
): DMatrix<T, R1, C2> = context.run { this@dot dot other }.coerce()
|
||||||
|
|
||||||
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
||||||
DPoint.coerceUnsafe(context { this@dot dot vector })
|
DPoint.coerceUnsafe(context.run { this@dot dot vector })
|
||||||
|
|
||||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> =
|
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> =
|
||||||
context { this@times.times(value) }.coerce()
|
context.run { this@times.times(value) }.coerce()
|
||||||
|
|
||||||
public inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
public inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
||||||
m * this
|
m * this
|
||||||
|
|
||||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||||
context { this@plus + other }.coerce()
|
context.run { this@plus + other }.coerce()
|
||||||
|
|
||||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||||
context { this@minus + other }.coerce()
|
context.run { this@minus + other }.coerce()
|
||||||
|
|
||||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> =
|
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> =
|
||||||
context { this@unaryMinus.unaryMinus() }.coerce()
|
context.run { this@unaryMinus.unaryMinus() }.coerce()
|
||||||
|
|
||||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
||||||
context { (this@transpose as Matrix<T>).transpose() }.coerce()
|
context.run { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val real: DMatrixContext<Double> = DMatrixContext(LinearSpace.real)
|
public val real: DMatrixContext<Double, RealField> = DMatrixContext(LinearSpace.real)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,11 +147,11 @@ public inline class DMatrixContext<T : Any>(public val context: LinearSpace<T, M
|
|||||||
/**
|
/**
|
||||||
* A square unit matrix
|
* A square unit matrix
|
||||||
*/
|
*/
|
||||||
public inline fun <reified D : Dimension> DMatrixContext<Double>.one(): DMatrix<Double, D, D> = produce { i, j ->
|
public inline fun <reified D : Dimension> DMatrixContext<Double, RealField>.one(): DMatrix<Double, D, D> = produce { i, j ->
|
||||||
if (i == j) 1.0 else 0.0
|
if (i == j) 1.0 else 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double>.zero(): DMatrix<Double, R, C> =
|
public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double, RealField>.zero(): DMatrix<Double, R, C> =
|
||||||
produce { _, _ ->
|
produce { _, _ ->
|
||||||
0.0
|
0.0
|
||||||
}
|
}
|
@ -44,7 +44,7 @@ public object EjmlLinearSpace : LinearSpace<Double, RealField>, ScaleOperations<
|
|||||||
|
|
||||||
override fun buildVector(size: Int, initializer: RealField.(Int) -> Double): Vector<Double> =
|
override fun buildVector(size: Int, initializer: RealField.(Int) -> Double): Vector<Double> =
|
||||||
EjmlVector(SimpleMatrix(size, 1).also {
|
EjmlVector(SimpleMatrix(size, 1).also {
|
||||||
(0 until it.numRows()).forEach { row -> it[row, 0] = initializer(row) }
|
(0 until it.numRows()).forEach { row -> it[row, 0] = RealField.initializer(row) }
|
||||||
})
|
})
|
||||||
|
|
||||||
private fun SimpleMatrix.wrapMatrix() = EjmlMatrix(this)
|
private fun SimpleMatrix.wrapMatrix() = EjmlMatrix(this)
|
||||||
@ -59,43 +59,34 @@ public object EjmlLinearSpace : LinearSpace<Double, RealField>, ScaleOperations<
|
|||||||
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
|
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
|
||||||
|
|
||||||
public override operator fun Matrix<Double>.minus(other: Matrix<Double>): EjmlMatrix =
|
public override operator fun Matrix<Double>.minus(other: Matrix<Double>): EjmlMatrix =
|
||||||
EjmlMatrix(toEjml().origin - other.toEjml().origin)
|
(toEjml().origin - other.toEjml().origin).wrapMatrix()
|
||||||
|
|
||||||
public override fun scale(a: Matrix<Double>, value: Double): EjmlMatrix =
|
public override fun scale(a: Matrix<Double>, value: Double): EjmlMatrix =
|
||||||
a.toEjml().origin.scale(value).wrapMatrix()
|
a.toEjml().origin.scale(value).wrapMatrix()
|
||||||
|
|
||||||
public override operator fun Matrix<Double>.times(value: Double): EjmlMatrix =
|
public override operator fun Matrix<Double>.times(value: Double): EjmlMatrix =
|
||||||
EjmlMatrix(toEjml().origin.scale(value))
|
toEjml().origin.scale(value).wrapMatrix()
|
||||||
|
|
||||||
override fun Vector<Double>.unaryMinus(): EjmlVector =
|
override fun Vector<Double>.unaryMinus(): EjmlVector =
|
||||||
toEjml().origin.negative().wrapVector()
|
toEjml().origin.negative().wrapVector()
|
||||||
|
|
||||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> {
|
override fun Matrix<Double>.plus(other: Matrix<Double>): EjmlMatrix =
|
||||||
TODO("Not yet implemented")
|
(toEjml().origin + other.toEjml().origin).wrapMatrix()
|
||||||
}
|
|
||||||
|
|
||||||
override fun Vector<Double>.plus(other: Vector<Double>): Vector<Double> {
|
override fun Vector<Double>.plus(other: Vector<Double>): EjmlVector =
|
||||||
TODO("Not yet implemented")
|
(toEjml().origin + other.toEjml().origin).wrapVector()
|
||||||
}
|
|
||||||
|
|
||||||
override fun Vector<Double>.minus(other: Vector<Double>): Vector<Double> {
|
override fun Vector<Double>.minus(other: Vector<Double>): EjmlVector =
|
||||||
TODO("Not yet implemented")
|
(toEjml().origin - other.toEjml().origin).wrapVector()
|
||||||
}
|
|
||||||
|
|
||||||
override fun Double.times(m: Matrix<Double>): Matrix<Double> {
|
override fun Double.times(m: Matrix<Double>): EjmlMatrix =
|
||||||
TODO("Not yet implemented")
|
m.toEjml().origin.scale(this).wrapMatrix()
|
||||||
}
|
|
||||||
|
|
||||||
override fun Vector<Double>.times(value: Double): Vector<Double> {
|
override fun Vector<Double>.times(value: Double): EjmlVector =
|
||||||
TODO("Not yet implemented")
|
toEjml().origin.scale(value).wrapVector()
|
||||||
}
|
|
||||||
|
|
||||||
override fun Double.times(v: Vector<Double>): Vector<Double> {
|
override fun Double.times(v: Vector<Double>): EjmlVector =
|
||||||
TODO("Not yet implemented")
|
v.toEjml().origin.scale(this).wrapVector()
|
||||||
}
|
|
||||||
|
|
||||||
public override fun add(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
|
||||||
EjmlMatrix(a.toEjml().origin + b.toEjml().origin)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2,6 +2,7 @@ package space.kscience.kmath.real
|
|||||||
|
|
||||||
import space.kscience.kmath.linear.*
|
import space.kscience.kmath.linear.*
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.RealField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.RealBuffer
|
import space.kscience.kmath.structures.RealBuffer
|
||||||
import space.kscience.kmath.structures.asIterable
|
import space.kscience.kmath.structures.asIterable
|
||||||
@ -21,11 +22,11 @@ import kotlin.math.pow
|
|||||||
|
|
||||||
public typealias RealMatrix = Matrix<Double>
|
public typealias RealMatrix = Matrix<Double>
|
||||||
|
|
||||||
public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
public fun realMatrix(rowNum: Int, colNum: Int, initializer: RealField.(i: Int, j: Int) -> Double): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum, initializer)
|
LinearSpace.real.buildMatrix(rowNum, colNum, initializer)
|
||||||
|
|
||||||
public fun Array<DoubleArray>.toMatrix(): RealMatrix {
|
public fun Array<DoubleArray>.toMatrix(): RealMatrix {
|
||||||
return LinearSpace.real.buildMatrix(size, this[0].size) { row, col -> this[row][col] }
|
return LinearSpace.real.buildMatrix(size, this[0].size) { row, col -> this@toMatrix[row][col] }
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let {
|
public fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let {
|
||||||
@ -43,37 +44,37 @@ public fun RealMatrix.repeatStackVertical(n: Int): RealMatrix =
|
|||||||
|
|
||||||
public operator fun RealMatrix.times(double: Double): RealMatrix =
|
public operator fun RealMatrix.times(double: Double): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
||||||
this[row, col] * double
|
get(row, col) * double
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun RealMatrix.plus(double: Double): RealMatrix =
|
public operator fun RealMatrix.plus(double: Double): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
||||||
this[row, col] + double
|
get(row, col) + double
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun RealMatrix.minus(double: Double): RealMatrix =
|
public operator fun RealMatrix.minus(double: Double): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
||||||
this[row, col] - double
|
get(row, col) - double
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun RealMatrix.div(double: Double): RealMatrix =
|
public operator fun RealMatrix.div(double: Double): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col ->
|
||||||
this[row, col] / double
|
get(row, col) / double
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun Double.times(matrix: RealMatrix): RealMatrix =
|
public operator fun Double.times(matrix: RealMatrix): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col ->
|
LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col ->
|
||||||
this * matrix[row, col]
|
this@times * matrix[row, col]
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun Double.plus(matrix: RealMatrix): RealMatrix =
|
public operator fun Double.plus(matrix: RealMatrix): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col ->
|
LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col ->
|
||||||
this + matrix[row, col]
|
this@plus + matrix[row, col]
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun Double.minus(matrix: RealMatrix): RealMatrix =
|
public operator fun Double.minus(matrix: RealMatrix): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col ->
|
LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col ->
|
||||||
this - matrix[row, col]
|
this@minus - matrix[row, col]
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: does this operation make sense? Should it be 'this/matrix[row, col]'?
|
// TODO: does this operation make sense? Should it be 'this/matrix[row, col]'?
|
||||||
@ -87,13 +88,13 @@ public operator fun Double.minus(matrix: RealMatrix): RealMatrix =
|
|||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public operator fun RealMatrix.times(other: RealMatrix): RealMatrix =
|
public operator fun RealMatrix.times(other: RealMatrix): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this[row, col] * other[row, col] }
|
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this@times[row, col] * other[row, col] }
|
||||||
|
|
||||||
public operator fun RealMatrix.plus(other: RealMatrix): RealMatrix =
|
public operator fun RealMatrix.plus(other: RealMatrix): RealMatrix =
|
||||||
LinearSpace.real.add(this, other)
|
LinearSpace.real.run { this@plus + other }
|
||||||
|
|
||||||
public operator fun RealMatrix.minus(other: RealMatrix): RealMatrix =
|
public operator fun RealMatrix.minus(other: RealMatrix): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this[row, col] - other[row, col] }
|
LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this@minus[row, col] - other[row, col] }
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Operations on columns
|
* Operations on columns
|
||||||
@ -102,14 +103,14 @@ public operator fun RealMatrix.minus(other: RealMatrix): RealMatrix =
|
|||||||
public inline fun RealMatrix.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): RealMatrix =
|
public inline fun RealMatrix.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, colNum + 1) { row, col ->
|
LinearSpace.real.buildMatrix(rowNum, colNum + 1) { row, col ->
|
||||||
if (col < colNum)
|
if (col < colNum)
|
||||||
this[row, col]
|
get(row, col)
|
||||||
else
|
else
|
||||||
mapper(rows[row])
|
mapper(rows[row])
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun RealMatrix.extractColumns(columnRange: IntRange): RealMatrix =
|
public fun RealMatrix.extractColumns(columnRange: IntRange): RealMatrix =
|
||||||
LinearSpace.real.buildMatrix(rowNum, columnRange.count()) { row, col ->
|
LinearSpace.real.buildMatrix(rowNum, columnRange.count()) { row, col ->
|
||||||
this[row, columnRange.first + col]
|
this@extractColumns[row, columnRange.first + col]
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix =
|
public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix =
|
||||||
|
@ -1,31 +1,31 @@
|
|||||||
package space.kscience.kmath.real
|
package space.kscience.kmath.real
|
||||||
|
|
||||||
import space.kscience.kmath.linear.BufferMatrix
|
import space.kscience.kmath.linear.LinearSpace
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.nd.Matrix
|
||||||
import space.kscience.kmath.structures.RealBuffer
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimized dot product for real matrices
|
* Optimized dot product for real matrices
|
||||||
*/
|
*/
|
||||||
public infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
|
public infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = LinearSpace.real.run{
|
||||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
this@dot dot other
|
||||||
val resultArray = DoubleArray(this.rowNum * other.colNum)
|
// require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
// val resultArray = DoubleArray(this.rowNum * other.colNum)
|
||||||
//convert to array to insure there is no memory indirection
|
//
|
||||||
fun Buffer<Double>.unsafeArray() = if (this is RealBuffer)
|
// //convert to array to insure there is no memory indirection
|
||||||
this.array
|
// fun Buffer<Double>.unsafeArray() = if (this is RealBuffer)
|
||||||
else
|
// this.array
|
||||||
DoubleArray(size) { get(it) }
|
// else
|
||||||
|
// DoubleArray(size) { get(it) }
|
||||||
val a = this.buffer.unsafeArray()
|
//
|
||||||
val b = other.buffer.unsafeArray()
|
// val a = this.buffer.unsafeArray()
|
||||||
|
// val b = other.buffer.unsafeArray()
|
||||||
for (i in (0 until rowNum))
|
//
|
||||||
for (j in (0 until other.colNum))
|
// for (i in (0 until rowNum))
|
||||||
for (k in (0 until colNum))
|
// for (j in (0 until other.colNum))
|
||||||
resultArray[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j]
|
// for (k in (0 until colNum))
|
||||||
|
// resultArray[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j]
|
||||||
val buffer = RealBuffer(resultArray)
|
//
|
||||||
return BufferMatrix(rowNum, other.colNum, buffer)
|
// val buffer = RealBuffer(resultArray)
|
||||||
|
// return BufferMatrix(rowNum, other.colNum, buffer)
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user