diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt index f312323b9..cc3d38f9c 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt @@ -55,24 +55,23 @@ public sealed class MST { public fun Algebra.evaluate(node: MST): T = when (node) { is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this") + is MST.Symbolic -> symbol(node.value) - is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) + is MST.Unary -> unaryOperation(node.operation)(evaluate(node.value)) + is MST.Binary -> when { - this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + this !is NumericAlgebra -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right)) node.left is MST.Numeric && node.right is MST.Numeric -> { - val number = RealField.binaryOperation( - node.operation, - node.left.value.toDouble(), - node.right.value.toDouble() - ) + val number = RealField + .binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble()) number(number) } - node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) - node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) - else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + node.left is MST.Numeric -> leftSideNumberOperation(node.operation)(node.left.value, evaluate(node.right)) + node.right is MST.Numeric -> rightSideNumberOperation(node.operation)(evaluate(node.left), node.right.value) + else -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right)) } } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 6ee6ab9af..cea708342 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -6,115 +6,111 @@ import kscience.kmath.operations.* * [Algebra] over [MST] nodes. */ public object MstAlgebra : NumericAlgebra { - override fun number(value: Number): MST.Numeric = MST.Numeric(value) + public override fun number(value: Number): MST.Numeric = MST.Numeric(value) - override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) + public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = - MST.Unary(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) } - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MST.Binary(operation, left, right) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = + { left, right -> MST.Binary(operation, left, right) } } /** * [Space] over [MST] nodes. */ public object MstSpace : Space, NumericAlgebra { - override val zero: MST.Numeric by lazy { number(0.0) } + public override val zero: MST.Numeric by lazy { number(0.0) } - override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) - override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) + public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) + override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) + override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstAlgebra.binaryOperation(operation, left, right) + override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstAlgebra.binaryOperation(operation) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) } /** * [Ring] over [MST] nodes. */ public object MstRing : Ring, NumericAlgebra { - override val zero: MST.Numeric + override val zero: MST get() = MstSpace.zero + override val one: MST = number(1.0) - override val one: MST.Numeric by lazy { number(1.0) } + public override fun number(value: Number): MST = MstSpace.number(value) + public override fun symbol(value: String): MST = MstSpace.symbol(value) + public override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) + public override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) - override fun number(value: Number): MST.Numeric = MstSpace.number(value) - override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) - override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstSpace.binaryOperation(operation) - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstSpace.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) } /** * [Field] over [MST] nodes. */ public object MstField : Field { - public override val zero: MST.Numeric + public override val zero: MST get() = MstRing.zero - public override val one: MST.Numeric + public override val one: MST get() = MstRing.one - public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) - public override fun number(value: Number): MST.Numeric = MstRing.number(value) - public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) - public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + public override fun symbol(value: String): MST = MstRing.symbol(value) + public override fun number(value: Number): MST = MstRing.number(value) + public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) + public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) + public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) - public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstRing.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstRing.binaryOperation(operation) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstRing.unaryOperation(operation) } /** * [ExtendedField] over [MST] nodes. */ public object MstExtendedField : ExtendedField { - override val zero: MST.Numeric + public override val zero: MST get() = MstField.zero - override val one: MST.Numeric + public override val one: MST get() = MstField.one - override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) - override fun number(value: Number): MST.Numeric = MstField.number(value) - override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) - override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) - override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) - override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) - override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) - override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) - override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) - override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) - override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) - override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) - override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + public override fun symbol(value: String): MST = MstField.symbol(value) + public override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) + public override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) + public override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg) + public override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) + public override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) + public override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) + public override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg) + public override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg) + public override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg) + public override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg) + public override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg) + public override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg) + public override fun add(a: MST, b: MST): MST = MstField.add(a, b) + public override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) + public override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) + public override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) + public override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) + public override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) - override fun power(arg: MST, pow: Number): MST.Binary = - binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstField.binaryOperation(operation) - override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) - override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstField.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstField.unaryOperation(operation) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = + MST.Unary(operation, arg) } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt index f68e3f5f8..8a99fd2f4 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -16,10 +16,8 @@ import kotlin.contracts.contract public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T = - algebra.binaryOperation(operation, left, right) + override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation) + override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation) @Suppress("UNCHECKED_CAST") override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt index 8f8175acd..eb5a1c8f3 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt @@ -1,6 +1,5 @@ package kscience.kmath.asm -import kscience.kmath.asm.compile import kscience.kmath.ast.mstInField import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField @@ -10,44 +9,44 @@ import kotlin.test.assertEquals internal class TestAsmSpecialization { @Test fun testUnaryPlus() { - val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperation("+")(symbol("x")) }.compile() assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { - val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperation("-")(symbol("x")) }.compile() assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { - val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperation("+")(symbol("x"), symbol("x")) }.compile() assertEquals(4.0, expr("x" to 2.0)) } @Test fun testSine() { - val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperation("sin")(symbol("x")) }.compile() assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { - val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperation("-")(symbol("x"), symbol("x")) }.compile() assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { - val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperation("/")(symbol("x"), symbol("x")) }.compile() assertEquals(1.0, expr("x" to 2.0)) } @Test fun testPower() { val expr = RealField - .mstInField { binaryOperation("power", symbol("x"), number(2)) } + .mstInField { binaryOperation("power")(symbol("x"), number(2)) } .compile() assertEquals(4.0, expr("x" to 2.0)) diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt index 2dc24597e..e2029ce19 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt @@ -1,8 +1,5 @@ package kscience.kmath.ast -import kscience.kmath.ast.evaluate -import kscience.kmath.ast.mstInField -import kscience.kmath.ast.parseMath import kscience.kmath.expressions.invoke import kscience.kmath.operations.Algebra import kscience.kmath.operations.Complex @@ -45,12 +42,15 @@ internal class ParserTest { val magicalAlgebra = object : Algebra { override fun symbol(value: String): String = value - override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() - - override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) { - "magic" -> "$left ★ $right" - else -> throw NotImplementedError() + override fun unaryOperation(operation: String): (arg: String) -> String { + throw NotImplementedError() } + + override fun binaryOperation(operation: String): (left: String, right: String) -> String = + when (operation) { + "magic" -> { left, right -> "$left ★ $right" } + else -> throw NotImplementedError() + } } val mst = "magic(a, b)".parseMath() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 0630e8e4b..e2413dea8 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -7,9 +7,8 @@ import kscience.kmath.operations.* * * @param algebra The algebra to provide for Expressions built. */ -public abstract class FunctionalExpressionAlgebra>( - public val algebra: A, -) : ExpressionAlgebra> { +public abstract class FunctionalExpressionAlgebra>(public val algebra: A) : + ExpressionAlgebra> { /** * Builds an Expression of constant expression which does not depend on arguments. */ @@ -25,19 +24,18 @@ public abstract class FunctionalExpressionAlgebra>( /** * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. */ - public override fun binaryOperation( - operation: String, - left: Expression, - right: Expression, - ): Expression = Expression { arguments -> - algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments)) - } + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + { left, right -> + Expression { arguments -> + algebra.binaryOperation(operation)(left.invoke(arguments), right.invoke(arguments)) + } + } /** * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. */ - public override fun unaryOperation(operation: String, arg: Expression): Expression = Expression { arguments -> - algebra.unaryOperation(operation, arg.invoke(arguments)) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = { arg -> + Expression { arguments -> algebra.unaryOperation(operation)(arg.invoke(arguments)) } } } @@ -52,7 +50,7 @@ public open class FunctionalExpressionSpace>(algebra: A) : * Builds an Expression of addition of two another expressions. */ public override fun add(a: Expression, b: Expression): Expression = - binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) /** * Builds an Expression of multiplication of expression by number. @@ -66,11 +64,11 @@ public open class FunctionalExpressionSpace>(algebra: A) : public operator fun T.plus(arg: Expression): Expression = arg + this public operator fun T.minus(arg: Expression): Expression = arg - this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), @@ -82,16 +80,16 @@ public open class FunctionalExpressionRing(algebra: A) : FunctionalExpress * Builds an Expression of multiplication of two expressions. */ public override fun multiply(a: Expression, b: Expression): Expression = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) + binaryOperation(RingOperations.TIMES_OPERATION)(a, b) public operator fun Expression.times(arg: T): Expression = this * const(arg) public operator fun T.times(arg: Expression): Expression = arg * this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public open class FunctionalExpressionField(algebra: A) : @@ -101,49 +99,49 @@ public open class FunctionalExpressionField(algebra: A) : * Builds an Expression of division an expression by another one. */ public override fun divide(a: Expression, b: Expression): Expression = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) + binaryOperation(FieldOperations.DIV_OPERATION)(a, b) public operator fun Expression.div(arg: T): Expression = this / const(arg) public operator fun T.div(arg: Expression): Expression = arg / this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public open class FunctionalExpressionExtendedField(algebra: A) : FunctionalExpressionField(algebra), ExtendedField> where A : ExtendedField, A : NumericAlgebra { public override fun sin(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) public override fun cos(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) public override fun asin(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) public override fun acos(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) public override fun atan(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) + unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) public override fun power(arg: Expression, pow: Number): Expression = - binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) public override fun exp(arg: Expression): Expression = - unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + public override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index f4dbce89a..174f62c12 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -18,10 +18,11 @@ public interface MatrixContext : SpaceOperations> { */ public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix - public override fun binaryOperation(operation: String, left: Matrix, right: Matrix): Matrix = when (operation) { - "dot" -> left dot right - else -> super.binaryOperation(operation, left, right) - } + public override fun binaryOperation(operation: String): (left: Matrix, right: Matrix) -> Matrix = + when (operation) { + "dot" -> { left, right -> left dot right } + else -> super.binaryOperation(operation) + } /** * Computes the dot product of this matrix and another one. diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 12a45615a..19ca1a4cc 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -13,19 +13,19 @@ public annotation class KMathContext */ public interface Algebra { /** - * Wrap raw string or variable + * Wraps raw string or variable. */ public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") /** - * Dynamic call of unary operation with name [operation] on [arg] + * Dynamically dispatches an unary operation with name [operation]. */ - public fun unaryOperation(operation: String, arg: T): T + public fun unaryOperation(operation: String): (arg: T) -> T /** - * Dynamic call of binary operation [operation] on [left] and [right] + * Dynamically dispatches a binary operation with name [operation]. */ - public fun binaryOperation(operation: String, left: T, right: T): T + public fun binaryOperation(operation: String): (left: T, right: T) -> T } /** @@ -40,16 +40,28 @@ public interface NumericAlgebra : Algebra { public fun number(value: Number): T /** - * Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. + * Dynamically dispatches a binary operation with name [operation] where the left argument is [Number]. */ - public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = - binaryOperation(operation, number(left), right) + public fun leftSideNumberOperation(operation: String): (left: Number, right: T) -> T = + { l, r -> binaryOperation(operation)(number(l), r) } + +// /** +// * Dynamically calls a binary operation with name [operation] where the left argument is [Number]. +// */ +// public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = +// leftSideNumberOperation(operation)(left, right) /** - * Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. + * Dynamically dispatches a binary operation with name [operation] where the right argument is [Number]. */ - public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = - leftSideNumberOperation(operation, right, left) + public fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T = + { l, r -> binaryOperation(operation)(l, number(r)) } + +// /** +// * Dynamically calls a binary operation with name [operation] where the right argument is [Number]. +// */ +// public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = +// rightSideNumberOperation(operation)(left, right) } /** @@ -146,15 +158,15 @@ public interface SpaceOperations : Algebra { */ public operator fun Number.times(b: T): T = b * this - override fun unaryOperation(operation: String, arg: T): T = when (operation) { - PLUS_OPERATION -> arg - MINUS_OPERATION -> -arg + override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) { + PLUS_OPERATION -> { arg -> arg } + MINUS_OPERATION -> { arg -> -arg } else -> error("Unary operation $operation not defined in $this") } - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - PLUS_OPERATION -> add(left, right) - MINUS_OPERATION -> left - right + override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { + PLUS_OPERATION -> ::add + MINUS_OPERATION -> { left, right -> left - right } else -> error("Binary operation $operation not defined in $this") } @@ -207,9 +219,9 @@ public interface RingOperations : SpaceOperations { */ public operator fun T.times(b: T): T = multiply(this, b) - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - TIMES_OPERATION -> multiply(left, right) - else -> super.binaryOperation(operation, left, right) + override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { + TIMES_OPERATION -> ::multiply + else -> super.binaryOperation(operation) } public companion object { @@ -234,20 +246,6 @@ public interface Ring : Space, RingOperations, NumericAlgebra { override fun number(value: Number): T = one * value.toDouble() - override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.leftSideNumberOperation(operation, left, right) - } - - override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.rightSideNumberOperation(operation, left, right) - } - /** * Addition of element and scalar. * @@ -308,9 +306,9 @@ public interface FieldOperations : RingOperations { */ public operator fun T.div(b: T): T = divide(this, b) - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - DIV_OPERATION -> divide(left, right) - else -> super.binaryOperation(operation, left, right) + override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { + DIV_OPERATION -> ::divide + else -> super.binaryOperation(operation) } public companion object { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt index a2b33a0c4..611222d34 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt @@ -15,23 +15,23 @@ public interface ExtendedFieldOperations : public override fun tan(arg: T): T = sin(arg) / cos(arg) public override fun tanh(arg: T): T = sinh(arg) / cosh(arg) - public override fun unaryOperation(operation: String, arg: T): T = when (operation) { - TrigonometricOperations.COS_OPERATION -> cos(arg) - TrigonometricOperations.SIN_OPERATION -> sin(arg) - TrigonometricOperations.TAN_OPERATION -> tan(arg) - TrigonometricOperations.ACOS_OPERATION -> acos(arg) - TrigonometricOperations.ASIN_OPERATION -> asin(arg) - TrigonometricOperations.ATAN_OPERATION -> atan(arg) - HyperbolicOperations.COSH_OPERATION -> cosh(arg) - HyperbolicOperations.SINH_OPERATION -> sinh(arg) - HyperbolicOperations.TANH_OPERATION -> tanh(arg) - HyperbolicOperations.ACOSH_OPERATION -> acosh(arg) - HyperbolicOperations.ASINH_OPERATION -> asinh(arg) - HyperbolicOperations.ATANH_OPERATION -> atanh(arg) - PowerOperations.SQRT_OPERATION -> sqrt(arg) - ExponentialOperations.EXP_OPERATION -> exp(arg) - ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) { + TrigonometricOperations.COS_OPERATION -> ::cos + TrigonometricOperations.SIN_OPERATION -> ::sin + TrigonometricOperations.TAN_OPERATION -> ::tan + TrigonometricOperations.ACOS_OPERATION -> ::acos + TrigonometricOperations.ASIN_OPERATION -> ::asin + TrigonometricOperations.ATAN_OPERATION -> ::atan + HyperbolicOperations.COSH_OPERATION -> ::cosh + HyperbolicOperations.SINH_OPERATION -> ::sinh + HyperbolicOperations.TANH_OPERATION -> ::tanh + HyperbolicOperations.ACOSH_OPERATION -> ::acosh + HyperbolicOperations.ASINH_OPERATION -> ::asinh + HyperbolicOperations.ATANH_OPERATION -> ::atanh + PowerOperations.SQRT_OPERATION -> ::sqrt + ExponentialOperations.EXP_OPERATION -> ::exp + ExponentialOperations.LN_OPERATION -> ::ln + else -> super.unaryOperation(operation) } } @@ -46,10 +46,11 @@ public interface ExtendedField : ExtendedFieldOperations, Field { public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 - public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - PowerOperations.POW_OPERATION -> power(left, right) - else -> super.rightSideNumberOperation(operation, left, right) - } + public override fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.rightSideNumberOperation(operation) + } } /** @@ -80,10 +81,11 @@ public object RealField : ExtendedField, Norm { public override val one: Double get() = 1.0 - public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) - } + public override fun binaryOperation(operation: String): (left: Double, right: Double) -> Double = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperation(operation) + } public override inline fun add(a: Double, b: Double): Double = a + b public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() @@ -130,9 +132,9 @@ public object FloatField : ExtendedField, Norm { public override val one: Float get() = 1.0f - public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Float, right: Float) -> Float = when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperation(operation) } public override inline fun add(a: Float, b: Float): Float = a + b