Merge branch 'dev' into feature/torch

This commit is contained in:
rgrit91 2021-01-08 06:57:45 +00:00
commit cfe93886ac
14 changed files with 183 additions and 117 deletions

View File

@ -2,19 +2,22 @@ package kscience.kmath.benchmarks
import kotlinx.benchmark.Benchmark import kotlinx.benchmark.Benchmark
import kscience.kmath.commons.linear.CMMatrixContext import kscience.kmath.commons.linear.CMMatrixContext
import kscience.kmath.commons.linear.CMMatrixContext.dot
import kscience.kmath.commons.linear.toCM import kscience.kmath.commons.linear.toCM
import kscience.kmath.ejml.EjmlMatrixContext import kscience.kmath.ejml.EjmlMatrixContext
import kscience.kmath.ejml.toEjml import kscience.kmath.ejml.toEjml
import kscience.kmath.linear.BufferMatrixContext
import kscience.kmath.linear.RealMatrixContext
import kscience.kmath.linear.real import kscience.kmath.linear.real
import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import kscience.kmath.structures.Buffer
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import kotlin.random.Random import kotlin.random.Random
@State(Scope.Benchmark) @State(Scope.Benchmark)
class MultiplicationBenchmark { class DotBenchmark {
companion object { companion object {
val random = Random(12224) val random = Random(12224)
val dim = 1000 val dim = 1000
@ -32,14 +35,14 @@ class MultiplicationBenchmark {
@Benchmark @Benchmark
fun commonsMathMultiplication() { fun commonsMathMultiplication() {
CMMatrixContext.invoke { CMMatrixContext {
cmMatrix1 dot cmMatrix2 cmMatrix1 dot cmMatrix2
} }
} }
@Benchmark @Benchmark
fun ejmlMultiplication() { fun ejmlMultiplication() {
EjmlMatrixContext.invoke { EjmlMatrixContext {
ejmlMatrix1 dot ejmlMatrix2 ejmlMatrix1 dot ejmlMatrix2
} }
} }
@ -48,13 +51,22 @@ class MultiplicationBenchmark {
fun ejmlMultiplicationwithConversion() { fun ejmlMultiplicationwithConversion() {
val ejmlMatrix1 = matrix1.toEjml() val ejmlMatrix1 = matrix1.toEjml()
val ejmlMatrix2 = matrix2.toEjml() val ejmlMatrix2 = matrix2.toEjml()
EjmlMatrixContext.invoke { EjmlMatrixContext {
ejmlMatrix1 dot ejmlMatrix2 ejmlMatrix1 dot ejmlMatrix2
} }
} }
@Benchmark @Benchmark
fun bufferedMultiplication() { fun bufferedMultiplication() {
matrix1 dot matrix2 BufferMatrixContext(RealField, Buffer.Companion::real).invoke{
matrix1 dot matrix2
}
}
@Benchmark
fun realMultiplication(){
RealMatrixContext {
matrix1 dot matrix2
}
} }
} }

View File

@ -4,9 +4,8 @@ import kscience.kmath.dimensions.D2
import kscience.kmath.dimensions.D3 import kscience.kmath.dimensions.D3
import kscience.kmath.dimensions.DMatrixContext import kscience.kmath.dimensions.DMatrixContext
import kscience.kmath.dimensions.Dimension import kscience.kmath.dimensions.Dimension
import kscience.kmath.operations.RealField
private fun DMatrixContext<Double, RealField>.simple() { private fun DMatrixContext<Double>.simple() {
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() } val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() } val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
@ -18,7 +17,7 @@ private object D5 : Dimension {
override val dim: UInt = 5u override val dim: UInt = 5u
} }
private fun DMatrixContext<Double, RealField>.custom() { private fun DMatrixContext<Double>.custom() {
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() } val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() } val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() } val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }

View File

@ -2,10 +2,9 @@ package kscience.kmath.ast
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/** /**
* A Mathematical Syntax Tree node for mathematical expressions. * A Mathematical Syntax Tree (MST) node for mathematical expressions.
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
@ -57,21 +56,22 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value))
is MST.Unary -> when {
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
}
is MST.Binary -> when { is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
node.left is MST.Numeric && node.right is MST.Numeric -> { this is NumericAlgebra && node.left is MST.Numeric ->
val number = RealField leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
.binaryOperationFunction(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
number(number) this is NumericAlgebra && node.right is MST.Numeric ->
} rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
} }
} }

View File

@ -1,18 +1,18 @@
package kscience.kmath.estree package kscience.kmath.estree
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.ast.MST.*
import kscience.kmath.ast.MstExpression import kscience.kmath.ast.MstExpression
import kscience.kmath.estree.internal.ESTreeBuilder import kscience.kmath.estree.internal.ESTreeBuilder
import kscience.kmath.estree.internal.estree.BaseExpression import kscience.kmath.estree.internal.estree.BaseExpression
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
@PublishedApi @PublishedApi
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> { internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) { fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
is MST.Symbolic -> { is Symbolic -> {
val symbol = try { val symbol = try {
algebra.symbol(node.value) algebra.symbol(node.value)
} catch (ignored: IllegalStateException) { } catch (ignored: IllegalStateException) {
@ -25,25 +25,29 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
variable(node.value) variable(node.value)
} }
is MST.Numeric -> constant(node.value) is Numeric -> constant(node.value)
is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
is MST.Binary -> when { is Unary -> when {
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> constant( algebra is NumericAlgebra && node.value is Numeric -> constant(
algebra.number( algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
RealField
.binaryOperationFunction(node.operation) else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
.invoke(node.left.value.toDouble(), node.right.value.toDouble()) }
)
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
algebra
.binaryOperationFunction(node.operation)
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
) )
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> call( algebra is NumericAlgebra && node.left is Numeric -> call(
algebra.leftSideNumberOperationFunction(node.operation), algebra.leftSideNumberOperationFunction(node.operation),
visit(node.left), visit(node.left),
visit(node.right), visit(node.right),
) )
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> call( algebra is NumericAlgebra && node.right is Numeric -> call(
algebra.rightSideNumberOperationFunction(node.operation), algebra.rightSideNumberOperationFunction(node.operation),
visit(node.left), visit(node.left),
visit(node.right), visit(node.right),

View File

@ -3,11 +3,11 @@ package kscience.kmath.asm
import kscience.kmath.asm.internal.AsmBuilder import kscience.kmath.asm.internal.AsmBuilder
import kscience.kmath.asm.internal.buildName import kscience.kmath.asm.internal.buildName
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.ast.MST.*
import kscience.kmath.ast.MstExpression import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/** /**
* Compiles given MST to an Expression using AST compiler. * Compiles given MST to an Expression using AST compiler.
@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField
@PublishedApi @PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> { internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) { fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
is MST.Symbolic -> { is Symbolic -> {
val symbol = try { val symbol = try {
algebra.symbol(node.value) algebra.symbol(node.value)
} catch (ignored: IllegalStateException) { } catch (ignored: IllegalStateException) {
@ -33,24 +33,29 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
loadVariable(node.value) loadVariable(node.value)
} }
is MST.Numeric -> loadNumberConstant(node.value) is Numeric -> loadNumberConstant(node.value)
is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
is MST.Binary -> when { is Unary -> when {
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant( algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
algebra.number( algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
RealField
.binaryOperationFunction(node.operation) else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
.invoke(node.left.value.toDouble(), node.right.value.toDouble()) }
)
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
algebra.binaryOperationFunction(node.operation)
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
) )
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) { algebra is NumericAlgebra && node.left is Numeric -> buildCall(
algebra.leftSideNumberOperationFunction(node.operation)) {
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
} }
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) { algebra is NumericAlgebra && node.right is Numeric -> buildCall(
algebra.rightSideNumberOperationFunction(node.operation)) {
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
} }

View File

@ -191,7 +191,7 @@ internal class AsmBuilder<T>(
} }
val cls = classLoader.defineClass(className, classWriter.toByteArray()) val cls = classLoader.defineClass(className, classWriter.toByteArray())
java.io.File("dump.class").writeBytes(classWriter.toByteArray()) // java.io.File("dump.class").writeBytes(classWriter.toByteArray())
val l = MethodHandles.publicLookup() val l = MethodHandles.publicLookup()
if (hasConstants) if (hasConstants)

View File

@ -29,6 +29,7 @@ public class CMMatrix(public val origin: RealMatrix, features: Set<MatrixFeature
} }
} }
//TODO move inside context
public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) { public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
this this
} else { } else {

View File

@ -1,8 +1,10 @@
package kscience.kmath.linear package kscience.kmath.linear
import kscience.kmath.operations.RealField
import kscience.kmath.operations.Ring import kscience.kmath.operations.Ring
import kscience.kmath.structures.* import kscience.kmath.structures.Buffer
import kscience.kmath.structures.BufferFactory
import kscience.kmath.structures.NDStructure
import kscience.kmath.structures.asSequence
/** /**
* Basic implementation of Matrix space based on [NDStructure] * Basic implementation of Matrix space based on [NDStructure]
@ -21,24 +23,6 @@ public class BufferMatrixContext<T : Any, R : Ring<T>>(
public companion object public companion object
} }
@Suppress("OVERRIDE_BY_INLINE")
public object RealMatrixContext : GenericMatrixContext<Double, RealField, BufferMatrix<Double>> {
public override val elementContext: RealField
get() = RealField
public override inline fun produce(
rows: Int,
columns: Int,
initializer: (i: Int, j: Int) -> Double,
): BufferMatrix<Double> {
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer)
}
public override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
RealBuffer(size, initializer)
}
public class BufferMatrix<T : Any>( public class BufferMatrix<T : Any>(
public override val rowNum: Int, public override val rowNum: Int,
public override val colNum: Int, public override val colNum: Int,

View File

@ -213,17 +213,8 @@ public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext
return decomposition.solveWithLUP(bufferFactory, b) return decomposition.solveWithLUP(bufferFactory, b)
} }
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): FeaturedMatrix<Double> =
solveWithLUP(a, b) { it < 1e-11 }
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.inverseWithLUP( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.inverseWithLUP(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto, noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
noinline checkSingular: (T) -> Boolean, noinline checkSingular: (T) -> Boolean,
): FeaturedMatrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) ): FeaturedMatrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
/**
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
*/
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): FeaturedMatrix<Double> =
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), Buffer.Companion::real) { it < 1e-11 }

