forked from kscience/kmath
Refactored Expression tree API
This commit is contained in:
parent
4a26c1080e
commit
1a869ace0e
@ -2,9 +2,8 @@ package scientifik.kmath.commons.expressions
|
|||||||
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.expressions.ExpressionContext
|
import scientifik.kmath.expressions.ExpressionField
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedField
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
import kotlin.reflect.KProperty
|
import kotlin.reflect.KProperty
|
||||||
|
|
||||||
@ -113,7 +112,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
|||||||
/**
|
/**
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||||
*/
|
*/
|
||||||
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
object DiffExpressionContext : ExpressionField<Double, DiffExpression> {
|
||||||
override fun variable(name: String, default: Double?) =
|
override fun variable(name: String, default: Double?) =
|
||||||
DiffExpression { variable(name, default?.const()) }
|
DiffExpression { variable(name, default?.const()) }
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments
|
||||||
*/
|
*/
|
||||||
@ -12,16 +15,97 @@ operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke
|
|||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
*/
|
*/
|
||||||
interface ExpressionContext<T> {
|
interface ExpressionContext<T, E : Expression<T>> {
|
||||||
/**
|
/**
|
||||||
* Introduce a variable into expression context
|
* Introduce a variable into expression context
|
||||||
*/
|
*/
|
||||||
fun variable(name: String, default: T? = null): Expression<T>
|
fun variable(name: String, default: T? = null): E
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constant expression which does not depend on arguments
|
* A constant expression which does not depend on arguments
|
||||||
*/
|
*/
|
||||||
fun const(value: T): Expression<T>
|
fun const(value: T): E
|
||||||
|
|
||||||
|
fun produce(node: SyntaxTreeNode): E
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ExpressionSpace<T, E : Expression<T>> : Space<E>, ExpressionContext<T, E> {
|
||||||
|
|
||||||
|
open fun produceSingular(value: String): E = variable(value)
|
||||||
|
|
||||||
|
open fun produceUnary(operation: String, value: E): E {
|
||||||
|
return when (operation) {
|
||||||
|
UnaryNode.PLUS_OPERATION -> value
|
||||||
|
UnaryNode.MINUS_OPERATION -> -value
|
||||||
|
else -> error("Unary operation $operation is not supported by $this")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
open fun produceBinary(operation: String, left: E, right: E): E {
|
||||||
|
return when (operation) {
|
||||||
|
BinaryNode.PLUS_OPERATION -> left + right
|
||||||
|
BinaryNode.MINUS_OPERATION -> left - right
|
||||||
|
else -> error("Binary operation $operation is not supported by $this")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun produce(node: SyntaxTreeNode): E {
|
||||||
|
return when (node) {
|
||||||
|
is NumberNode -> error("Single number nodes are not supported")
|
||||||
|
is SingularNode -> produceSingular(node.value)
|
||||||
|
is UnaryNode -> produceUnary(node.operation, produce(node.value))
|
||||||
|
is BinaryNode -> {
|
||||||
|
when (node.operation) {
|
||||||
|
BinaryNode.TIMES_OPERATION -> {
|
||||||
|
if (node.left is NumberNode) {
|
||||||
|
return produce(node.right) * node.left.value
|
||||||
|
} else if (node.right is NumberNode) {
|
||||||
|
return produce(node.left) * node.right.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BinaryNode.DIV_OPERATION -> {
|
||||||
|
if (node.right is NumberNode) {
|
||||||
|
return produce(node.left) / node.right.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
produceBinary(node.operation, produce(node.left), produce(node.right))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ExpressionField<T, E : Expression<T>> : Field<E>, ExpressionSpace<T, E> {
|
||||||
|
fun const(value: Double): E = one.times(value)
|
||||||
|
|
||||||
|
override fun produce(node: SyntaxTreeNode): E {
|
||||||
|
if (node is BinaryNode) {
|
||||||
|
when (node.operation) {
|
||||||
|
BinaryNode.PLUS_OPERATION -> {
|
||||||
|
if (node.left is NumberNode) {
|
||||||
|
return produce(node.right) + one * node.left.value
|
||||||
|
} else if (node.right is NumberNode) {
|
||||||
|
return produce(node.left) + one * node.right.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BinaryNode.MINUS_OPERATION -> {
|
||||||
|
if (node.left is NumberNode) {
|
||||||
|
return one * node.left.value - produce(node.right)
|
||||||
|
} else if (node.right is NumberNode) {
|
||||||
|
return produce(node.left) - one * node.right.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return super.produce(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun produceBinary(operation: String, left: E, right: E): E {
|
||||||
|
return when (operation) {
|
||||||
|
BinaryNode.TIMES_OPERATION -> left * right
|
||||||
|
BinaryNode.DIV_OPERATION -> left / right
|
||||||
|
else -> super.produceBinary(operation, left, right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fun produce(node: SyntaxTreeNode): Expression<T>
|
|
||||||
}
|
}
|
@ -4,20 +4,24 @@ sealed class SyntaxTreeNode
|
|||||||
|
|
||||||
data class SingularNode(val value: String) : SyntaxTreeNode()
|
data class SingularNode(val value: String) : SyntaxTreeNode()
|
||||||
|
|
||||||
data class UnaryNode(val operation: String, val value: SyntaxTreeNode): SyntaxTreeNode(){
|
data class NumberNode(val value: Number) : SyntaxTreeNode()
|
||||||
companion object{
|
|
||||||
|
data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() {
|
||||||
|
companion object {
|
||||||
const val PLUS_OPERATION = "+"
|
const val PLUS_OPERATION = "+"
|
||||||
const val MINUS_OPERATION = "-"
|
const val MINUS_OPERATION = "-"
|
||||||
const val NOT_OPERATION = "!"
|
const val NOT_OPERATION = "!"
|
||||||
const val ABS_OPERATION = "abs"
|
const val ABS_OPERATION = "abs"
|
||||||
const val SIN_OPERATION = "sin"
|
const val SIN_OPERATION = "sin"
|
||||||
const val cos_OPERATION = "cos"
|
const val COS_OPERATION = "cos"
|
||||||
|
const val EXP_OPERATION = "exp"
|
||||||
|
const val LN_OPERATION = "ln"
|
||||||
//TODO add operations
|
//TODO add operations
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode): SyntaxTreeNode(){
|
data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() {
|
||||||
companion object{
|
companion object {
|
||||||
const val PLUS_OPERATION = "+"
|
const val PLUS_OPERATION = "+"
|
||||||
const val MINUS_OPERATION = "-"
|
const val MINUS_OPERATION = "-"
|
||||||
const val TIMES_OPERATION = "*"
|
const val TIMES_OPERATION = "*"
|
@ -40,12 +40,10 @@ internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, v
|
|||||||
open class FunctionalExpressionSpace<T>(
|
open class FunctionalExpressionSpace<T>(
|
||||||
val space: Space<T>,
|
val space: Space<T>,
|
||||||
one: T
|
one: T
|
||||||
) : Space<Expression<T>>, ExpressionContext<T> {
|
) : Space<Expression<T>>, ExpressionSpace<T,Expression<T>> {
|
||||||
|
|
||||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
override val zero: Expression<T> = ConstantExpression(space.zero)
|
||||||
|
|
||||||
val one: Expression<T> = ConstantExpression(one)
|
|
||||||
|
|
||||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
||||||
|
|
||||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
||||||
@ -60,46 +58,17 @@ open class FunctionalExpressionSpace<T>(
|
|||||||
|
|
||||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
operator fun T.plus(arg: Expression<T>) = arg + this
|
||||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
operator fun T.minus(arg: Expression<T>) = arg - this
|
||||||
|
|
||||||
fun const(value: Double): Expression<T> = one.times(value)
|
|
||||||
|
|
||||||
open fun produceSingular(value: String): Expression<T> {
|
|
||||||
val numberValue = value.toDoubleOrNull()
|
|
||||||
return if (numberValue == null) {
|
|
||||||
variable(value)
|
|
||||||
} else {
|
|
||||||
const(numberValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
open fun produceUnary(operation: String, value: Expression<T>): Expression<T> {
|
|
||||||
return when (operation) {
|
|
||||||
UnaryNode.PLUS_OPERATION -> value
|
|
||||||
UnaryNode.MINUS_OPERATION -> -value
|
|
||||||
else -> error("Unary operation $operation is not supported by $this")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
open fun produceBinary(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> {
|
|
||||||
return when (operation) {
|
|
||||||
BinaryNode.PLUS_OPERATION -> left + right
|
|
||||||
BinaryNode.MINUS_OPERATION -> left - right
|
|
||||||
else -> error("Binary operation $operation is not supported by $this")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun produce(node: SyntaxTreeNode): Expression<T> {
|
|
||||||
return when (node) {
|
|
||||||
is SingularNode -> produceSingular(node.value)
|
|
||||||
is UnaryNode -> produceUnary(node.operation, produce(node.value))
|
|
||||||
is BinaryNode -> produceBinary(node.operation, produce(node.left), produce(node.right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
open class FunctionalExpressionField<T>(
|
open class FunctionalExpressionField<T>(
|
||||||
val field: Field<T>
|
val field: Field<T>
|
||||||
) : Field<Expression<T>>, FunctionalExpressionSpace<T>(field, field.one) {
|
) : ExpressionField<T,Expression<T>>, FunctionalExpressionSpace<T>(field, field.one) {
|
||||||
|
|
||||||
|
override val one: Expression<T>
|
||||||
|
get() = const(this.field.one)
|
||||||
|
|
||||||
|
override fun const(value: Double): Expression<T> = const(field.run { one*value})
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
||||||
|
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
||||||
@ -109,17 +78,4 @@ open class FunctionalExpressionField<T>(
|
|||||||
|
|
||||||
operator fun T.times(arg: Expression<T>) = arg * this
|
operator fun T.times(arg: Expression<T>) = arg * this
|
||||||
operator fun T.div(arg: Expression<T>) = arg / this
|
operator fun T.div(arg: Expression<T>) = arg / this
|
||||||
|
|
||||||
override fun produce(node: SyntaxTreeNode): Expression<T> {
|
|
||||||
//TODO bring together numeric and typed expressions
|
|
||||||
return super.produce(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun produceBinary(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> {
|
|
||||||
return when (operation) {
|
|
||||||
BinaryNode.TIMES_OPERATION -> left * right
|
|
||||||
BinaryNode.DIV_OPERATION -> left / right
|
|
||||||
else -> super.produceBinary(operation, left, right)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user