From 1a869ace0e23659eca398d7a768a54191bc084d7 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 15 May 2020 16:10:19 +0300 Subject: [PATCH] Refactored Expression tree API --- .../commons/expressions/DiffExpression.kt | 5 +- .../kmath/expressions/Expression.kt | 92 ++++++++++++++++++- .../{syntaxTree.kt => SyntaxTreeNode.kt} | 14 ++- .../expressions/functionalExpressions.kt | 60 ++---------- 4 files changed, 107 insertions(+), 64 deletions(-) rename kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/{syntaxTree.kt => SyntaxTreeNode.kt} (67%) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index d5c038dc4..65e915ef4 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,9 +2,8 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.expressions.ExpressionField import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -113,7 +112,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionContext, Field { +object DiffExpressionContext : ExpressionField { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index bc3d70b4f..94bc81413 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,5 +1,8 @@ package scientifik.kmath.expressions +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space + /** * An elementary function that could be invoked on a map of arguments */ @@ -12,16 +15,97 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext { +interface ExpressionContext> { /** * Introduce a variable into expression context */ - fun variable(name: String, default: T? = null): Expression + fun variable(name: String, default: T? = null): E /** * A constant expression which does not depend on arguments */ - fun const(value: T): Expression + fun const(value: T): E + + fun produce(node: SyntaxTreeNode): E +} + +interface ExpressionSpace> : Space, ExpressionContext { + + open fun produceSingular(value: String): E = variable(value) + + open fun produceUnary(operation: String, value: E): E { + return when (operation) { + UnaryNode.PLUS_OPERATION -> value + UnaryNode.MINUS_OPERATION -> -value + else -> error("Unary operation $operation is not supported by $this") + } + } + + open fun produceBinary(operation: String, left: E, right: E): E { + return when (operation) { + BinaryNode.PLUS_OPERATION -> left + right + BinaryNode.MINUS_OPERATION -> left - right + else -> error("Binary operation $operation is not supported by $this") + } + } + + override fun produce(node: SyntaxTreeNode): E { + return when (node) { + is NumberNode -> error("Single number nodes are not supported") + is SingularNode -> produceSingular(node.value) + is UnaryNode -> produceUnary(node.operation, produce(node.value)) + is BinaryNode -> { + when (node.operation) { + BinaryNode.TIMES_OPERATION -> { + if (node.left is NumberNode) { + return produce(node.right) * node.left.value + } else if (node.right is NumberNode) { + return produce(node.left) * node.right.value + } + } + BinaryNode.DIV_OPERATION -> { + if (node.right is NumberNode) { + return produce(node.left) / node.right.value + } + } + } + produceBinary(node.operation, produce(node.left), produce(node.right)) + } + } + } +} + +interface ExpressionField> : Field, ExpressionSpace { + fun const(value: Double): E = one.times(value) + + override fun produce(node: SyntaxTreeNode): E { + if (node is BinaryNode) { + when (node.operation) { + BinaryNode.PLUS_OPERATION -> { + if (node.left is NumberNode) { + return produce(node.right) + one * node.left.value + } else if (node.right is NumberNode) { + return produce(node.left) + one * node.right.value + } + } + BinaryNode.MINUS_OPERATION -> { + if (node.left is NumberNode) { + return one * node.left.value - produce(node.right) + } else if (node.right is NumberNode) { + return produce(node.left) - one * node.right.value + } + } + } + } + return super.produce(node) + } + + override fun produceBinary(operation: String, left: E, right: E): E { + return when (operation) { + BinaryNode.TIMES_OPERATION -> left * right + BinaryNode.DIV_OPERATION -> left / right + else -> super.produceBinary(operation, left, right) + } + } - fun produce(node: SyntaxTreeNode): Expression } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt similarity index 67% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt index daa17977d..fff3c5bc0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt @@ -4,20 +4,24 @@ sealed class SyntaxTreeNode data class SingularNode(val value: String) : SyntaxTreeNode() -data class UnaryNode(val operation: String, val value: SyntaxTreeNode): SyntaxTreeNode(){ - companion object{ +data class NumberNode(val value: Number) : SyntaxTreeNode() + +data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() { + companion object { const val PLUS_OPERATION = "+" const val MINUS_OPERATION = "-" const val NOT_OPERATION = "!" const val ABS_OPERATION = "abs" const val SIN_OPERATION = "sin" - const val cos_OPERATION = "cos" + const val COS_OPERATION = "cos" + const val EXP_OPERATION = "exp" + const val LN_OPERATION = "ln" //TODO add operations } } -data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode): SyntaxTreeNode(){ - companion object{ +data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() { + companion object { const val PLUS_OPERATION = "+" const val MINUS_OPERATION = "-" const val TIMES_OPERATION = "*" diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt index 64b3ad72c..ad0acbc4a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -40,12 +40,10 @@ internal class DivExpession(val context: Field, val expr: Expression, v open class FunctionalExpressionSpace( val space: Space, one: T -) : Space>, ExpressionContext { +) : Space>, ExpressionSpace> { override val zero: Expression = ConstantExpression(space.zero) - val one: Expression = ConstantExpression(one) - override fun const(value: T): Expression = ConstantExpression(value) override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) @@ -60,46 +58,17 @@ open class FunctionalExpressionSpace( operator fun T.plus(arg: Expression) = arg + this operator fun T.minus(arg: Expression) = arg - this - - fun const(value: Double): Expression = one.times(value) - - open fun produceSingular(value: String): Expression { - val numberValue = value.toDoubleOrNull() - return if (numberValue == null) { - variable(value) - } else { - const(numberValue) - } - } - - open fun produceUnary(operation: String, value: Expression): Expression { - return when (operation) { - UnaryNode.PLUS_OPERATION -> value - UnaryNode.MINUS_OPERATION -> -value - else -> error("Unary operation $operation is not supported by $this") - } - } - - open fun produceBinary(operation: String, left: Expression, right: Expression): Expression { - return when (operation) { - BinaryNode.PLUS_OPERATION -> left + right - BinaryNode.MINUS_OPERATION -> left - right - else -> error("Binary operation $operation is not supported by $this") - } - } - - override fun produce(node: SyntaxTreeNode): Expression { - return when (node) { - is SingularNode -> produceSingular(node.value) - is UnaryNode -> produceUnary(node.operation, produce(node.value)) - is BinaryNode -> produceBinary(node.operation, produce(node.left), produce(node.right)) - } - } } open class FunctionalExpressionField( val field: Field -) : Field>, FunctionalExpressionSpace(field, field.one) { +) : ExpressionField>, FunctionalExpressionSpace(field, field.one) { + + override val one: Expression + get() = const(this.field.one) + + override fun const(value: Double): Expression = const(field.run { one*value}) + override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) @@ -109,17 +78,4 @@ open class FunctionalExpressionField( operator fun T.times(arg: Expression) = arg * this operator fun T.div(arg: Expression) = arg / this - - override fun produce(node: SyntaxTreeNode): Expression { - //TODO bring together numeric and typed expressions - return super.produce(node) - } - - override fun produceBinary(operation: String, left: Expression, right: Expression): Expression { - return when (operation) { - BinaryNode.TIMES_OPERATION -> left * right - BinaryNode.DIV_OPERATION -> left / right - else -> super.produceBinary(operation, left, right) - } - } } \ No newline at end of file