Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-asm

# Conflicts:
#	kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt
#	settings.gradle.kts
This commit is contained in:
Iaroslav 2020-06-12 21:58:50 +07:00
commit 07b938e582
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
18 changed files with 316 additions and 97 deletions

View File

@ -0,0 +1,21 @@
plugins {
id("scientifik.mpp")
}
repositories{
maven("https://dl.bintray.com/hotkeytlt/maven")
}
kotlin.sourceSets {
commonMain {
dependencies {
api(project(":kmath-core"))
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3")
}
}
jvmMain{
dependencies{
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
}
}
}

View File

@ -0,0 +1,62 @@
package scientifik.kmath.ast
import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField
/**
* A Mathematical Syntax Tree node for mathematical expressions
*/
sealed class MST {
/**
* A node containing unparsed string
*/
data class Singular(val value: String) : MST()
/**
* A node containing a number
*/
data class Numeric(val value: Number) : MST()
/**
* A node containing an unary operation
*/
data class Unary(val operation: String, val value: MST) : MST() {
companion object {
const val ABS_OPERATION = "abs"
//TODO add operations
}
}
/**
* A node containing binary operation
*/
data class Binary(val operation: String, val left: MST, val right: MST) : MST() {
companion object
}
}
//TODO add a function with positional arguments
//TODO add a function with named arguments
fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
return when (node) {
is MST.Numeric -> number(node.value)
is MST.Singular -> raw(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
node.left.value.toDouble(),
node.right.value.toDouble()
)
number(number)
}
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
}
}
}

View File

@ -0,0 +1,19 @@
package scientifik.kmath.ast
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.NumericAlgebra
/**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
*/
class MSTExpression<T>(val algebra: NumericAlgebra<T>, val mst: MST) : Expression<T> {
/**
* Substitute algebra raw value
*/
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> by algebra {
override fun raw(value: String): T = arguments[value] ?: super.raw(value)
}
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
}

View File

@ -0,0 +1,23 @@
package scientifik.kmath.ast
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
//TODO stubs for asm generation
interface AsmExpression<T>
interface AsmExpressionAlgebra<T, A : Algebra<T>> : NumericAlgebra<AsmExpression<T>> {
val algebra: A
}
fun <T> AsmExpression<T>.compile(): Expression<T> = TODO()
//TODO add converter for functional expressions
inline fun <reified T : Any, A : Algebra<T>> A.asm(
block: AsmExpressionAlgebra<T, A>.() -> AsmExpression<T>
): Expression<T> = TODO()
inline fun <reified T : Any, A : Algebra<T>> A.asm(ast: MST): Expression<T> = asm { evaluate(ast) }

View File

@ -0,0 +1,59 @@
package scientifik.kmath.ast
import com.github.h0tk3y.betterParse.combinators.*
import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations
import scientifik.kmath.operations.PowerOperations
import scientifik.kmath.operations.RingOperations
import scientifik.kmath.operations.SpaceOperations
/**
* TODO move to common
*/
private object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?")
val lpar by token("\\(")
val rpar by token("\\)")
val mul by token("\\*")
val pow by token("\\^")
val div by token("/")
val minus by token("-")
val plus by token("\\+")
val ws by token("\\s+", ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
val term: Parser<MST> by number or
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
(skip(lpar) and parser(this::rootParser) and skip(rpar))
val powChain by leftAssociative(term, pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b)
}
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
if (op == div) {
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else {
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
}
}
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
if (op == plus) {
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else {
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
}
}
override val rootParser: Parser<MST> by subSumChain
}
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)

View File

@ -0,0 +1,17 @@
package scietifik.kmath.ast
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.parseMath
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
internal class ParserTest{
@Test
fun parsedExpression(){
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(Complex(10.0,0.0), res)
}
}

View File