View File

@ -18,6 +18,11 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
*/ */
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
/**
* Produce a point compatible with matrix space (and possibly optimized for it)
*/
public fun point(size: Int, initializer: (Int) -> T): Point<T> = Buffer.boxing(size, initializer)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public override fun binaryOperationFunction(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M = public override fun binaryOperationFunction(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M =
when (operation) { when (operation) {
@ -62,10 +67,6 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
public operator fun T.times(m: Matrix<T>): M = m * this public operator fun T.times(m: Matrix<T>): M = m * this
public companion object { public companion object {
/**
* Non-boxing double matrix
*/
public val real: RealMatrixContext = RealMatrixContext
/** /**
* A structured matrix with custom buffer * A structured matrix with custom buffer
@ -89,11 +90,6 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>, out M : Matrix<T>> :
*/ */
public val elementContext: R public val elementContext: R
/**
* Produce a point compatible with matrix space
*/
public fun point(size: Int, initializer: (Int) -> T): Point<T>
public override infix fun Matrix<T>.dot(other: Matrix<T>): M { public override infix fun Matrix<T>.dot(other: Matrix<T>): M {
//TODO add typed error //TODO add typed error
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }

View File

@ -0,0 +1,84 @@
package kscience.kmath.linear
import kscience.kmath.operations.RealField
import kscience.kmath.structures.Matrix
import kscience.kmath.structures.MutableBuffer
import kscience.kmath.structures.MutableBufferFactory
import kscience.kmath.structures.RealBuffer
@Suppress("OVERRIDE_BY_INLINE")
public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
public override inline fun produce(
rows: Int,
columns: Int,
initializer: (i: Int, j: Int) -> Double,
): BufferMatrix<Double> {
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer)
}
private fun Matrix<Double>.wrap(): BufferMatrix<Double> = if (this is BufferMatrix) this else {
produce(rowNum, colNum) { i, j -> get(i, j) }
}
public fun one(rows: Int, columns: Int): FeaturedMatrix<Double> = VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
if (i == j) 1.0 else 0.0
}
public override infix fun Matrix<Double>.dot(other: Matrix<Double>): BufferMatrix<Double> {
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
return produce(rowNum, other.colNum) { i, j ->
var res = 0.0
for (l in 0 until colNum) {
res += get(i, l) * other.get(l, j)
}
res
}
}
public override infix fun Matrix<Double>.dot(vector: Point<Double>): Point<Double> {
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
return RealBuffer(rowNum) { i ->
var res = 0.0
for (j in 0 until colNum) {
res += get(i, j) * vector[j]
}
res
}
}
override fun add(a: Matrix<Double>, b: Matrix<Double>): BufferMatrix<Double> {
require(a.rowNum == b.rowNum) { "Row number mismatch in matrix addition. Left side: ${a.rowNum}, right side: ${b.rowNum}" }
require(a.colNum == b.colNum) { "Column number mismatch in matrix addition. Left side: ${a.colNum}, right side: ${b.colNum}" }
return produce(a.rowNum, a.colNum) { i, j ->
a[i, j] + b[i, j]
}
}
override fun Matrix<Double>.times(value: Double): BufferMatrix<Double> =
produce(rowNum, colNum) { i, j -> get(i, j) * value }
override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> =
produce(a.rowNum, a.colNum) { i, j -> a.get(i, j) * k.toDouble() }
}
/**
* Partially optimized real-valued matrix
*/
public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): FeaturedMatrix<Double> {
// Use existing decomposition if it is provided by matrix
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
val decomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 }
return decomposition.solveWithLUP(bufferFactory, b)
}
/**
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
*/
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): FeaturedMatrix<Double> =
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum))

