Rework unary/binary operation API

This commit is contained in:
Iaroslav Postovalov 2020-11-28 13:33:05 +07:00
parent abe68a4fb6
commit 3e7c9d8dce
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
9 changed files with 194 additions and 203 deletions

View File

@ -55,24 +55,23 @@ public sealed class MST {
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) { public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Unary -> unaryOperation(node.operation)(evaluate(node.value))
is MST.Binary -> when { is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) this !is NumericAlgebra -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation( val number = RealField
node.operation, .binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble())
node.left.value.toDouble(),
node.right.value.toDouble()
)
number(number) number(number)
} }
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) node.left is MST.Numeric -> leftSideNumberOperation(node.operation)(node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) node.right is MST.Numeric -> rightSideNumberOperation(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) else -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
} }
} }

View File

@ -6,115 +6,111 @@ import kscience.kmath.operations.*
* [Algebra] over [MST] nodes. * [Algebra] over [MST] nodes.
*/ */
public object MstAlgebra : NumericAlgebra<MST> { public object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST.Numeric = MST.Numeric(value) public override fun number(value: Number): MST.Numeric = MST.Numeric(value)
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) }
MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MST.Binary(operation, left, right) { left, right -> MST.Binary(operation, left, right) }
} }
/** /**
* [Space] over [MST] nodes. * [Space] over [MST] nodes.
*/ */
public object MstSpace : Space<MST>, NumericAlgebra<MST> { public object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST.Numeric by lazy { number(0.0) } public override val zero: MST.Numeric by lazy { number(0.0) }
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
MstAlgebra.binaryOperation(operation, left, right) MstAlgebra.binaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation)
} }
/** /**
* [Ring] over [MST] nodes. * [Ring] over [MST] nodes.
*/ */
public object MstRing : Ring<MST>, NumericAlgebra<MST> { public object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST.Numeric override val zero: MST
get() = MstSpace.zero get() = MstSpace.zero
override val one: MST = number(1.0)
override val one: MST.Numeric by lazy { number(1.0) } public override fun number(value: Number): MST = MstSpace.number(value)
public override fun symbol(value: String): MST = MstSpace.symbol(value)
public override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
public override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
override fun number(value: Number): MST.Numeric = MstSpace.number(value) public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) MstSpace.binaryOperation(operation)
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation)
MstSpace.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
} }
/** /**
* [Field] over [MST] nodes. * [Field] over [MST] nodes.
*/ */
public object MstField : Field<MST> { public object MstField : Field<MST> {
public override val zero: MST.Numeric public override val zero: MST
get() = MstRing.zero get() = MstRing.zero
public override val one: MST.Numeric public override val one: MST
get() = MstRing.one get() = MstRing.one
public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) public override fun symbol(value: String): MST = MstRing.symbol(value)
public override fun number(value: Number): MST.Numeric = MstRing.number(value) public override fun number(value: Number): MST = MstRing.number(value)
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b) public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
MstRing.binaryOperation(operation, left, right) MstRing.binaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg) public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstRing.unaryOperation(operation)
} }
/** /**
* [ExtendedField] over [MST] nodes. * [ExtendedField] over [MST] nodes.
*/ */
public object MstExtendedField : ExtendedField<MST> { public object MstExtendedField : ExtendedField<MST> {
override val zero: MST.Numeric public override val zero: MST
get() = MstField.zero get() = MstField.zero
override val one: MST.Numeric public override val one: MST
get() = MstField.one get() = MstField.one
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) public override fun symbol(value: String): MST = MstField.symbol(value)
override fun number(value: Number): MST.Numeric = MstField.number(value) public override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) public override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) public override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg)
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) public override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) public override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) public override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) public override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg)
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) public override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg)
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) public override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg)
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) public override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg)
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) public override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg)
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) public override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg)
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) public override fun add(a: MST, b: MST): MST = MstField.add(a, b)
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) public override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) public override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) public override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) public override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
public override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
public override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
override fun power(arg: MST, pow: Number): MST.Binary = public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) MstField.binaryOperation(operation)
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstField.unaryOperation(operation)
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary =
MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstField.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
} }

View File

@ -16,10 +16,8 @@ import kotlin.contracts.contract
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> { public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation)
override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)

View File

