Merge branch 'dev' into feature/torch
This commit is contained in:
commit
cfe93886ac
@ -2,19 +2,22 @@ package kscience.kmath.benchmarks
|
|||||||
|
|
||||||
import kotlinx.benchmark.Benchmark
|
import kotlinx.benchmark.Benchmark
|
||||||
import kscience.kmath.commons.linear.CMMatrixContext
|
import kscience.kmath.commons.linear.CMMatrixContext
|
||||||
import kscience.kmath.commons.linear.CMMatrixContext.dot
|
|
||||||
import kscience.kmath.commons.linear.toCM
|
import kscience.kmath.commons.linear.toCM
|
||||||
import kscience.kmath.ejml.EjmlMatrixContext
|
import kscience.kmath.ejml.EjmlMatrixContext
|
||||||
import kscience.kmath.ejml.toEjml
|
import kscience.kmath.ejml.toEjml
|
||||||
|
import kscience.kmath.linear.BufferMatrixContext
|
||||||
|
import kscience.kmath.linear.RealMatrixContext
|
||||||
import kscience.kmath.linear.real
|
import kscience.kmath.linear.real
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.structures.Buffer
|
||||||
import kscience.kmath.structures.Matrix
|
import kscience.kmath.structures.Matrix
|
||||||
import org.openjdk.jmh.annotations.Scope
|
import org.openjdk.jmh.annotations.Scope
|
||||||
import org.openjdk.jmh.annotations.State
|
import org.openjdk.jmh.annotations.State
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
class MultiplicationBenchmark {
|
class DotBenchmark {
|
||||||
companion object {
|
companion object {
|
||||||
val random = Random(12224)
|
val random = Random(12224)
|
||||||
val dim = 1000
|
val dim = 1000
|
||||||
@ -32,14 +35,14 @@ class MultiplicationBenchmark {
|
|||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun commonsMathMultiplication() {
|
fun commonsMathMultiplication() {
|
||||||
CMMatrixContext.invoke {
|
CMMatrixContext {
|
||||||
cmMatrix1 dot cmMatrix2
|
cmMatrix1 dot cmMatrix2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun ejmlMultiplication() {
|
fun ejmlMultiplication() {
|
||||||
EjmlMatrixContext.invoke {
|
EjmlMatrixContext {
|
||||||
ejmlMatrix1 dot ejmlMatrix2
|
ejmlMatrix1 dot ejmlMatrix2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -48,13 +51,22 @@ class MultiplicationBenchmark {
|
|||||||
fun ejmlMultiplicationwithConversion() {
|
fun ejmlMultiplicationwithConversion() {
|
||||||
val ejmlMatrix1 = matrix1.toEjml()
|
val ejmlMatrix1 = matrix1.toEjml()
|
||||||
val ejmlMatrix2 = matrix2.toEjml()
|
val ejmlMatrix2 = matrix2.toEjml()
|
||||||
EjmlMatrixContext.invoke {
|
EjmlMatrixContext {
|
||||||
ejmlMatrix1 dot ejmlMatrix2
|
ejmlMatrix1 dot ejmlMatrix2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun bufferedMultiplication() {
|
fun bufferedMultiplication() {
|
||||||
|
BufferMatrixContext(RealField, Buffer.Companion::real).invoke{
|
||||||
matrix1 dot matrix2
|
matrix1 dot matrix2
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun realMultiplication(){
|
||||||
|
RealMatrixContext {
|
||||||
|
matrix1 dot matrix2
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -4,9 +4,8 @@ import kscience.kmath.dimensions.D2
|
|||||||
import kscience.kmath.dimensions.D3
|
import kscience.kmath.dimensions.D3
|
||||||
import kscience.kmath.dimensions.DMatrixContext
|
import kscience.kmath.dimensions.DMatrixContext
|
||||||
import kscience.kmath.dimensions.Dimension
|
import kscience.kmath.dimensions.Dimension
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
private fun DMatrixContext<Double, RealField>.simple() {
|
private fun DMatrixContext<Double>.simple() {
|
||||||
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||||
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||||
|
|
||||||
@ -18,7 +17,7 @@ private object D5 : Dimension {
|
|||||||
override val dim: UInt = 5u
|
override val dim: UInt = 5u
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun DMatrixContext<Double, RealField>.custom() {
|
private fun DMatrixContext<Double>.custom() {
|
||||||
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
||||||
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
||||||
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
||||||
|
@ -2,10 +2,9 @@ package kscience.kmath.ast
|
|||||||
|
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Mathematical Syntax Tree node for mathematical expressions.
|
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
@ -57,21 +56,22 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
|||||||
?: error("Numeric nodes are not supported by $this")
|
?: error("Numeric nodes are not supported by $this")
|
||||||
|
|
||||||
is MST.Symbolic -> symbol(node.value)
|
is MST.Symbolic -> symbol(node.value)
|
||||||
is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is MST.Unary -> when {
|
||||||
this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
|
||||||
|
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
|
||||||
val number = RealField
|
|
||||||
.binaryOperationFunction(node.operation)
|
|
||||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
|
||||||
|
|
||||||
number(number)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
is MST.Binary -> when {
|
||||||
node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
|
||||||
|
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
|
||||||
|
|
||||||
|
this is NumericAlgebra && node.left is MST.Numeric ->
|
||||||
|
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||||
|
|
||||||
|
this is NumericAlgebra && node.right is MST.Numeric ->
|
||||||
|
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||||
|
|
||||||
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
package kscience.kmath.estree
|
package kscience.kmath.estree
|
||||||
|
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MST.*
|
||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.estree.internal.ESTreeBuilder
|
import kscience.kmath.estree.internal.ESTreeBuilder
|
||||||
import kscience.kmath.estree.internal.estree.BaseExpression
|
import kscience.kmath.estree.internal.estree.BaseExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||||
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
||||||
is MST.Symbolic -> {
|
is Symbolic -> {
|
||||||
val symbol = try {
|
val symbol = try {
|
||||||
algebra.symbol(node.value)
|
algebra.symbol(node.value)
|
||||||
} catch (ignored: IllegalStateException) {
|
} catch (ignored: IllegalStateException) {
|
||||||
@ -25,25 +25,29 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
|||||||
variable(node.value)
|
variable(node.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Numeric -> constant(node.value)
|
is Numeric -> constant(node.value)
|
||||||
is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is Unary -> when {
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> constant(
|
algebra is NumericAlgebra && node.value is Numeric -> constant(
|
||||||
algebra.number(
|
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||||
RealField
|
|
||||||
|
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
||||||
|
}
|
||||||
|
|
||||||
|
is Binary -> when {
|
||||||
|
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
|
||||||
|
algebra
|
||||||
.binaryOperationFunction(node.operation)
|
.binaryOperationFunction(node.operation)
|
||||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> call(
|
algebra is NumericAlgebra && node.left is Numeric -> call(
|
||||||
algebra.leftSideNumberOperationFunction(node.operation),
|
algebra.leftSideNumberOperationFunction(node.operation),
|
||||||
visit(node.left),
|
visit(node.left),
|
||||||
visit(node.right),
|
visit(node.right),
|
||||||
)
|
)
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> call(
|
algebra is NumericAlgebra && node.right is Numeric -> call(
|
||||||
algebra.rightSideNumberOperationFunction(node.operation),
|
algebra.rightSideNumberOperationFunction(node.operation),
|
||||||
visit(node.left),
|
visit(node.left),
|
||||||
visit(node.right),
|
visit(node.right),
|
||||||
|
@ -3,11 +3,11 @@ package kscience.kmath.asm
|
|||||||
import kscience.kmath.asm.internal.AsmBuilder
|
import kscience.kmath.asm.internal.AsmBuilder
|
||||||
import kscience.kmath.asm.internal.buildName
|
import kscience.kmath.asm.internal.buildName
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MST.*
|
||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles given MST to an Expression using AST compiler.
|
* Compiles given MST to an Expression using AST compiler.
|
||||||
@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField
|
|||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||||
is MST.Symbolic -> {
|
is Symbolic -> {
|
||||||
val symbol = try {
|
val symbol = try {
|
||||||
algebra.symbol(node.value)
|
algebra.symbol(node.value)
|
||||||
} catch (ignored: IllegalStateException) {
|
} catch (ignored: IllegalStateException) {
|
||||||
@ -33,24 +33,29 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
|||||||
loadVariable(node.value)
|
loadVariable(node.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Numeric -> loadNumberConstant(node.value)
|
is Numeric -> loadNumberConstant(node.value)
|
||||||
is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is Unary -> when {
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
|
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||||
algebra.number(
|
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||||
RealField
|
|
||||||
.binaryOperationFunction(node.operation)
|
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
}
|
||||||
)
|
|
||||||
|
is Binary -> when {
|
||||||
|
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
||||||
|
algebra.binaryOperationFunction(node.operation)
|
||||||
|
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||||
)
|
)
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) {
|
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
||||||
|
algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) {
|
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
||||||
|
algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
}
|
}
|
||||||
|
@ -191,7 +191,7 @@ internal class AsmBuilder<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
||||||
java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
// java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
||||||
val l = MethodHandles.publicLookup()
|
val l = MethodHandles.publicLookup()
|
||||||
|
|
||||||
if (hasConstants)
|
if (hasConstants)
|
||||||
|
@ -29,6 +29,7 @@ public class CMMatrix(public val origin: RealMatrix, features: Set<MatrixFeature
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO move inside context
|
||||||
public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
||||||
this
|
this
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package kscience.kmath.linear
|
package kscience.kmath.linear
|
||||||
|
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
import kscience.kmath.operations.Ring
|
import kscience.kmath.operations.Ring
|
||||||
import kscience.kmath.structures.*
|
import kscience.kmath.structures.Buffer
|
||||||
|
import kscience.kmath.structures.BufferFactory
|
||||||
|
import kscience.kmath.structures.NDStructure
|
||||||
|
import kscience.kmath.structures.asSequence
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Basic implementation of Matrix space based on [NDStructure]
|
* Basic implementation of Matrix space based on [NDStructure]
|
||||||
@ -21,24 +23,6 @@ public class BufferMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
public object RealMatrixContext : GenericMatrixContext<Double, RealField, BufferMatrix<Double>> {
|
|
||||||
public override val elementContext: RealField
|
|
||||||
get() = RealField
|
|
||||||
|
|
||||||
public override inline fun produce(
|
|
||||||
rows: Int,
|
|
||||||
columns: Int,
|
|
||||||
initializer: (i: Int, j: Int) -> Double,
|
|
||||||
): BufferMatrix<Double> {
|
|
||||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
|
||||||
return BufferMatrix(rows, columns, buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
public override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
|
|
||||||
RealBuffer(size, initializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
public class BufferMatrix<T : Any>(
|
public class BufferMatrix<T : Any>(
|
||||||
public override val rowNum: Int,
|
public override val rowNum: Int,
|
||||||
public override val colNum: Int,
|
public override val colNum: Int,
|
||||||
|
@ -213,17 +213,8 @@ public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext
|
|||||||
return decomposition.solveWithLUP(bufferFactory, b)
|
return decomposition.solveWithLUP(bufferFactory, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): FeaturedMatrix<Double> =
|
|
||||||
solveWithLUP(a, b) { it < 1e-11 }
|
|
||||||
|
|
||||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.inverseWithLUP(
|
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.inverseWithLUP(
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
|
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
|
||||||
noinline checkSingular: (T) -> Boolean,
|
noinline checkSingular: (T) -> Boolean,
|
||||||
): FeaturedMatrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
|
): FeaturedMatrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
|
||||||
|
|
||||||
/**
|
|
||||||
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
|
|
||||||
*/
|
|
||||||
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): FeaturedMatrix<Double> =
|
|
||||||
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), Buffer.Companion::real) { it < 1e-11 }
|
|
||||||
|
@ -18,6 +18,11 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
|
|||||||
*/
|
*/
|
||||||
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
|
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce a point compatible with matrix space (and possibly optimized for it)
|
||||||
|
*/
|
||||||
|
public fun point(size: Int, initializer: (Int) -> T): Point<T> = Buffer.boxing(size, initializer)
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public override fun binaryOperationFunction(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M =
|
public override fun binaryOperationFunction(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M =
|
||||||
when (operation) {
|
when (operation) {
|
||||||
@ -62,10 +67,6 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
|
|||||||
public operator fun T.times(m: Matrix<T>): M = m * this
|
public operator fun T.times(m: Matrix<T>): M = m * this
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
|
||||||
* Non-boxing double matrix
|
|
||||||
*/
|
|
||||||
public val real: RealMatrixContext = RealMatrixContext
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structured matrix with custom buffer
|
* A structured matrix with custom buffer
|
||||||
@ -89,11 +90,6 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>, out M : Matrix<T>> :
|
|||||||
*/
|
*/
|
||||||
public val elementContext: R
|
public val elementContext: R
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce a point compatible with matrix space
|
|
||||||
*/
|
|
||||||
public fun point(size: Int, initializer: (Int) -> T): Point<T>
|
|
||||||
|
|
||||||
public override infix fun Matrix<T>.dot(other: Matrix<T>): M {
|
public override infix fun Matrix<T>.dot(other: Matrix<T>): M {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
package kscience.kmath.linear
|
||||||
|
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.structures.Matrix
|
||||||
|
import kscience.kmath.structures.MutableBuffer
|
||||||
|
import kscience.kmath.structures.MutableBufferFactory
|
||||||
|
import kscience.kmath.structures.RealBuffer
|
||||||
|
|
||||||
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
|
public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
|
||||||
|
|
||||||
|
public override inline fun produce(
|
||||||
|
rows: Int,
|
||||||
|
columns: Int,
|
||||||
|
initializer: (i: Int, j: Int) -> Double,
|
||||||
|
): BufferMatrix<Double> {
|
||||||
|
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
|
return BufferMatrix(rows, columns, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun Matrix<Double>.wrap(): BufferMatrix<Double> = if (this is BufferMatrix) this else {
|
||||||
|
produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun one(rows: Int, columns: Int): FeaturedMatrix<Double> = VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
|
||||||
|
if (i == j) 1.0 else 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
public override infix fun Matrix<Double>.dot(other: Matrix<Double>): BufferMatrix<Double> {
|
||||||
|
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
return produce(rowNum, other.colNum) { i, j ->
|
||||||
|
var res = 0.0
|
||||||
|
for (l in 0 until colNum) {
|
||||||
|
res += get(i, l) * other.get(l, j)
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public override infix fun Matrix<Double>.dot(vector: Point<Double>): Point<Double> {
|
||||||
|
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||||
|
return RealBuffer(rowNum) { i ->
|
||||||
|
var res = 0.0
|
||||||
|
for (j in 0 until colNum) {
|
||||||
|
res += get(i, j) * vector[j]
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun add(a: Matrix<Double>, b: Matrix<Double>): BufferMatrix<Double> {
|
||||||
|
require(a.rowNum == b.rowNum) { "Row number mismatch in matrix addition. Left side: ${a.rowNum}, right side: ${b.rowNum}" }
|
||||||
|
require(a.colNum == b.colNum) { "Column number mismatch in matrix addition. Left side: ${a.colNum}, right side: ${b.colNum}" }
|
||||||
|
return produce(a.rowNum, a.colNum) { i, j ->
|
||||||
|
a[i, j] + b[i, j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Matrix<Double>.times(value: Double): BufferMatrix<Double> =
|
||||||
|
produce(rowNum, colNum) { i, j -> get(i, j) * value }
|
||||||
|
|
||||||
|
|
||||||
|
override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> =
|
||||||
|
produce(a.rowNum, a.colNum) { i, j -> a.get(i, j) * k.toDouble() }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Partially optimized real-valued matrix
|
||||||
|
*/
|
||||||
|
public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext
|
||||||
|
|
||||||
|
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): FeaturedMatrix<Double> {
|
||||||
|
// Use existing decomposition if it is provided by matrix
|
||||||
|
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
|
||||||
|
val decomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 }
|
||||||
|
return decomposition.solveWithLUP(bufferFactory, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
|
||||||
|
*/
|
||||||
|
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): FeaturedMatrix<Double> =
|
||||||
|
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum))
|
@ -1,11 +1,6 @@
|
|||||||
package kscience.kmath.dimensions
|
package kscience.kmath.dimensions
|
||||||
|
|
||||||
import kscience.kmath.linear.GenericMatrixContext
|
import kscience.kmath.linear.*
|
||||||
import kscience.kmath.linear.MatrixContext
|
|
||||||
import kscience.kmath.linear.Point
|
|
||||||
import kscience.kmath.linear.transpose
|
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
import kscience.kmath.operations.Ring
|
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kscience.kmath.structures.Matrix
|
import kscience.kmath.structures.Matrix
|
||||||
import kscience.kmath.structures.Structure2D
|
import kscience.kmath.structures.Structure2D
|
||||||
@ -42,7 +37,7 @@ public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
|||||||
* An inline wrapper for a Matrix
|
* An inline wrapper for a Matrix
|
||||||
*/
|
*/
|
||||||
public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||||
private val structure: Structure2D<T>
|
private val structure: Structure2D<T>,
|
||||||
) : DMatrix<T, R, C> {
|
) : DMatrix<T, R, C> {
|
||||||
override val shape: IntArray get() = structure.shape
|
override val shape: IntArray get() = structure.shape
|
||||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
@ -81,7 +76,7 @@ public inline class DPointWrapper<T, D : Dimension>(public val point: Point<T>)
|
|||||||
/**
|
/**
|
||||||
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
||||||
*/
|
*/
|
||||||
public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: GenericMatrixContext<T, Ri, Matrix<T>>) {
|
public inline class DMatrixContext<T : Any>(public val context: MatrixContext<T, Matrix<T>>) {
|
||||||
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
||||||
require(rowNum == Dimension.dim<R>().toInt()) {
|
require(rowNum == Dimension.dim<R>().toInt()) {
|
||||||
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
|
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
|
||||||
@ -115,7 +110,7 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
|
|||||||
}
|
}
|
||||||
|
|
||||||
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||||
other: DMatrix<T, C1, C2>
|
other: DMatrix<T, C1, C2>,
|
||||||
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
||||||
|
|
||||||
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
||||||
@ -139,18 +134,19 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
|
|||||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
||||||
context { (this@transpose as Matrix<T>).transpose() }.coerce()
|
context { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||||
|
|
||||||
/**
|
|
||||||
* A square unit matrix
|
|
||||||
*/
|
|
||||||
public inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
|
|
||||||
if (i == j) context.elementContext.one else context.elementContext.zero
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
|
|
||||||
context.elementContext.zero
|
|
||||||
}
|
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
|
public val real: DMatrixContext<Double> = DMatrixContext(MatrixContext.real)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A square unit matrix
|
||||||
|
*/
|
||||||
|
public inline fun <reified D : Dimension> DMatrixContext<Double>.one(): DMatrix<Double, D, D> = produce { i, j ->
|
||||||
|
if (i == j) 1.0 else 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double>.zero(): DMatrix<Double, R, C> = produce { _, _ ->
|
||||||
|
0.0
|
||||||
|
}
|
@ -3,6 +3,7 @@ package kscience.dimensions
|
|||||||
import kscience.kmath.dimensions.D2
|
import kscience.kmath.dimensions.D2
|
||||||
import kscience.kmath.dimensions.D3
|
import kscience.kmath.dimensions.D3
|
||||||
import kscience.kmath.dimensions.DMatrixContext
|
import kscience.kmath.dimensions.DMatrixContext
|
||||||
|
import kscience.kmath.dimensions.one
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
internal class DMatrixContextTest {
|
internal class DMatrixContextTest {
|
||||||
|
@ -1,13 +1,7 @@
|
|||||||
package kscience.kmath.real
|
package kscience.kmath.real
|
||||||
|
|
||||||
import kscience.kmath.linear.FeaturedMatrix
|
import kscience.kmath.linear.*
|
||||||
import kscience.kmath.linear.MatrixContext
|
|
||||||
import kscience.kmath.linear.RealMatrixContext.elementContext
|
|
||||||
import kscience.kmath.linear.VirtualMatrix
|
|
||||||
import kscience.kmath.linear.inverseWithLUP
|
|
||||||
import kscience.kmath.misc.UnstableKMathAPI
|
import kscience.kmath.misc.UnstableKMathAPI
|
||||||
import kscience.kmath.operations.invoke
|
|
||||||
import kscience.kmath.operations.sum
|
|
||||||
import kscience.kmath.structures.Buffer
|
import kscience.kmath.structures.Buffer
|
||||||
import kscience.kmath.structures.RealBuffer
|
import kscience.kmath.structures.RealBuffer
|
||||||
import kscience.kmath.structures.asIterable
|
import kscience.kmath.structures.asIterable
|
||||||
@ -122,8 +116,7 @@ public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix =
|
|||||||
extractColumns(columnIndex..columnIndex)
|
extractColumns(columnIndex..columnIndex)
|
||||||
|
|
||||||
public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
val column = columns[j]
|
columns[j].asIterable().sum()
|
||||||
elementContext { sum(column.asIterable()) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
|
Loading…
Reference in New Issue
Block a user