View File

@ -1,11 +1,6 @@
package kscience.kmath.dimensions package kscience.kmath.dimensions
import kscience.kmath.linear.GenericMatrixContext import kscience.kmath.linear.*
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.Point
import kscience.kmath.linear.transpose
import kscience.kmath.operations.RealField
import kscience.kmath.operations.Ring
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.Structure2D import kscience.kmath.structures.Structure2D
@ -42,7 +37,7 @@ public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
* An inline wrapper for a Matrix * An inline wrapper for a Matrix
*/ */
public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>( public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
private val structure: Structure2D<T> private val structure: Structure2D<T>,
) : DMatrix<T, R, C> { ) : DMatrix<T, R, C> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override operator fun get(i: Int, j: Int): T = structure[i, j] override operator fun get(i: Int, j: Int): T = structure[i, j]
@ -81,7 +76,7 @@ public inline class DPointWrapper<T, D : Dimension>(public val point: Point<T>)
/** /**
* Basic operations on dimension-safe matrices. Operates on [Matrix] * Basic operations on dimension-safe matrices. Operates on [Matrix]
*/ */
public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: GenericMatrixContext<T, Ri, Matrix<T>>) { public inline class DMatrixContext<T : Any>(public val context: MatrixContext<T, Matrix<T>>) {
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> { public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
require(rowNum == Dimension.dim<R>().toInt()) { require(rowNum == Dimension.dim<R>().toInt()) {
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum" "Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
@ -115,7 +110,7 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
} }
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot( public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
other: DMatrix<T, C1, C2> other: DMatrix<T, C1, C2>,
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce() ): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> = public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
@ -139,18 +134,19 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> = public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
context { (this@transpose as Matrix<T>).transpose() }.coerce() context { (this@transpose as Matrix<T>).transpose() }.coerce()
/**
* A square unit matrix
*/
public inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
if (i == j) context.elementContext.one else context.elementContext.zero
}
public inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
context.elementContext.zero
}
public companion object { public companion object {
public val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real) public val real: DMatrixContext<Double> = DMatrixContext(MatrixContext.real)
} }
} }
/**
* A square unit matrix
*/
public inline fun <reified D : Dimension> DMatrixContext<Double>.one(): DMatrix<Double, D, D> = produce { i, j ->
if (i == j) 1.0 else 0.0
}
public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double>.zero(): DMatrix<Double, R, C> = produce { _, _ ->
0.0
}

