Rework unary/binary operation API
This commit is contained in:
parent
abe68a4fb6
commit
3e7c9d8dce
@ -55,24 +55,23 @@ public sealed class MST {
|
||||
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Unary -> unaryOperation(node.operation)(evaluate(node.value))
|
||||
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
node.left.value.toDouble(),
|
||||
node.right.value.toDouble()
|
||||
)
|
||||
val number = RealField
|
||||
.binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble())
|
||||
|
||||
number(number)
|
||||
}
|
||||
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation)(node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation)(evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,115 +6,111 @@ import kscience.kmath.operations.*
|
||||
* [Algebra] over [MST] nodes.
|
||||
*/
|
||||
public object MstAlgebra : NumericAlgebra<MST> {
|
||||
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||
public override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||
|
||||
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||
public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
||||
MST.Unary(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) }
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MST.Binary(operation, left, right)
|
||||
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
{ left, right -> MST.Binary(operation, left, right) }
|
||||
}
|
||||
|
||||
/**
|
||||
* [Space] over [MST] nodes.
|
||||
*/
|
||||
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST.Numeric by lazy { number(0.0) }
|
||||
public override val zero: MST.Numeric by lazy { number(0.0) }
|
||||
|
||||
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||
public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k))
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
|
||||
MstAlgebra.binaryOperation(operation)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
|
||||
override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Ring] over [MST] nodes.
|
||||
*/
|
||||
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST.Numeric
|
||||
override val zero: MST
|
||||
get() = MstSpace.zero
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override val one: MST.Numeric by lazy { number(1.0) }
|
||||
public override fun number(value: Number): MST = MstSpace.number(value)
|
||||
public override fun symbol(value: String): MST = MstSpace.symbol(value)
|
||||
public override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
|
||||
|
||||
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
|
||||
MstSpace.binaryOperation(operation)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstSpace.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Field] over [MST] nodes.
|
||||
*/
|
||||
public object MstField : Field<MST> {
|
||||
public override val zero: MST.Numeric
|
||||
public override val zero: MST
|
||||
get() = MstRing.zero
|
||||
|
||||
public override val one: MST.Numeric
|
||||
public override val one: MST
|
||||
get() = MstRing.one
|
||||
|
||||
public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
|
||||
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
public override fun symbol(value: String): MST = MstRing.symbol(value)
|
||||
public override fun number(value: Number): MST = MstRing.number(value)
|
||||
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstRing.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
|
||||
MstRing.binaryOperation(operation)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstRing.unaryOperation(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* [ExtendedField] over [MST] nodes.
|
||||
*/
|
||||
public object MstExtendedField : ExtendedField<MST> {
|
||||
override val zero: MST.Numeric
|
||||
public override val zero: MST
|
||||
get() = MstField.zero
|
||||
|
||||
override val one: MST.Numeric
|
||||
public override val one: MST
|
||||
get() = MstField.one
|
||||
|
||||
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||
override fun number(value: Number): MST.Numeric = MstField.number(value)
|
||||
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||
public override fun symbol(value: String): MST = MstField.symbol(value)
|
||||
public override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||
public override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
|
||||
public override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg)
|
||||
public override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
|
||||
public override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
|
||||
public override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
|
||||
public override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg)
|
||||
public override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg)
|
||||
public override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg)
|
||||
public override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg)
|
||||
public override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg)
|
||||
public override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg)
|
||||
public override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
|
||||
public override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||
public override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
|
||||
public override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
|
||||
|
||||
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
|
||||
MstField.binaryOperation(operation)
|
||||
|
||||
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstField.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstField.unaryOperation(operation)
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
||||
MST.Unary(operation, arg)
|
||||
}
|
||||
|
@ -16,10 +16,8 @@ import kotlin.contracts.contract
|
||||
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation)
|
||||
override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||
|
@ -1,6 +1,5 @@
|
||||
package kscience.kmath.asm
|
||||
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.RealField
|
||||
@ -10,44 +9,44 @@ import kotlin.test.assertEquals
|
||||
internal class TestAsmSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { unaryOperation("+")(symbol("x")) }.compile()
|
||||
assertEquals(2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { unaryOperation("-")(symbol("x")) }.compile()
|
||||
assertEquals(-2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { binaryOperation("+")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { unaryOperation("sin")(symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { binaryOperation("-")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { binaryOperation("/")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(1.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = RealField
|
||||
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
|
||||
.mstInField { binaryOperation("power")(symbol("x"), number(2)) }
|
||||
.compile()
|
||||
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
|
@ -1,8 +1,5 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.ast.evaluate
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.ast.parseMath
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.Complex
|
||||
@ -45,12 +42,15 @@ internal class ParserTest {
|
||||
val magicalAlgebra = object : Algebra<String> {
|
||||
override fun symbol(value: String): String = value
|
||||
|
||||
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||
|
||||
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||
"magic" -> "$left ★ $right"
|
||||
else -> throw NotImplementedError()
|
||||
override fun unaryOperation(operation: String): (arg: String) -> String {
|
||||
throw NotImplementedError()
|
||||
}
|
||||
|
||||
override fun binaryOperation(operation: String): (left: String, right: String) -> String =
|
||||
when (operation) {
|
||||
"magic" -> { left, right -> "$left ★ $right" }
|
||||
else -> throw NotImplementedError()
|
||||
}
|
||||
}
|
||||
|
||||
val mst = "magic(a, b)".parseMath()
|
||||
|
@ -7,9 +7,8 @@ import kscience.kmath.operations.*
|
||||
*
|
||||
* @param algebra The algebra to provide for Expressions built.
|
||||
*/
|
||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||
public val algebra: A,
|
||||
) : ExpressionAlgebra<T, Expression<T>> {
|
||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
|
||||
ExpressionAlgebra<T, Expression<T>> {
|
||||
/**
|
||||
* Builds an Expression of constant expression which does not depend on arguments.
|
||||
*/
|
||||
@ -25,19 +24,18 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||
/**
|
||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||
*/
|
||||
public override fun binaryOperation(
|
||||
operation: String,
|
||||
left: Expression<T>,
|
||||
right: Expression<T>,
|
||||
): Expression<T> = Expression { arguments ->
|
||||
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
|
||||
}
|
||||
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
{ left, right ->
|
||||
Expression { arguments ->
|
||||
algebra.binaryOperation(operation)(left.invoke(arguments), right.invoke(arguments))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||
*/
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
|
||||
algebra.unaryOperation(operation, arg.invoke(arguments))
|
||||
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
||||
Expression { arguments -> algebra.unaryOperation(operation)(arg.invoke(arguments)) }
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,7 +50,7 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
* Builds an Expression of addition of two another expressions.
|
||||
*/
|
||||
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
|
||||
|
||||
/**
|
||||
* Builds an Expression of multiplication of expression by number.
|
||||
@ -66,11 +64,11 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.unaryOperation(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||
@ -82,16 +80,16 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
||||
* Builds an Expression of multiplication of two expressions.
|
||||
*/
|
||||
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
|
||||
|
||||
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionSpace>.unaryOperation(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionSpace>.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
@ -101,49 +99,49 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
* Builds an Expression of division an expression by another one.
|
||||
*/
|
||||
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
|
||||
|
||||
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionRing>.unaryOperation(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionRing>.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||
FunctionalExpressionField<T, A>(algebra),
|
||||
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||
public override fun sin(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||
|
||||
public override fun cos(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
|
||||
|
||||
public override fun asin(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||
unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
|
||||
|
||||
public override fun acos(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||
unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
|
||||
|
||||
public override fun atan(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||
unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
|
||||
|
||||
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||
|
||||
public override fun exp(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
|
||||
|
||||
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionField>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionField>.unaryOperation(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionField>.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
|
||||
|
@ -18,10 +18,11 @@ public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
||||
*/
|
||||
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): Matrix<T> = when (operation) {
|
||||
"dot" -> left dot right
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
}
|
||||
public override fun binaryOperation(operation: String): (left: Matrix<T>, right: Matrix<T>) -> Matrix<T> =
|
||||
when (operation) {
|
||||
"dot" -> { left, right -> left dot right }
|
||||
else -> super.binaryOperation(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the dot product of this matrix and another one.
|
||||
|
@ -13,19 +13,19 @@ public annotation class KMathContext
|
||||
*/
|
||||
public interface Algebra<T> {
|
||||
/**
|
||||
* Wrap raw string or variable
|
||||
* Wraps raw string or variable.
|
||||
*/
|
||||
public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
|
||||
|
||||
/**
|
||||
* Dynamic call of unary operation with name [operation] on [arg]
|
||||
* Dynamically dispatches an unary operation with name [operation].
|
||||
*/
|
||||
public fun unaryOperation(operation: String, arg: T): T
|
||||
public fun unaryOperation(operation: String): (arg: T) -> T
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right]
|
||||
* Dynamically dispatches a binary operation with name [operation].
|
||||
*/
|
||||
public fun binaryOperation(operation: String, left: T, right: T): T
|
||||
public fun binaryOperation(operation: String): (left: T, right: T) -> T
|
||||
}
|
||||
|
||||
/**
|
||||
@ -40,16 +40,28 @@ public interface NumericAlgebra<T> : Algebra<T> {
|
||||
public fun number(value: Number): T
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
|
||||
* Dynamically dispatches a binary operation with name [operation] where the left argument is [Number].
|
||||
*/
|
||||
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
binaryOperation(operation, number(left), right)
|
||||
public fun leftSideNumberOperation(operation: String): (left: Number, right: T) -> T =
|
||||
{ l, r -> binaryOperation(operation)(number(l), r) }
|
||||
|
||||
// /**
|
||||
// * Dynamically calls a binary operation with name [operation] where the left argument is [Number].
|
||||
// */
|
||||
// public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
// leftSideNumberOperation(operation)(left, right)
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
|
||||
* Dynamically dispatches a binary operation with name [operation] where the right argument is [Number].
|
||||
*/
|
||||
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
leftSideNumberOperation(operation, right, left)
|
||||
public fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
|
||||
{ l, r -> binaryOperation(operation)(l, number(r)) }
|
||||
|
||||
// /**
|
||||
// * Dynamically calls a binary operation with name [operation] where the right argument is [Number].
|
||||
// */
|
||||
// public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
// rightSideNumberOperation(operation)(left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -146,15 +158,15 @@ public interface SpaceOperations<T> : Algebra<T> {
|
||||
*/
|
||||
public operator fun Number.times(b: T): T = b * this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
PLUS_OPERATION -> arg
|
||||
MINUS_OPERATION -> -arg
|
||||
override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
|
||||
PLUS_OPERATION -> { arg -> arg }
|
||||
MINUS_OPERATION -> { arg -> -arg }
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
PLUS_OPERATION -> add(left, right)
|
||||
MINUS_OPERATION -> left - right
|
||||
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
PLUS_OPERATION -> ::add
|
||||
MINUS_OPERATION -> { left, right -> left - right }
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
@ -207,9 +219,9 @@ public interface RingOperations<T> : SpaceOperations<T> {
|
||||
*/
|
||||
public operator fun T.times(b: T): T = multiply(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
TIMES_OPERATION -> multiply(left, right)
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
TIMES_OPERATION -> ::multiply
|
||||
else -> super.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public companion object {
|
||||
@ -234,20 +246,6 @@ public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||
|
||||
override fun number(value: Number): T = one * value.toDouble()
|
||||
|
||||
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.leftSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Addition of element and scalar.
|
||||
*
|
||||
@ -308,9 +306,9 @@ public interface FieldOperations<T> : RingOperations<T> {
|
||||
*/
|
||||
public operator fun T.div(b: T): T = divide(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
DIV_OPERATION -> divide(left, right)
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
DIV_OPERATION -> ::divide
|
||||
else -> super.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public companion object {
|
||||
|
@ -15,23 +15,23 @@ public interface ExtendedFieldOperations<T> :
|
||||
public override fun tan(arg: T): T = sin(arg) / cos(arg)
|
||||
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
TrigonometricOperations.COS_OPERATION -> cos(arg)
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(arg)
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(arg)
|
||||
TrigonometricOperations.ACOS_OPERATION -> acos(arg)
|
||||
TrigonometricOperations.ASIN_OPERATION -> asin(arg)
|
||||
TrigonometricOperations.ATAN_OPERATION -> atan(arg)
|
||||
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
|
||||
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
|
||||
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
|
||||
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
|
||||
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
|
||||
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
|
||||
PowerOperations.SQRT_OPERATION -> sqrt(arg)
|
||||
ExponentialOperations.EXP_OPERATION -> exp(arg)
|
||||
ExponentialOperations.LN_OPERATION -> ln(arg)
|
||||
else -> super.unaryOperation(operation, arg)
|
||||
public override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
|
||||
TrigonometricOperations.COS_OPERATION -> ::cos
|
||||
TrigonometricOperations.SIN_OPERATION -> ::sin
|
||||
TrigonometricOperations.TAN_OPERATION -> ::tan
|
||||
TrigonometricOperations.ACOS_OPERATION -> ::acos
|
||||
TrigonometricOperations.ASIN_OPERATION -> ::asin
|
||||
TrigonometricOperations.ATAN_OPERATION -> ::atan
|
||||
HyperbolicOperations.COSH_OPERATION -> ::cosh
|
||||
HyperbolicOperations.SINH_OPERATION -> ::sinh
|
||||
HyperbolicOperations.TANH_OPERATION -> ::tanh
|
||||
HyperbolicOperations.ACOSH_OPERATION -> ::acosh
|
||||
HyperbolicOperations.ASINH_OPERATION -> ::asinh
|
||||
HyperbolicOperations.ATANH_OPERATION -> ::atanh
|
||||
PowerOperations.SQRT_OPERATION -> ::sqrt
|
||||
ExponentialOperations.EXP_OPERATION -> ::exp
|
||||
ExponentialOperations.LN_OPERATION -> ::ln
|
||||
else -> super.unaryOperation(operation)
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,10 +46,11 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
|
||||
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
|
||||
|
||||
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> power(left, right)
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
public override fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
|
||||
when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.rightSideNumberOperation(operation)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -80,10 +81,11 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
public override val one: Double
|
||||
get() = 1.0
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> left pow right
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
}
|
||||
public override fun binaryOperation(operation: String): (left: Double, right: Double) -> Double =
|
||||
when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public override inline fun add(a: Double, b: Double): Double = a + b
|
||||
public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||
@ -130,9 +132,9 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
public override val one: Float
|
||||
get() = 1.0f
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> left pow right
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperation(operation: String): (left: Float, right: Float) -> Float = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.binaryOperation(operation)
|
||||
}
|
||||
|
||||
public override inline fun add(a: Float, b: Float): Float = a + b
|
||||
|
Loading…
Reference in New Issue
Block a user