@ -2,7 +2,7 @@ package scientifik.kmath.commons.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionContext
import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field
import kotlin.properties.ReadOnlyProperty
@ -113,7 +113,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1)
/**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/
object DiffExpressionContext : ExpressionContext<Double, DiffExpression>, Field<DiffExpression> {
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
override fun variable(name: String, default: Double?) =
DiffExpression { variable(name, default?.const()) }

View File

@ -7,6 +7,12 @@ import scientifik.kmath.operations.Algebra
*/
interface Expression<T> {
operator fun invoke(arguments: Map<String, T>): T
companion object {
operator fun <T> invoke(block: (Map<String, T>) -> T): Expression<T> = object : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = block(arguments)
}
}
}
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
@ -14,7 +20,7 @@ operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke
/**
* A context for expression construction
*/
interface ExpressionContext<T, E> : Algebra<E> {
interface ExpressionAlgebra<T, E> : Algebra<E> {
/**
* Introduce a variable into expression context
*/

View File

@ -1,52 +0,0 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.NumericAlgebra
/**
* A syntax tree node for mathematical expressions
*/
sealed class SyntaxTreeNode
/**
* A node containing unparsed string
*/
data class SingularNode(val value: String) : SyntaxTreeNode()
/**
* A node containing a number
*/
data class NumberNode(val value: Number) : SyntaxTreeNode()
/**
* A node containing an unary operation
*/
data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() {
companion object {
const val ABS_OPERATION = "abs"
const val SIN_OPERATION = "sin"
const val COS_OPERATION = "cos"
const val EXP_OPERATION = "exp"
const val LN_OPERATION = "ln"
//TODO add operations
}
}
/**
* A node containing binary operation
*/
data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() {
companion object
}
//TODO add a function with positional arguments
//TODO add a function with named arguments
fun <T> NumericAlgebra<T>.compile(node: SyntaxTreeNode): T{
return when (node) {
is NumberNode -> number(node.value)
is SingularNode -> raw(node.value)
is UnaryNode -> unaryOperation(node.operation, compile(node.value))
is BinaryNode -> binaryOperation(node.operation, compile(node.left), compile(node.right))
}
}

View File

@ -128,14 +128,14 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
luRow[col] = sum
// maintain best permutation choice
if (abs(sum) > largest) {
largest = abs(sum)
if (this@lup.abs(sum) > largest) {
largest = this@lup.abs(sum)
max = row
}
}
// Singularity check
if (checkSingular(abs(lu[max, col]))) {
if (checkSingular(this@lup.abs(lu[max, col]))) {
error("The matrix is singular")
}

View File

@ -1,5 +1,7 @@
package scientifik.kmath.misc
import kotlin.math.abs
/**
* Convert double range to sequence.
*
@ -8,8 +10,7 @@ package scientifik.kmath.misc
*
* If step is negative, the same goes from upper boundary downwards
*/
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> =
when {
fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Sequence<Double> = when {
step == 0.0 -> error("Zero step in double progression")
step > 0 -> sequence {
var current = start
@ -25,11 +26,20 @@ fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double>
current += step
}
}
}
}
/**
* Convert double range to sequence with the fixed number of points
*/
fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Sequence<Double> {
require(numPoints > 1) { "The number of points should be more than 2" }
return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1))
}
/**
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
*/
@Deprecated("Replace by 'toSequenceWithPoints'")
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't create generic grid with less than two points")
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }

View File