View File

@ -3,6 +3,7 @@ package kscience.dimensions
import kscience.kmath.dimensions.D2 import kscience.kmath.dimensions.D2
import kscience.kmath.dimensions.D3 import kscience.kmath.dimensions.D3
import kscience.kmath.dimensions.DMatrixContext import kscience.kmath.dimensions.DMatrixContext
import kscience.kmath.dimensions.one
import kotlin.test.Test import kotlin.test.Test
internal class DMatrixContextTest { internal class DMatrixContextTest {

View File

@ -1,13 +1,7 @@
package kscience.kmath.real package kscience.kmath.real
import kscience.kmath.linear.FeaturedMatrix import kscience.kmath.linear.*
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.RealMatrixContext.elementContext
import kscience.kmath.linear.VirtualMatrix
import kscience.kmath.linear.inverseWithLUP
import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.operations.invoke
import kscience.kmath.operations.sum
import kscience.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.RealBuffer
import kscience.kmath.structures.asIterable import kscience.kmath.structures.asIterable
@ -122,8 +116,7 @@ public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix =
extractColumns(columnIndex..columnIndex) extractColumns(columnIndex..columnIndex)
public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
val column = columns[j] columns[j].asIterable().sum()
elementContext { sum(column.asIterable()) }
} }
public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->