A prototype for advanced expressoins

This commit is contained in:
Alexander Nozik 2020-05-14 20:30:43 +03:00
parent c15f77acef
commit 4a26c1080e
4 changed files with 163 additions and 72 deletions

View File

@ -1,9 +1,5 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
/**
* An elementary function that could be invoked on a map of arguments
*/
@ -26,67 +22,6 @@ interface ExpressionContext<T> {
* A constant expression which does not depend on arguments
*/
fun const(value: T): Expression<T>
}
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class ConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value
}
internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
}
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.multiply(first.invoke(arguments), second.invoke(arguments))
}
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
}
open class ExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionContext<T> {
override val zero: Expression<T> = ConstantExpression(space.zero)
override fun const(value: T): Expression<T> = ConstantExpression(value)
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
operator fun Expression<T>.plus(arg: T) = this + const(arg)
operator fun Expression<T>.minus(arg: T) = this - const(arg)
operator fun T.plus(arg: Expression<T>) = arg + this
operator fun T.minus(arg: Expression<T>) = arg - this
}
class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionSpace<T>(field) {
override val one: Expression<T> = ConstantExpression(field.one)
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
operator fun Expression<T>.times(arg: T) = this * const(arg)
operator fun Expression<T>.div(arg: T) = this / const(arg)
operator fun T.times(arg: Expression<T>) = arg * this
operator fun T.div(arg: Expression<T>) = arg / this
fun produce(node: SyntaxTreeNode): Expression<T>
}

View File

@ -0,0 +1,125 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class ConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value
}
internal class SumExpression<T>(
val context: Space<T>,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
}
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.multiply(first.invoke(arguments), second.invoke(arguments))
}
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
}
open class FunctionalExpressionSpace<T>(
val space: Space<T>,
one: T
) : Space<Expression<T>>, ExpressionContext<T> {
override val zero: Expression<T> = ConstantExpression(space.zero)
val one: Expression<T> = ConstantExpression(one)
override fun const(value: T): Expression<T> = ConstantExpression(value)
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
operator fun Expression<T>.plus(arg: T) = this + const(arg)
operator fun Expression<T>.minus(arg: T) = this - const(arg)
operator fun T.plus(arg: Expression<T>) = arg + this
operator fun T.minus(arg: Expression<T>) = arg - this
fun const(value: Double): Expression<T> = one.times(value)
open fun produceSingular(value: String): Expression<T> {
val numberValue = value.toDoubleOrNull()
return if (numberValue == null) {
variable(value)
} else {
const(numberValue)
}
}
open fun produceUnary(operation: String, value: Expression<T>): Expression<T> {
return when (operation) {
UnaryNode.PLUS_OPERATION -> value
UnaryNode.MINUS_OPERATION -> -value
else -> error("Unary operation $operation is not supported by $this")
}
}
open fun produceBinary(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> {
return when (operation) {
BinaryNode.PLUS_OPERATION -> left + right
BinaryNode.MINUS_OPERATION -> left - right
else -> error("Binary operation $operation is not supported by $this")
}
}
override fun produce(node: SyntaxTreeNode): Expression<T> {
return when (node) {
is SingularNode -> produceSingular(node.value)
is UnaryNode -> produceUnary(node.operation, produce(node.value))
is BinaryNode -> produceBinary(node.operation, produce(node.left), produce(node.right))
}
}
}
open class FunctionalExpressionField<T>(
val field: Field<T>
) : Field<Expression<T>>, FunctionalExpressionSpace<T>(field, field.one) {
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
operator fun Expression<T>.times(arg: T) = this * const(arg)
operator fun Expression<T>.div(arg: T) = this / const(arg)
operator fun T.times(arg: Expression<T>) = arg * this
operator fun T.div(arg: Expression<T>) = arg / this
override fun produce(node: SyntaxTreeNode): Expression<T> {
//TODO bring together numeric and typed expressions
return super.produce(node)
}
override fun produceBinary(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> {
return when (operation) {
BinaryNode.TIMES_OPERATION -> left * right
BinaryNode.DIV_OPERATION -> left / right
else -> super.produceBinary(operation, left, right)
}
}
}

View File

@ -0,0 +1,31 @@
package scientifik.kmath.expressions
sealed class SyntaxTreeNode
data class SingularNode(val value: String) : SyntaxTreeNode()
data class UnaryNode(val operation: String, val value: SyntaxTreeNode): SyntaxTreeNode(){
companion object{
const val PLUS_OPERATION = "+"
const val MINUS_OPERATION = "-"
const val NOT_OPERATION = "!"
const val ABS_OPERATION = "abs"
const val SIN_OPERATION = "sin"
const val cos_OPERATION = "cos"
//TODO add operations
}
}
data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode): SyntaxTreeNode(){
companion object{
const val PLUS_OPERATION = "+"
const val MINUS_OPERATION = "-"
const val TIMES_OPERATION = "*"
const val DIV_OPERATION = "/"
//TODO add operations
}
}
//TODO add a function with positional arguments
//TODO add a function with named arguments

View File

@ -9,7 +9,7 @@ import kotlin.test.assertEquals
class ExpressionFieldTest {
@Test
fun testExpression() {
val context = ExpressionField(RealField)
val context = FunctionalExpressionField(RealField)
val expression = with(context) {
val x = variable("x", 2.0)
x * x + 2 * x + one
@ -20,7 +20,7 @@ class ExpressionFieldTest {
@Test
fun testComplex() {
val context = ExpressionField(ComplexField)
val context = FunctionalExpressionField(ComplexField)
val expression = with(context) {
val x = variable("x", Complex(2.0, 0.0))
x * x + 2 * x + one
@ -31,23 +31,23 @@ class ExpressionFieldTest {
@Test
fun separateContext() {
fun <T> ExpressionField<T>.expression(): Expression<T> {
fun <T> FunctionalExpressionField<T>.expression(): Expression<T> {
val x = variable("x")
return x * x + 2 * x + one
}
val expression = ExpressionField(RealField).expression()
val expression = FunctionalExpressionField(RealField).expression()
assertEquals(expression("x" to 1.0), 4.0)
}
@Test
fun valueExpression() {
val expressionBuilder: ExpressionField<Double>.() -> Expression<Double> = {
val expressionBuilder: FunctionalExpressionField<Double>.() -> Expression<Double> = {
val x = variable("x")
x * x + 2 * x + one
}
val expression = ExpressionField(RealField).expressionBuilder()
val expression = FunctionalExpressionField(RealField).expressionBuilder()
assertEquals(expression("x" to 1.0), 4.0)
}
}