From e447a15304f2c1b7df9ad00c48a721db9f92edc6 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sat, 28 Nov 2020 13:42:18 +0700 Subject: [PATCH] Make MstAlgebra operations return specific types --- .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 101 +++++++++--------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index cea708342..44a7f20b9 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -7,10 +7,10 @@ import kscience.kmath.operations.* */ public object MstAlgebra : NumericAlgebra { public override fun number(value: Number): MST.Numeric = MST.Numeric(value) - public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) - public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) } + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = + { arg -> MST.Unary(operation, arg) } public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = { left, right -> MST.Binary(operation, left, right) } @@ -25,92 +25,95 @@ public object MstSpace : Space, NumericAlgebra { public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) + override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) - override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstAlgebra.binaryOperation(operation) - override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) + override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstAlgebra.unaryOperation(operation) } /** * [Ring] over [MST] nodes. */ public object MstRing : Ring, NumericAlgebra { - override val zero: MST + override val zero: MST.Numeric get() = MstSpace.zero - override val one: MST = number(1.0) - public override fun number(value: Number): MST = MstSpace.number(value) - public override fun symbol(value: String): MST = MstSpace.symbol(value) - public override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) + override val one: MST.Numeric by lazy { number(1.0) } - public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + public override fun number(value: Number): MST.Numeric = MstSpace.number(value) + public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) + public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) + + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstSpace.binaryOperation(operation) - public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperation(operation) } /** * [Field] over [MST] nodes. */ public object MstField : Field { - public override val zero: MST + public override val zero: MST.Numeric get() = MstRing.zero - public override val one: MST + public override val one: MST.Numeric get() = MstRing.one - public override fun symbol(value: String): MST = MstRing.symbol(value) - public override fun number(value: Number): MST = MstRing.number(value) - public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) + public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) + public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) - public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperation(operation) - public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstRing.unaryOperation(operation) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstRing.unaryOperation(operation) } /** * [ExtendedField] over [MST] nodes. */ public object MstExtendedField : ExtendedField { - public override val zero: MST + public override val zero: MST.Numeric get() = MstField.zero - public override val one: MST + public override val one: MST.Numeric get() = MstField.one public override fun symbol(value: String): MST = MstField.symbol(value) - public override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) - public override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) - public override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg) - public override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) - public override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) - public override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) - public override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg) - public override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg) - public override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg) - public override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg) - public override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg) - public override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg) - public override fun add(a: MST, b: MST): MST = MstField.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) - public override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) - public override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) - public override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) + public override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) + public override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) + public override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg) + public override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) + public override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) + public override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) + public override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg) + public override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg) + public override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg) + public override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg) + public override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg) + public override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg) + public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) - public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + public override fun power(arg: MST, pow: Number): MST.Binary = + binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) + + public override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) + public override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) + + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstField.binaryOperation(operation) - public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstField.unaryOperation(operation) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = - MST.Unary(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstField.unaryOperation(operation) }