Rework unary/binary operation API

This commit is contained in:
Iaroslav Postovalov 2020-11-28 13:33:05 +07:00
parent abe68a4fb6
commit 3e7c9d8dce
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
9 changed files with 194 additions and 203 deletions

View File

@ -55,24 +55,23 @@ public sealed class MST {
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Unary -> unaryOperation(node.operation)(evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
this !is NumericAlgebra -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
node.left.value.toDouble(),
node.right.value.toDouble()
)
val number = RealField
.binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble())
number(number)
}
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric -> leftSideNumberOperation(node.operation)(node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
}
}

View File

@ -6,115 +6,111 @@ import kscience.kmath.operations.*
* [Algebra] over [MST] nodes.
*/
public object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
public override fun number(value: Number): MST.Numeric = MST.Numeric(value)
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
MST.Unary(operation, arg)
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) }
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MST.Binary(operation, left, right)
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
{ left, right -> MST.Binary(operation, left, right) }
}
/**
* [Space] over [MST] nodes.
*/
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST.Numeric by lazy { number(0.0) }
public override val zero: MST.Numeric by lazy { number(0.0) }
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstAlgebra.binaryOperation(operation, left, right)
override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
MstAlgebra.binaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation)
}
/**
* [Ring] over [MST] nodes.
*/
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST.Numeric
override val zero: MST
get() = MstSpace.zero
override val one: MST = number(1.0)
override val one: MST.Numeric by lazy { number(1.0) }
public override fun number(value: Number): MST = MstSpace.number(value)
public override fun symbol(value: String): MST = MstSpace.symbol(value)
public override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
public override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
MstSpace.binaryOperation(operation)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstSpace.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation)
}
/**
* [Field] over [MST] nodes.
*/
public object MstField : Field<MST> {
public override val zero: MST.Numeric
public override val zero: MST
get() = MstRing.zero
public override val one: MST.Numeric
public override val one: MST
get() = MstRing.one
public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
public override fun symbol(value: String): MST = MstRing.symbol(value)
public override fun number(value: Number): MST = MstRing.number(value)
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstRing.binaryOperation(operation, left, right)
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
MstRing.binaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstRing.unaryOperation(operation)
}
/**
* [ExtendedField] over [MST] nodes.
*/
public object MstExtendedField : ExtendedField<MST> {
override val zero: MST.Numeric
public override val zero: MST
get() = MstField.zero
override val one: MST.Numeric
public override val one: MST
get() = MstField.one
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
override fun number(value: Number): MST.Numeric = MstField.number(value)
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
public override fun symbol(value: String): MST = MstField.symbol(value)
public override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
public override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
public override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg)
public override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
public override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
public override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
public override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg)
public override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg)
public override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg)
public override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg)
public override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg)
public override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg)
public override fun add(a: MST, b: MST): MST = MstField.add(a, b)
public override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
public override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
public override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
public override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
public override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
override fun power(arg: MST, pow: Number): MST.Binary =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST =
MstField.binaryOperation(operation)
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstField.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstField.unaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
MST.Unary(operation, arg)
}

View File

@ -16,10 +16,8 @@ import kotlin.contracts.contract
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation)
override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)

View File

