diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 65e915ef4..8c19395d3 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,8 +2,9 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionField +import scientifik.kmath.expressions.ExpressionContext import scientifik.kmath.operations.ExtendedField +import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -112,7 +113,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionField { +object DiffExpressionContext : ExpressionContext, Field { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index dae069879..73745f78f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,7 +1,6 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.Algebra /** * An elementary function that could be invoked on a map of arguments @@ -15,7 +14,7 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext { +interface ExpressionContext : Algebra { /** * Introduce a variable into expression context */ @@ -25,87 +24,13 @@ interface ExpressionContext { * A constant expression which does not depend on arguments */ fun const(value: T): E - - fun produce(node: SyntaxTreeNode): E } -interface ExpressionSpace : Space, ExpressionContext { - - fun produceSingular(value: String): E = variable(value) - - fun produceUnary(operation: String, value: E): E { - return when (operation) { - UnaryNode.PLUS_OPERATION -> value - UnaryNode.MINUS_OPERATION -> -value - else -> error("Unary operation $operation is not supported by $this") - } +fun ExpressionContext.produce(node: SyntaxTreeNode): E { + return when (node) { + is NumberNode -> error("Single number nodes are not supported") + is SingularNode -> variable(node.value) + is UnaryNode -> unaryOperation(node.operation, produce(node.value)) + is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) } - - fun produceBinary(operation: String, left: E, right: E): E { - return when (operation) { - BinaryNode.PLUS_OPERATION -> left + right - BinaryNode.MINUS_OPERATION -> left - right - else -> error("Binary operation $operation is not supported by $this") - } - } - - override fun produce(node: SyntaxTreeNode): E { - return when (node) { - is NumberNode -> error("Single number nodes are not supported") - is SingularNode -> produceSingular(node.value) - is UnaryNode -> produceUnary(node.operation, produce(node.value)) - is BinaryNode -> { - when (node.operation) { - BinaryNode.TIMES_OPERATION -> { - if (node.left is NumberNode) { - return produce(node.right) * node.left.value - } else if (node.right is NumberNode) { - return produce(node.left) * node.right.value - } - } - BinaryNode.DIV_OPERATION -> { - if (node.right is NumberNode) { - return produce(node.left) / node.right.value - } - } - } - produceBinary(node.operation, produce(node.left), produce(node.right)) - } - } - } -} - -interface ExpressionField : Field, ExpressionSpace { - fun number(value: Number): E = one * value - - override fun produce(node: SyntaxTreeNode): E { - if (node is BinaryNode) { - when (node.operation) { - BinaryNode.PLUS_OPERATION -> { - if (node.left is NumberNode) { - return produce(node.right) + one * node.left.value - } else if (node.right is NumberNode) { - return produce(node.left) + one * node.right.value - } - } - BinaryNode.MINUS_OPERATION -> { - if (node.left is NumberNode) { - return one * node.left.value - produce(node.right) - } else if (node.right is NumberNode) { - return produce(node.left) - one * node.right.value - } - } - } - } - return super.produce(node) - } - - override fun produceBinary(operation: String, left: E, right: E): E { - return when (operation) { - BinaryNode.TIMES_OPERATION -> left * right - BinaryNode.DIV_OPERATION -> left / right - else -> super.produceBinary(operation, left, right) - } - } - } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index 7b5a93465..4f38e3a71 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -3,7 +3,6 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Field import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke internal class VariableExpression(val name: String, val default: T? = null) : Expression { override fun invoke(arguments: Map): T = @@ -28,41 +27,54 @@ internal class ProductExpression(val context: Ring, val first: Expression< context.multiply(first.invoke(arguments), second.invoke(arguments)) } -internal class ConstProductExpression(val context: Space, val expr: Expression, val const: Number) : +internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } -internal class DivExpression(val context: Field, val expr: Expression, val second: Expression) : +internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : Expression { override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) } -open class FunctionalExpressionSpace(val space: Space) : Space>, ExpressionSpace> { +open class FunctionalExpressionSpace( + val space: Space +) : Space>, ExpressionContext> { + override val zero: Expression = ConstantExpression(space.zero) override fun const(value: T): Expression = ConstantExpression(value) + override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) + override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpression(space, a, k) - operator fun Expression.plus(arg: T): Expression = this + const(arg) - operator fun Expression.minus(arg: T): Expression = this - const(arg) - operator fun T.plus(arg: Expression): Expression = arg + this - operator fun T.minus(arg: Expression): Expression = arg - this + + override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) + + + operator fun Expression.plus(arg: T) = this + const(arg) + operator fun Expression.minus(arg: T) = this - const(arg) + + operator fun T.plus(arg: Expression) = arg + this + operator fun T.minus(arg: Expression) = arg - this } -open class FunctionalExpressionField(val field: Field) : - ExpressionField>, - FunctionalExpressionSpace(field) { +open class FunctionalExpressionField( + val field: Field +) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) - override fun number(value: Number): Expression = const(field { one * value }) + fun const(value: Double): Expression = const(field.run { one*value}) + override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) - override fun divide(a: Expression, b: Expression): Expression = DivExpression(field, a, b) - operator fun Expression.times(arg: T): Expression = this * const(arg) - operator fun Expression.div(arg: T): Expression = this / const(arg) - operator fun T.times(arg: Expression): Expression = arg * this - operator fun T.div(arg: Expression): Expression = arg / this -} + + override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) + + operator fun Expression.times(arg: T) = this * const(arg) + operator fun Expression.div(arg: T) = this / const(arg) + + operator fun T.times(arg: Expression) = arg * this + operator fun T.div(arg: Expression) = arg / this +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt index fff3c5bc0..b4038f680 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt @@ -8,9 +8,6 @@ data class NumberNode(val value: Number) : SyntaxTreeNode() data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() { companion object { - const val PLUS_OPERATION = "+" - const val MINUS_OPERATION = "-" - const val NOT_OPERATION = "!" const val ABS_OPERATION = "abs" const val SIN_OPERATION = "sin" const val COS_OPERATION = "cos" @@ -22,10 +19,6 @@ data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxT data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() { companion object { - const val PLUS_OPERATION = "+" - const val MINUS_OPERATION = "-" - const val TIMES_OPERATION = "*" - const val DIV_OPERATION = "/" //TODO add operations } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 485185526..b3282e5f3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -6,9 +6,12 @@ annotation class KMathContext /** * Marker interface for any algebra */ -interface Algebra +interface Algebra { + fun unaryOperation(operation: String, arg: T): T + fun binaryOperation(operation: String, left: T, right: T): T +} -inline operator fun , R> T.invoke(block: T.() -> R): R = run(block) +inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** * Space-like operations without neutral element @@ -24,7 +27,7 @@ interface SpaceOperations : Algebra { */ fun multiply(a: T, k: Number): T - //Operation to be performed in this context + //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.plus(b: T): T = add(this, b) @@ -32,6 +35,24 @@ interface SpaceOperations : Algebra { operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) operator fun Number.times(b: T) = b * this + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + PLUS_OPERATION -> arg + MINUS_OPERATION -> -arg + else -> error("Unary operation $operation not defined in $this") + } + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + PLUS_OPERATION -> add(left, right) + MINUS_OPERATION -> left - right + else -> error("Binary operation $operation not defined in $this") + } + + companion object { + const val PLUS_OPERATION = "+" + const val MINUS_OPERATION = "-" + const val NOT_OPERATION = "!" + } } @@ -60,6 +81,15 @@ interface RingOperations : SpaceOperations { fun multiply(a: T, b: T): T operator fun T.times(b: T): T = multiply(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + TIMES_OPERATION -> multiply(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object{ + const val TIMES_OPERATION = "*" + } } /** @@ -85,6 +115,15 @@ interface FieldOperations : RingOperations { fun divide(a: T, b: T): T operator fun T.div(b: T): T = divide(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + DIV_OPERATION -> divide(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object{ + const val DIV_OPERATION = "/" + } } /** diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt new file mode 100644 index 000000000..e69de29bb