@ -1,6 +1,5 @@
package kscience.kmath.asm package kscience.kmath.asm
import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
@ -10,44 +9,44 @@ import kotlin.test.assertEquals
internal class TestAsmSpecialization { internal class TestAsmSpecialization {
@Test @Test
fun testUnaryPlus() { fun testUnaryPlus() {
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() val expr = RealField.mstInField { unaryOperation("+")(symbol("x")) }.compile()
assertEquals(2.0, expr("x" to 2.0)) assertEquals(2.0, expr("x" to 2.0))
} }
@Test @Test
fun testUnaryMinus() { fun testUnaryMinus() {
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() val expr = RealField.mstInField { unaryOperation("-")(symbol("x")) }.compile()
assertEquals(-2.0, expr("x" to 2.0)) assertEquals(-2.0, expr("x" to 2.0))
} }
@Test @Test
fun testAdd() { fun testAdd() {
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() val expr = RealField.mstInField { binaryOperation("+")(symbol("x"), symbol("x")) }.compile()
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))
} }
@Test @Test
fun testSine() { fun testSine() {
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() val expr = RealField.mstInField { unaryOperation("sin")(symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 0.0)) assertEquals(0.0, expr("x" to 0.0))
} }
@Test @Test
fun testMinus() { fun testMinus() {
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() val expr = RealField.mstInField { binaryOperation("-")(symbol("x"), symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 2.0)) assertEquals(0.0, expr("x" to 2.0))
} }
@Test @Test
fun testDivide() { fun testDivide() {
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() val expr = RealField.mstInField { binaryOperation("/")(symbol("x"), symbol("x")) }.compile()
assertEquals(1.0, expr("x" to 2.0)) assertEquals(1.0, expr("x" to 2.0))
} }
@Test @Test
fun testPower() { fun testPower() {
val expr = RealField val expr = RealField
.mstInField { binaryOperation("power", symbol("x"), number(2)) } .mstInField { binaryOperation("power")(symbol("x"), number(2)) }
.compile() .compile()
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))

View File

@ -1,8 +1,5 @@
package kscience.kmath.ast package kscience.kmath.ast
import kscience.kmath.ast.evaluate
import kscience.kmath.ast.mstInField
import kscience.kmath.ast.parseMath
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.Complex import kscience.kmath.operations.Complex
@ -45,12 +42,15 @@ internal class ParserTest {
val magicalAlgebra = object : Algebra<String> { val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() override fun unaryOperation(operation: String): (arg: String) -> String {
throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
} }
override fun binaryOperation(operation: String): (left: String, right: String) -> String =
when (operation) {
"magic" -> { left, right -> "$left$right" }
else -> throw NotImplementedError()
}
} }
val mst = "magic(a, b)".parseMath() val mst = "magic(a, b)".parseMath()

View File

