Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-asm

# Conflicts:
#	kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt
#	kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt
This commit is contained in:
Iaroslav 2020-06-09 22:17:42 +07:00
commit 4dcdc0f99c
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
6 changed files with 84 additions and 114 deletions

View File

@ -2,8 +2,9 @@ package scientifik.kmath.commons.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionField
import scientifik.kmath.expressions.ExpressionContext
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field
import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KProperty
@ -112,7 +113,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1)
/**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/
object DiffExpressionContext : ExpressionField<Double, DiffExpression> {
object DiffExpressionContext : ExpressionContext<Double, DiffExpression>, Field<DiffExpression> {
override fun variable(name: String, default: Double?) =
DiffExpression { variable(name, default?.const()) }

View File

@ -1,7 +1,6 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.Algebra
/**
* An elementary function that could be invoked on a map of arguments
@ -15,7 +14,7 @@ operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke
/**
* A context for expression construction
*/
interface ExpressionContext<T, E> {
interface ExpressionContext<T, E> : Algebra<E> {
/**
* Introduce a variable into expression context
*/
@ -25,87 +24,13 @@ interface ExpressionContext<T, E> {
* A constant expression which does not depend on arguments
*/
fun const(value: T): E
fun produce(node: SyntaxTreeNode): E
}
interface ExpressionSpace<T, E> : Space<E>, ExpressionContext<T, E> {
fun produceSingular(value: String): E = variable(value)
fun produceUnary(operation: String, value: E): E {
return when (operation) {
UnaryNode.PLUS_OPERATION -> value
UnaryNode.MINUS_OPERATION -> -value
else -> error("Unary operation $operation is not supported by $this")
}
}
fun produceBinary(operation: String, left: E, right: E): E {
return when (operation) {
BinaryNode.PLUS_OPERATION -> left + right
BinaryNode.MINUS_OPERATION -> left - right
else -> error("Binary operation $operation is not supported by $this")
}
}
override fun produce(node: SyntaxTreeNode): E {
fun <T, E> ExpressionContext<T, E>.produce(node: SyntaxTreeNode): E {
return when (node) {
is NumberNode -> error("Single number nodes are not supported")
is SingularNode -> produceSingular(node.value)
is UnaryNode -> produceUnary(node.operation, produce(node.value))
is BinaryNode -> {
when (node.operation) {
BinaryNode.TIMES_OPERATION -> {
if (node.left is NumberNode) {
return produce(node.right) * node.left.value
} else if (node.right is NumberNode) {
return produce(node.left) * node.right.value
is SingularNode -> variable(node.value)
is UnaryNode -> unaryOperation(node.operation, produce(node.value))
is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right))
}
}
BinaryNode.DIV_OPERATION -> {
if (node.right is NumberNode) {
return produce(node.left) / node.right.value
}
}
}
produceBinary(node.operation, produce(node.left), produce(node.right))
}
}
}
}
interface ExpressionField<T, E> : Field<E>, ExpressionSpace<T, E> {
fun number(value: Number): E = one * value
override fun produce(node: SyntaxTreeNode): E {
if (node is BinaryNode) {
when (node.operation) {
BinaryNode.PLUS_OPERATION -> {
if (node.left is NumberNode) {
return produce(node.right) + one * node.left.value
} else if (node.right is NumberNode) {
return produce(node.left) + one * node.right.value
}
}
BinaryNode.MINUS_OPERATION -> {
if (node.left is NumberNode) {
return one * node.left.value - produce(node.right)
} else if (node.right is NumberNode) {
return produce(node.left) - one * node.right.value
}
}
}
}
return super.produce(node)
}
override fun produceBinary(operation: String, left: E, right: E): E {
return when (operation) {
BinaryNode.TIMES_OPERATION -> left * right
BinaryNode.DIV_OPERATION -> left / right
else -> super.produceBinary(operation, left, right)
}
}
}

View File

@ -3,7 +3,6 @@ package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
@ -28,41 +27,54 @@ internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<
context.multiply(first.invoke(arguments), second.invoke(arguments))
}
internal class ConstProductExpression<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
internal class DivExpression<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
}
open class FunctionalExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionSpace<T, Expression<T>> {
open class FunctionalExpressionSpace<T>(
val space: Space<T>
) : Space<Expression<T>>, ExpressionContext<T,Expression<T>> {
override val zero: Expression<T> = ConstantExpression(space.zero)
override fun const(value: T): Expression<T> = ConstantExpression(value)
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpression(space, a, k)
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
operator fun Expression<T>.plus(arg: T) = this + const(arg)
operator fun Expression<T>.minus(arg: T) = this - const(arg)
operator fun T.plus(arg: Expression<T>) = arg + this
operator fun T.minus(arg: Expression<T>) = arg - this
}
open class FunctionalExpressionField<T>(val field: Field<T>) :
ExpressionField<T, Expression<T>>,
FunctionalExpressionSpace<T>(field) {
open class FunctionalExpressionField<T>(
val field: Field<T>
) : Field<Expression<T>>, ExpressionContext<T,Expression<T>>, FunctionalExpressionSpace<T>(field) {
override val one: Expression<T>
get() = const(this.field.one)
override fun number(value: Number): Expression<T> = const(field { one * value })
fun const(value: Double): Expression<T> = const(field.run { one*value})
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpression(field, a, b)
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
operator fun Expression<T>.times(arg: T) = this * const(arg)
operator fun Expression<T>.div(arg: T) = this / const(arg)
operator fun T.times(arg: Expression<T>) = arg * this
operator fun T.div(arg: Expression<T>) = arg / this
}

View File

@ -8,9 +8,6 @@ data class NumberNode(val value: Number) : SyntaxTreeNode()
data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() {
companion object {
const val PLUS_OPERATION = "+"
const val MINUS_OPERATION = "-"
const val NOT_OPERATION = "!"
const val ABS_OPERATION = "abs"
const val SIN_OPERATION = "sin"
const val COS_OPERATION = "cos"
@ -22,10 +19,6 @@ data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxT
data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() {
companion object {
const val PLUS_OPERATION = "+"
const val MINUS_OPERATION = "-"
const val TIMES_OPERATION = "*"
const val DIV_OPERATION = "/"
//TODO add operations
}
}

View File

@ -6,9 +6,12 @@ annotation class KMathContext
/**
* Marker interface for any algebra
*/
interface Algebra<T>
interface Algebra<T> {
fun unaryOperation(operation: String, arg: T): T
fun binaryOperation(operation: String, left: T, right: T): T
}
inline operator fun <T : Algebra<*>, R> T.invoke(block: T.() -> R): R = run(block)
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
/**
* Space-like operations without neutral element
@ -24,7 +27,7 @@ interface SpaceOperations<T> : Algebra<T> {
*/
fun multiply(a: T, k: Number): T
//Operation to be performed in this context
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
operator fun T.unaryMinus(): T = multiply(this, -1.0)
operator fun T.plus(b: T): T = add(this, b)
@ -32,6 +35,24 @@ interface SpaceOperations<T> : Algebra<T> {
operator fun T.times(k: Number) = multiply(this, k.toDouble())
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
operator fun Number.times(b: T) = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
PLUS_OPERATION -> arg
MINUS_OPERATION -> -arg
else -> error("Unary operation $operation not defined in $this")
}
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
PLUS_OPERATION -> add(left, right)
MINUS_OPERATION -> left - right
else -> error("Binary operation $operation not defined in $this")
}
companion object {
const val PLUS_OPERATION = "+"
const val MINUS_OPERATION = "-"
const val NOT_OPERATION = "!"
}
}
@ -60,6 +81,15 @@ interface RingOperations<T> : SpaceOperations<T> {
fun multiply(a: T, b: T): T
operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
TIMES_OPERATION -> multiply(left, right)
else -> super.binaryOperation(operation, left, right)
}
companion object{
const val TIMES_OPERATION = "*"
}
}
/**
@ -85,6 +115,15 @@ interface FieldOperations<T> : RingOperations<T> {
fun divide(a: T, b: T): T
operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
DIV_OPERATION -> divide(left, right)
else -> super.binaryOperation(operation, left, right)
}
companion object{
const val DIV_OPERATION = "/"
}
}
/**