A prototype for advanced expressoins
This commit is contained in:
parent
c15f77acef
commit
4a26c1080e
@ -1,9 +1,5 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
/**
|
||||
* An elementary function that could be invoked on a map of arguments
|
||||
*/
|
||||
@ -26,67 +22,6 @@ interface ExpressionContext<T> {
|
||||
* A constant expression which does not depend on arguments
|
||||
*/
|
||||
fun const(value: T): Expression<T>
|
||||
}
|
||||
|
||||
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||
}
|
||||
|
||||
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = value
|
||||
}
|
||||
|
||||
internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||
}
|
||||
|
||||
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
open class ExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionContext<T> {
|
||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
||||
|
||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
||||
|
||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
||||
|
||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
||||
|
||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
||||
|
||||
|
||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
||||
|
||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
||||
}
|
||||
|
||||
|
||||
class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionSpace<T>(field) {
|
||||
override val one: Expression<T> = ConstantExpression(field.one)
|
||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
||||
|
||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
||||
|
||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
||||
|
||||
operator fun T.times(arg: Expression<T>) = arg * this
|
||||
operator fun T.div(arg: Expression<T>) = arg / this
|
||||
fun produce(node: SyntaxTreeNode): Expression<T>
|
||||
}
|
@ -0,0 +1,125 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||
}
|
||||
|
||||
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = value
|
||||
}
|
||||
|
||||
internal class SumExpression<T>(
|
||||
val context: Space<T>,
|
||||
val first: Expression<T>,
|
||||
val second: Expression<T>
|
||||
) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||
}
|
||||
|
||||
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
open class FunctionalExpressionSpace<T>(
|
||||
val space: Space<T>,
|
||||
one: T
|
||||
) : Space<Expression<T>>, ExpressionContext<T> {
|
||||
|
||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
||||
|
||||
val one: Expression<T> = ConstantExpression(one)
|
||||
|
||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
||||
|
||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
||||
|
||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
||||
|
||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
||||
|
||||
|
||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
||||
|
||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
||||
|
||||
fun const(value: Double): Expression<T> = one.times(value)
|
||||
|
||||
open fun produceSingular(value: String): Expression<T> {
|
||||
val numberValue = value.toDoubleOrNull()
|
||||
return if (numberValue == null) {
|
||||
variable(value)
|
||||
} else {
|
||||
const(numberValue)
|
||||
}
|
||||
}
|
||||
|
||||
open fun produceUnary(operation: String, value: Expression<T>): Expression<T> {
|
||||
return when (operation) {
|
||||
UnaryNode.PLUS_OPERATION -> value
|
||||
UnaryNode.MINUS_OPERATION -> -value
|
||||
else -> error("Unary operation $operation is not supported by $this")
|
||||
}
|
||||
}
|
||||
|
||||
open fun produceBinary(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> {
|
||||
return when (operation) {
|
||||
BinaryNode.PLUS_OPERATION -> left + right
|
||||
BinaryNode.MINUS_OPERATION -> left - right
|
||||
else -> error("Binary operation $operation is not supported by $this")
|
||||
}
|
||||
}
|
||||
|
||||
override fun produce(node: SyntaxTreeNode): Expression<T> {
|
||||
return when (node) {
|
||||
is SingularNode -> produceSingular(node.value)
|
||||
is UnaryNode -> produceUnary(node.operation, produce(node.value))
|
||||
is BinaryNode -> produceBinary(node.operation, produce(node.left), produce(node.right))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
open class FunctionalExpressionField<T>(
|
||||
val field: Field<T>
|
||||
) : Field<Expression<T>>, FunctionalExpressionSpace<T>(field, field.one) {
|
||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
||||
|
||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
||||
|
||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
||||
|
||||
operator fun T.times(arg: Expression<T>) = arg * this
|
||||
operator fun T.div(arg: Expression<T>) = arg / this
|
||||
|
||||
override fun produce(node: SyntaxTreeNode): Expression<T> {
|
||||
//TODO bring together numeric and typed expressions
|
||||
return super.produce(node)
|
||||
}
|
||||
|
||||
override fun produceBinary(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> {
|
||||
return when (operation) {
|
||||
BinaryNode.TIMES_OPERATION -> left * right
|
||||
BinaryNode.DIV_OPERATION -> left / right
|
||||
else -> super.produceBinary(operation, left, right)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
sealed class SyntaxTreeNode
|
||||
|
||||
data class SingularNode(val value: String) : SyntaxTreeNode()
|
||||
|
||||
data class UnaryNode(val operation: String, val value: SyntaxTreeNode): SyntaxTreeNode(){
|
||||
companion object{
|
||||
const val PLUS_OPERATION = "+"
|
||||
const val MINUS_OPERATION = "-"
|
||||
const val NOT_OPERATION = "!"
|
||||
const val ABS_OPERATION = "abs"
|
||||
const val SIN_OPERATION = "sin"
|
||||
const val cos_OPERATION = "cos"
|
||||
//TODO add operations
|
||||
}
|
||||
}
|
||||
|
||||
data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode): SyntaxTreeNode(){
|
||||
companion object{
|
||||
const val PLUS_OPERATION = "+"
|
||||
const val MINUS_OPERATION = "-"
|
||||
const val TIMES_OPERATION = "*"
|
||||
const val DIV_OPERATION = "/"
|
||||
//TODO add operations
|
||||
}
|
||||
}
|
||||
|
||||
//TODO add a function with positional arguments
|
||||
|
||||
//TODO add a function with named arguments
|
@ -9,7 +9,7 @@ import kotlin.test.assertEquals
|
||||
class ExpressionFieldTest {
|
||||
@Test
|
||||
fun testExpression() {
|
||||
val context = ExpressionField(RealField)
|
||||
val context = FunctionalExpressionField(RealField)
|
||||
val expression = with(context) {
|
||||
val x = variable("x", 2.0)
|
||||
x * x + 2 * x + one
|
||||
@ -20,7 +20,7 @@ class ExpressionFieldTest {
|
||||
|
||||
@Test
|
||||
fun testComplex() {
|
||||
val context = ExpressionField(ComplexField)
|
||||
val context = FunctionalExpressionField(ComplexField)
|
||||
val expression = with(context) {
|
||||
val x = variable("x", Complex(2.0, 0.0))
|
||||
x * x + 2 * x + one
|
||||
@ -31,23 +31,23 @@ class ExpressionFieldTest {
|
||||
|
||||
@Test
|
||||
fun separateContext() {
|
||||
fun <T> ExpressionField<T>.expression(): Expression<T> {
|
||||
fun <T> FunctionalExpressionField<T>.expression(): Expression<T> {
|
||||
val x = variable("x")
|
||||
return x * x + 2 * x + one
|
||||
}
|
||||
|
||||
val expression = ExpressionField(RealField).expression()
|
||||
val expression = FunctionalExpressionField(RealField).expression()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun valueExpression() {
|
||||
val expressionBuilder: ExpressionField<Double>.() -> Expression<Double> = {
|
||||
val expressionBuilder: FunctionalExpressionField<Double>.() -> Expression<Double> = {
|
||||
val x = variable("x")
|
||||
x * x + 2 * x + one
|
||||
}
|
||||
|
||||
val expression = ExpressionField(RealField).expressionBuilder()
|
||||
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user