diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 65e915ef4..8c19395d3 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,8 +2,9 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionField +import scientifik.kmath.expressions.ExpressionContext import scientifik.kmath.operations.ExtendedField +import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -112,7 +113,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionField { +object DiffExpressionContext : ExpressionContext, Field { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 94bc81413..73745f78f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,7 +1,6 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.Algebra /** * An elementary function that could be invoked on a map of arguments @@ -15,7 +14,7 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext> { +interface ExpressionContext : Algebra { /** * Introduce a variable into expression context */ @@ -25,87 +24,13 @@ interface ExpressionContext> { * A constant expression which does not depend on arguments */ fun const(value: T): E - - fun produce(node: SyntaxTreeNode): E } -interface ExpressionSpace> : Space, ExpressionContext { - - open fun produceSingular(value: String): E = variable(value) - - open fun produceUnary(operation: String, value: E): E { - return when (operation) { - UnaryNode.PLUS_OPERATION -> value - UnaryNode.MINUS_OPERATION -> -value - else -> error("Unary operation $operation is not supported by $this") - } +fun ExpressionContext.produce(node: SyntaxTreeNode): E { + return when (node) { + is NumberNode -> error("Single number nodes are not supported") + is SingularNode -> variable(node.value) + is UnaryNode -> unaryOperation(node.operation, produce(node.value)) + is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) } - - open fun produceBinary(operation: String, left: E, right: E): E { - return when (operation) { - BinaryNode.PLUS_OPERATION -> left + right - BinaryNode.MINUS_OPERATION -> left - right - else -> error("Binary operation $operation is not supported by $this") - } - } - - override fun produce(node: SyntaxTreeNode): E { - return when (node) { - is NumberNode -> error("Single number nodes are not supported") - is SingularNode -> produceSingular(node.value) - is UnaryNode -> produceUnary(node.operation, produce(node.value)) - is BinaryNode -> { - when (node.operation) { - BinaryNode.TIMES_OPERATION -> { - if (node.left is NumberNode) { - return produce(node.right) * node.left.value - } else if (node.right is NumberNode) { - return produce(node.left) * node.right.value - } - } - BinaryNode.DIV_OPERATION -> { - if (node.right is NumberNode) { - return produce(node.left) / node.right.value - } - } - } - produceBinary(node.operation, produce(node.left), produce(node.right)) - } - } - } -} - -interface ExpressionField> : Field, ExpressionSpace { - fun const(value: Double): E = one.times(value) - - override fun produce(node: SyntaxTreeNode): E { - if (node is BinaryNode) { - when (node.operation) { - BinaryNode.PLUS_OPERATION -> { - if (node.left is NumberNode) { - return produce(node.right) + one * node.left.value - } else if (node.right is NumberNode) { - return produce(node.left) + one * node.right.value - } - } - BinaryNode.MINUS_OPERATION -> { - if (node.left is NumberNode) { - return one * node.left.value - produce(node.right) - } else if (node.right is NumberNode) { - return produce(node.left) - one * node.right.value - } - } - } - } - return super.produce(node) - } - - override fun produceBinary(operation: String, left: E, right: E): E { - return when (operation) { - BinaryNode.TIMES_OPERATION -> left * right - BinaryNode.DIV_OPERATION -> left / right - else -> super.produceBinary(operation, left, right) - } - } - } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt index fff3c5bc0..b4038f680 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt @@ -8,9 +8,6 @@ data class NumberNode(val value: Number) : SyntaxTreeNode() data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() { companion object { - const val PLUS_OPERATION = "+" - const val MINUS_OPERATION = "-" - const val NOT_OPERATION = "!" const val ABS_OPERATION = "abs" const val SIN_OPERATION = "sin" const val COS_OPERATION = "cos" @@ -22,10 +19,6 @@ data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxT data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() { companion object { - const val PLUS_OPERATION = "+" - const val MINUS_OPERATION = "-" - const val TIMES_OPERATION = "*" - const val DIV_OPERATION = "/" //TODO add operations } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt index ad0acbc4a..4f38e3a71 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -38,9 +38,8 @@ internal class DivExpession(val context: Field, val expr: Expression, v } open class FunctionalExpressionSpace( - val space: Space, - one: T -) : Space>, ExpressionSpace> { + val space: Space +) : Space>, ExpressionContext> { override val zero: Expression = ConstantExpression(space.zero) @@ -62,12 +61,12 @@ open class FunctionalExpressionSpace( open class FunctionalExpressionField( val field: Field -) : ExpressionField>, FunctionalExpressionSpace(field, field.one) { +) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) - override fun const(value: Double): Expression = const(field.run { one*value}) + fun const(value: Double): Expression = const(field.run { one*value}) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 485185526..b3282e5f3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -6,9 +6,12 @@ annotation class KMathContext /** * Marker interface for any algebra */ -interface Algebra +interface Algebra { + fun unaryOperation(operation: String, arg: T): T + fun binaryOperation(operation: String, left: T, right: T): T +} -inline operator fun , R> T.invoke(block: T.() -> R): R = run(block) +inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** * Space-like operations without neutral element @@ -24,7 +27,7 @@ interface SpaceOperations : Algebra { */ fun multiply(a: T, k: Number): T - //Operation to be performed in this context + //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.plus(b: T): T = add(this, b) @@ -32,6 +35,24 @@ interface SpaceOperations : Algebra { operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) operator fun Number.times(b: T) = b * this + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + PLUS_OPERATION -> arg + MINUS_OPERATION -> -arg + else -> error("Unary operation $operation not defined in $this") + } + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + PLUS_OPERATION -> add(left, right) + MINUS_OPERATION -> left - right + else -> error("Binary operation $operation not defined in $this") + } + + companion object { + const val PLUS_OPERATION = "+" + const val MINUS_OPERATION = "-" + const val NOT_OPERATION = "!" + } } @@ -60,6 +81,15 @@ interface RingOperations : SpaceOperations { fun multiply(a: T, b: T): T operator fun T.times(b: T): T = multiply(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + TIMES_OPERATION -> multiply(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object{ + const val TIMES_OPERATION = "*" + } } /** @@ -85,6 +115,15 @@ interface FieldOperations : RingOperations { fun divide(a: T, b: T): T operator fun T.div(b: T): T = divide(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + DIV_OPERATION -> divide(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object{ + const val DIV_OPERATION = "/" + } } /** diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt new file mode 100644 index 000000000..e69de29bb