@ -1,6 +1,5 @@
package kscience.kmath.asm
import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInField
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField
@ -10,44 +9,44 @@ import kotlin.test.assertEquals
internal class TestAsmSpecialization {
@Test
fun testUnaryPlus() {
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
val expr = RealField.mstInField { unaryOperation("+")(symbol("x")) }.compile()
assertEquals(2.0, expr("x" to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
val expr = RealField.mstInField { unaryOperation("-")(symbol("x")) }.compile()
assertEquals(-2.0, expr("x" to 2.0))
}
@Test
fun testAdd() {
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
val expr = RealField.mstInField { binaryOperation("+")(symbol("x"), symbol("x")) }.compile()
assertEquals(4.0, expr("x" to 2.0))
}
@Test
fun testSine() {
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
val expr = RealField.mstInField { unaryOperation("sin")(symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 0.0))
}
@Test
fun testMinus() {
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
val expr = RealField.mstInField { binaryOperation("-")(symbol("x"), symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 2.0))
}
@Test
fun testDivide() {
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
val expr = RealField.mstInField { binaryOperation("/")(symbol("x"), symbol("x")) }.compile()
assertEquals(1.0, expr("x" to 2.0))
}
@Test
fun testPower() {
val expr = RealField
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
.mstInField { binaryOperation("power")(symbol("x"), number(2)) }
.compile()
assertEquals(4.0, expr("x" to 2.0))

View File

@ -1,8 +1,5 @@
package kscience.kmath.ast
import kscience.kmath.ast.evaluate
import kscience.kmath.ast.mstInField
import kscience.kmath.ast.parseMath
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.Complex
@ -45,12 +42,15 @@ internal class ParserTest {
val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
override fun unaryOperation(operation: String): (arg: String) -> String {
throw NotImplementedError()
}
override fun binaryOperation(operation: String): (left: String, right: String) -> String =
when (operation) {
"magic" -> { left, right -> "$left$right" }
else -> throw NotImplementedError()
}
}
val mst = "magic(a, b)".parseMath()

View File

@ -7,9 +7,8 @@ import kscience.kmath.operations.*
*
* @param algebra The algebra to provide for Expressions built.
*/
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
public val algebra: A,
) : ExpressionAlgebra<T, Expression<T>> {
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
ExpressionAlgebra<T, Expression<T>> {
/**
* Builds an Expression of constant expression which does not depend on arguments.
*/
@ -25,19 +24,18 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
public override fun binaryOperation(
operation: String,
left: Expression<T>,
right: Expression<T>,
): Expression<T> = Expression { arguments ->
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
}
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
{ left, right ->
Expression { arguments ->
algebra.binaryOperation(operation)(left.invoke(arguments), right.invoke(arguments))
}
}
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
algebra.unaryOperation(operation, arg.invoke(arguments))
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
Expression { arguments -> algebra.unaryOperation(operation)(arg.invoke(arguments)) }
}
}
@ -52,7 +50,7 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
* Builds an Expression of addition of two another expressions.
*/
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
/**
* Builds an Expression of multiplication of expression by number.
@ -66,11 +64,11 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation)
}
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
@ -82,16 +80,16 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
* Builds an Expression of multiplication of two expressions.
*/
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation)
}
public open class FunctionalExpressionField<T, A>(algebra: A) :
@ -101,49 +99,49 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
* Builds an Expression of division an expression by another one.
*/
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation)
}
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
FunctionalExpressionField<T, A>(algebra),
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
public override fun sin(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
public override fun cos(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
public override fun asin(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
public override fun acos(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
public override fun atan(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
public override fun exp(arg: Expression<T>): Expression<T> =
unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.binaryOperation(operation)
}
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =

View File

@ -18,10 +18,11 @@ public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
*/
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): Matrix<T> = when (operation) {
"dot" -> left dot right
else -> super.binaryOperation(operation, left, right)
}
public override fun binaryOperation(operation: String): (left: Matrix<T>, right: Matrix<T>) -> Matrix<T> =
when (operation) {
"dot" -> { left, right -> left dot right }
else -> super.binaryOperation(operation)
}
/**
* Computes the dot product of this matrix and another one.

View File

@ -13,19 +13,19 @@ public annotation class KMathContext
*/
public interface Algebra<T> {
/**
* Wrap raw string or variable
* Wraps raw string or variable.
*/
public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
/**
* Dynamic call of unary operation with name [operation] on [arg]
* Dynamically dispatches an unary operation with name [operation].
*/
public fun unaryOperation(operation: String, arg: T): T
public fun unaryOperation(operation: String): (arg: T) -> T
/**
* Dynamic call of binary operation [operation] on [left] and [right]
* Dynamically dispatches a binary operation with name [operation].
*/
public fun binaryOperation(operation: String, left: T, right: T): T
public fun binaryOperation(operation: String): (left: T, right: T) -> T
}
/**
@ -40,16 +40,28 @@ public interface NumericAlgebra<T> : Algebra<T> {
public fun number(value: Number): T
/**
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
* Dynamically dispatches a binary operation with name [operation] where the left argument is [Number].
*/
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
binaryOperation(operation, number(left), right)
public fun leftSideNumberOperation(operation: String): (left: Number, right: T) -> T =
{ l, r -> binaryOperation(operation)(number(l), r) }
// /**
// * Dynamically calls a binary operation with name [operation] where the left argument is [Number].
// */
// public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
// leftSideNumberOperation(operation)(left, right)
/**
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
* Dynamically dispatches a binary operation with name [operation] where the right argument is [Number].
*/
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left)
public fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
{ l, r -> binaryOperation(operation)(l, number(r)) }
// /**
// * Dynamically calls a binary operation with name [operation] where the right argument is [Number].
// */
// public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
// rightSideNumberOperation(operation)(left, right)
}
/**
@ -146,15 +158,15 @@ public interface SpaceOperations<T> : Algebra<T> {
*/
public operator fun Number.times(b: T): T = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
PLUS_OPERATION -> arg
MINUS_OPERATION -> -arg
override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> { arg -> arg }
MINUS_OPERATION -> { arg -> -arg }
else -> error("Unary operation $operation not defined in $this")
}
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
PLUS_OPERATION -> add(left, right)
MINUS_OPERATION -> left - right
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
PLUS_OPERATION -> ::add
MINUS_OPERATION -> { left, right -> left - right }
else -> error("Binary operation $operation not defined in $this")
}
@ -207,9 +219,9 @@ public interface RingOperations<T> : SpaceOperations<T> {
*/
public operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
TIMES_OPERATION -> multiply(left, right)
else -> super.binaryOperation(operation, left, right)
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
TIMES_OPERATION -> ::multiply
else -> super.binaryOperation(operation)
}
public companion object {
@ -234,20 +246,6 @@ public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble()
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right)
}
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
/**
* Addition of element and scalar.
*
@ -308,9 +306,9 @@ public interface FieldOperations<T> : RingOperations<T> {
*/
public operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
DIV_OPERATION -> divide(left, right)
else -> super.binaryOperation(operation, left, right)
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
DIV_OPERATION -> ::divide
else -> super.binaryOperation(operation)
}
public companion object {

View File

@ -15,23 +15,23 @@ public interface ExtendedFieldOperations<T> :
public override fun tan(arg: T): T = sin(arg) / cos(arg)
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
public override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg)
TrigonometricOperations.TAN_OPERATION -> tan(arg)
TrigonometricOperations.ACOS_OPERATION -> acos(arg)
TrigonometricOperations.ASIN_OPERATION -> asin(arg)
TrigonometricOperations.ATAN_OPERATION -> atan(arg)
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
PowerOperations.SQRT_OPERATION -> sqrt(arg)
ExponentialOperations.EXP_OPERATION -> exp(arg)
ExponentialOperations.LN_OPERATION -> ln(arg)
else -> super.unaryOperation(operation, arg)
public override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
TrigonometricOperations.COS_OPERATION -> ::cos
TrigonometricOperations.SIN_OPERATION -> ::sin
TrigonometricOperations.TAN_OPERATION -> ::tan
TrigonometricOperations.ACOS_OPERATION -> ::acos
TrigonometricOperations.ASIN_OPERATION -> ::asin
TrigonometricOperations.ATAN_OPERATION -> ::atan
HyperbolicOperations.COSH_OPERATION -> ::cosh
HyperbolicOperations.SINH_OPERATION -> ::sinh
HyperbolicOperations.TANH_OPERATION -> ::tanh
HyperbolicOperations.ACOSH_OPERATION -> ::acosh
HyperbolicOperations.ASINH_OPERATION -> ::asinh
HyperbolicOperations.ATANH_OPERATION -> ::atanh
PowerOperations.SQRT_OPERATION -> ::sqrt
ExponentialOperations.EXP_OPERATION -> ::exp
ExponentialOperations.LN_OPERATION -> ::ln
else -> super.unaryOperation(operation)
}
}
@ -46,10 +46,11 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right)
}
public override fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
when (operation) {
PowerOperations.POW_OPERATION -> ::power
else -> super.rightSideNumberOperation(operation)
}
}
/**
@ -80,10 +81,11 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
public override val one: Double
get() = 1.0
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
public override fun binaryOperation(operation: String): (left: Double, right: Double) -> Double =
when (operation) {
PowerOperations.POW_OPERATION -> ::power
else -> super.binaryOperation(operation)
}
public override inline fun add(a: Double, b: Double): Double = a + b
public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
@ -130,9 +132,9 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
public override val one: Float
get() = 1.0f
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
public override fun binaryOperation(operation: String): (left: Float, right: Float) -> Float = when (operation) {
PowerOperations.POW_OPERATION -> ::power
else -> super.binaryOperation(operation)
}
public override inline fun add(a: Float, b: Float): Float = a + b