@ -7,9 +7,8 @@ import kscience.kmath.operations.*
* *
* @param algebra The algebra to provide for Expressions built. * @param algebra The algebra to provide for Expressions built.
*/ */
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>( public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
public val algebra: A, ExpressionAlgebra<T, Expression<T>> {
) : ExpressionAlgebra<T, Expression<T>> {
/** /**
* Builds an Expression of constant expression which does not depend on arguments. * Builds an Expression of constant expression which does not depend on arguments.
*/ */
@ -25,19 +24,18 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
/** /**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/ */
public override fun binaryOperation( public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
operation: String, { left, right ->
left: Expression<T>, Expression { arguments ->
right: Expression<T>, algebra.binaryOperation(operation)(left.invoke(arguments), right.invoke(arguments))
): Expression<T> = Expression { arguments -> }
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments)) }
}
/** /**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. * Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/ */
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments -> public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
algebra.unaryOperation(operation, arg.invoke(arguments)) Expression { arguments -> algebra.unaryOperation(operation)(arg.invoke(arguments)) }
} }
} }
@ -52,7 +50,7 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
* Builds an Expression of addition of two another expressions. * Builds an Expression of addition of two another expressions.
*/ */
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
/** /**
* Builds an Expression of multiplication of expression by number. * Builds an Expression of multiplication of expression by number.
@ -66,11 +64,11 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg) super<FunctionalExpressionAlgebra>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right) super<FunctionalExpressionAlgebra>.binaryOperation(operation)
} }
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra), public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
@ -82,16 +80,16 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
* Builds an Expression of multiplication of two expressions. * Builds an Expression of multiplication of two expressions.
*/ */
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(RingOperations.TIMES_OPERATION, a, b) binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg) public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation, arg) super<FunctionalExpressionSpace>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right) super<FunctionalExpressionSpace>.binaryOperation(operation)
} }
public open class FunctionalExpressionField<T, A>(algebra: A) : public open class FunctionalExpressionField<T, A>(algebra: A) :
@ -101,49 +99,49 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(FieldOperations.DIV_OPERATION, a, b) binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg) public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation, arg) super<FunctionalExpressionRing>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation, left, right) super<FunctionalExpressionRing>.binaryOperation(operation)
} }
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) : public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
FunctionalExpressionField<T, A>(algebra), FunctionalExpressionField<T, A>(algebra),
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> { ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
public override fun sin(arg: Expression<T>): Expression<T> = public override fun sin(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
public override fun cos(arg: Expression<T>): Expression<T> = public override fun cos(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.COS_OPERATION, arg) unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
public override fun asin(arg: Expression<T>): Expression<T> = public override fun asin(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
public override fun acos(arg: Expression<T>): Expression<T> = public override fun acos(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
public override fun atan(arg: Expression<T>): Expression<T> = public override fun atan(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
public override fun power(arg: Expression<T>, pow: Number): Expression<T> = public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
public override fun exp(arg: Expression<T>): Expression<T> = public override fun exp(arg: Expression<T>): Expression<T> =
unaryOperation(ExponentialOperations.EXP_OPERATION, arg) unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg) public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.unaryOperation(operation, arg) super<FunctionalExpressionField>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.binaryOperation(operation, left, right) super<FunctionalExpressionField>.binaryOperation(operation)
} }
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> = public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =

View File

@ -18,10 +18,11 @@ public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
*/ */
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T> public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): Matrix<T> = when (operation) { public override fun binaryOperation(operation: String): (left: Matrix<T>, right: Matrix<T>) -> Matrix<T> =
"dot" -> left dot right when (operation) {
else -> super.binaryOperation(operation, left, right) "dot" -> { left, right -> left dot right }
} else -> super.binaryOperation(operation)
}
/** /**
* Computes the dot product of this matrix and another one. * Computes the dot product of this matrix and another one.

View File

@ -13,19 +13,19 @@ public annotation class KMathContext
*/ */
public interface Algebra<T> { public interface Algebra<T> {
/** /**
* Wrap raw string or variable * Wraps raw string or variable.
*/ */
public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
/** /**
* Dynamic call of unary operation with name [operation] on [arg] * Dynamically dispatches an unary operation with name [operation].
*/ */
public fun unaryOperation(operation: String, arg: T): T public fun unaryOperation(operation: String): (arg: T) -> T
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] * Dynamically dispatches a binary operation with name [operation].
*/ */
public fun binaryOperation(operation: String, left: T, right: T): T public fun binaryOperation(operation: String): (left: T, right: T) -> T
} }
/** /**
@ -40,16 +40,28 @@ public interface NumericAlgebra<T> : Algebra<T> {
public fun number(value: Number): T public fun number(value: Number): T
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. * Dynamically dispatches a binary operation with name [operation] where the left argument is [Number].
*/ */
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = public fun leftSideNumberOperation(operation: String): (left: Number, right: T) -> T =
binaryOperation(operation, number(left), right) { l, r -> binaryOperation(operation)(number(l), r) }
// /**
// * Dynamically calls a binary operation with name [operation] where the left argument is [Number].
// */
// public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
// leftSideNumberOperation(operation)(left, right)
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. * Dynamically dispatches a binary operation with name [operation] where the right argument is [Number].
*/ */
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = public fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
leftSideNumberOperation(operation, right, left) { l, r -> binaryOperation(operation)(l, number(r)) }
// /**
// * Dynamically calls a binary operation with name [operation] where the right argument is [Number].
// */
// public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
// rightSideNumberOperation(operation)(left, right)
} }
/** /**
@ -146,15 +158,15 @@ public interface SpaceOperations<T> : Algebra<T> {
*/ */
public operator fun Number.times(b: T): T = b * this public operator fun Number.times(b: T): T = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> arg PLUS_OPERATION -> { arg -> arg }
MINUS_OPERATION -> -arg MINUS_OPERATION -> { arg -> -arg }
else -> error("Unary operation $operation not defined in $this") else -> error("Unary operation $operation not defined in $this")
} }
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
PLUS_OPERATION -> add(left, right) PLUS_OPERATION -> ::add
MINUS_OPERATION -> left - right MINUS_OPERATION -> { left, right -> left - right }
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }
@ -207,9 +219,9 @@ public interface RingOperations<T> : SpaceOperations<T> {
*/ */
public operator fun T.times(b: T): T = multiply(this, b) public operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
TIMES_OPERATION -> multiply(left, right) TIMES_OPERATION -> ::multiply
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation)
} }
public companion object { public companion object {
@ -234,20 +246,6 @@ public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble() override fun number(value: Number): T = one * value.toDouble()
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right)
}
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
/** /**
* Addition of element and scalar. * Addition of element and scalar.
* *
@ -308,9 +306,9 @@ public interface FieldOperations<T> : RingOperations<T> {
*/ */
public operator fun T.div(b: T): T = divide(this, b) public operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
DIV_OPERATION -> divide(left, right) DIV_OPERATION -> ::divide
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation)
} }
public companion object { public companion object {

View File

@ -15,23 +15,23 @@ public interface ExtendedFieldOperations<T> :
public override fun tan(arg: T): T = sin(arg) / cos(arg) public override fun tan(arg: T): T = sin(arg) / cos(arg)
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg) public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
public override fun unaryOperation(operation: String, arg: T): T = when (operation) { public override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.COS_OPERATION -> ::cos
TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.SIN_OPERATION -> ::sin
TrigonometricOperations.TAN_OPERATION -> tan(arg) TrigonometricOperations.TAN_OPERATION -> ::tan
TrigonometricOperations.ACOS_OPERATION -> acos(arg) TrigonometricOperations.ACOS_OPERATION -> ::acos
TrigonometricOperations.ASIN_OPERATION -> asin(arg) TrigonometricOperations.ASIN_OPERATION -> ::asin
TrigonometricOperations.ATAN_OPERATION -> atan(arg) TrigonometricOperations.ATAN_OPERATION -> ::atan
HyperbolicOperations.COSH_OPERATION -> cosh(arg) HyperbolicOperations.COSH_OPERATION -> ::cosh
HyperbolicOperations.SINH_OPERATION -> sinh(arg) HyperbolicOperations.SINH_OPERATION -> ::sinh
HyperbolicOperations.TANH_OPERATION -> tanh(arg) HyperbolicOperations.TANH_OPERATION -> ::tanh
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg) HyperbolicOperations.ACOSH_OPERATION -> ::acosh
HyperbolicOperations.ASINH_OPERATION -> asinh(arg) HyperbolicOperations.ASINH_OPERATION -> ::asinh
HyperbolicOperations.ATANH_OPERATION -> atanh(arg) HyperbolicOperations.ATANH_OPERATION -> ::atanh
PowerOperations.SQRT_OPERATION -> sqrt(arg) PowerOperations.SQRT_OPERATION -> ::sqrt
ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.EXP_OPERATION -> ::exp
ExponentialOperations.LN_OPERATION -> ln(arg) ExponentialOperations.LN_OPERATION -> ::ln
else -> super.unaryOperation(operation, arg) else -> super.unaryOperation(operation)
} }
} }
@ -46,10 +46,11 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { public override fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
PowerOperations.POW_OPERATION -> power(left, right) when (operation) {
else -> super.rightSideNumberOperation(operation, left, right) PowerOperations.POW_OPERATION -> ::power
} else -> super.rightSideNumberOperation(operation)
}
} }
/** /**
@ -80,10 +81,11 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
public override val one: Double public override val one: Double
get() = 1.0 get() = 1.0
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { public override fun binaryOperation(operation: String): (left: Double, right: Double) -> Double =
PowerOperations.POW_OPERATION -> left pow right when (operation) {
else -> super.binaryOperation(operation, left, right) PowerOperations.POW_OPERATION -> ::power
} else -> super.binaryOperation(operation)
}
public override inline fun add(a: Double, b: Double): Double = a + b public override inline fun add(a: Double, b: Double): Double = a + b
public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
@ -130,9 +132,9 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
public override val one: Float public override val one: Float
get() = 1.0f get() = 1.0f
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { public override fun binaryOperation(operation: String): (left: Float, right: Float) -> Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right PowerOperations.POW_OPERATION -> ::power
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation)
} }
public override inline fun add(a: Float, b: Float): Float = a + b public override inline fun add(a: Float, b: Float): Float = a + b