@ -31,6 +31,12 @@ interface NumericAlgebra<T> : Algebra<T> {
* Wrap a number
*/
fun number(value: Number): T
fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
binaryOperation(operation, number(left), right)
fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left)
}
/**
@ -128,8 +134,14 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble()
// those operators are blocked by type conflict in RealField
// operator fun T.plus(b: Number) = this.plus(b * one)
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right)
}
//TODO those operators are blocked by type conflict in RealField
// operator fun T.plus(b: Number) = this.plus(b * one)
// operator fun Number.plus(b: T) = b + this
//
// operator fun T.minus(b: Number) = this.minus(b * one)

View File

@ -11,7 +11,7 @@ import kotlin.math.*
/**
* A field for complex numbers
*/
object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
object ComplexField : ExtendedField<Complex> {
override val zero: Complex = Complex(0.0, 0.0)
override val one: Complex = Complex(1.0, 0.0)
@ -50,6 +50,12 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
operator fun Complex.minus(d: Double) = add(this, -d.toComplex())
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this)
override fun raw(value: String): Complex = if (value == "i") {
i
} else {
super.raw(value)
}
}
/**

View File

@ -10,9 +10,30 @@ interface ExtendedFieldOperations<T> :
FieldOperations<T>,
TrigonometricOperations<T>,
PowerOperations<T>,
ExponentialOperations<T>
ExponentialOperations<T> {
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>
override fun tan(arg: T): T = sin(arg) / cos(arg)
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg)
PowerOperations.SQRT_OPERATION -> sqrt(arg)
ExponentialOperations.EXP_OPERATION -> exp(arg)
ExponentialOperations.LN_OPERATION -> ln(arg)
else -> super.unaryOperation(operation, arg)
}
}
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T {
return when (operation) {
PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right)
}
}
}
/**
* Real field element wrapping double.
@ -44,6 +65,7 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
override inline fun tan(arg: Double) = kotlin.math.tan(arg)
override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble())
@ -76,6 +98,8 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
override inline fun exp(arg: Float) = kotlin.math.exp(arg)

View File

@ -10,30 +10,37 @@ package scientifik.kmath.operations
* It also allows to override behavior for optional operations
*
*/
interface TrigonometricOperations<T> : FieldOperations<T> {
interface TrigonometricOperations<T> {
fun sin(arg: T): T
fun cos(arg: T): T
fun tg(arg: T): T = sin(arg) / cos(arg)
fun tan(arg: T): T
fun ctg(arg: T): T = cos(arg) / sin(arg)
companion object {
const val SIN_OPERATION = "sin"
const val COS_OPERATION = "cos"
}
}
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> tg(arg: T): T = arg.context.tg(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.context.ctg(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
/* Power and roots */
/**
* A context extension to include power operations like square roots, etc
*/
interface PowerOperations<T> : Algebra<T> {
interface PowerOperations<T> {
fun power(arg: T, pow: Number): T
fun sqrt(arg: T) = power(arg, 0.5)
infix fun T.pow(pow: Number) = power(this, pow)
companion object {
const val POW_OPERATION = "pow"
const val SQRT_OPERATION = "sqrt"
}
}
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
@ -42,9 +49,14 @@ fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/* Exponential */
interface ExponentialOperations<T>: Algebra<T> {
interface ExponentialOperations<T> {
fun exp(arg: T): T
fun ln(arg: T): T
companion object {
const val EXP_OPERATION = "exp"
const val LN_OPERATION = "ln"
}
}
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)

View File

@ -1,13 +1,8 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.*
import scientifik.kmath.operations.ExtendedField
interface ExtendedNDField<T : Any, F, N : NDStructure<T>> :
NDField<T, F, N>,
TrigonometricOperations<N>,
PowerOperations<N>,
ExponentialOperations<N>
where F : ExtendedFieldOperations<T>, F : Field<T>
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
///**

View File

@ -1,10 +1,12 @@
package scientifik.kmath.operations
import scientifik.kmath.structures.*
import java.math.BigDecimal
import java.math.BigInteger
import java.math.MathContext
/**
* A field wrapper for Java [BigInteger]
*/
object JBigIntegerField : Field<BigInteger> {
override val zero: BigInteger = BigInteger.ZERO
override val one: BigInteger = BigInteger.ONE
@ -18,6 +20,9 @@ object JBigIntegerField : Field<BigInteger> {
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
}
/**
* A Field wrapper for Java [BigDecimal]
*/
class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
override val zero: BigDecimal = BigDecimal.ZERO
override val one: BigDecimal = BigDecimal.ONE

View File

@ -44,6 +44,6 @@ include(
":kmath-dimensions",
":kmath-for-real",
":kmath-geometry",
":examples",
":kmath-asm"
":kmath-ast",
":examples"
)