From 3e7c9d8dce15fd5c17747aba953147e05c13c4f4 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sat, 28 Nov 2020 13:33:05 +0700 Subject: [PATCH 1/6] Rework unary/binary operation API --- .../kotlin/kscience/kmath/ast/MST.kt | 19 ++- .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 126 +++++++++--------- .../kscience/kmath/ast/MstExpression.kt | 6 +- .../kmath/asm/TestAsmSpecialization.kt | 15 +-- .../kotlin/kscience/kmath/ast/ParserTest.kt | 16 +-- .../FunctionalExpressionAlgebra.kt | 76 +++++------ .../kscience/kmath/linear/MatrixContext.kt | 9 +- .../kscience/kmath/operations/Algebra.kt | 72 +++++----- .../kmath/operations/NumberAlgebra.kt | 58 ++++---- 9 files changed, 194 insertions(+), 203 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt index f312323b9..cc3d38f9c 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt @@ -55,24 +55,23 @@ public sealed class MST { public fun Algebra.evaluate(node: MST): T = when (node) { is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this") + is MST.Symbolic -> symbol(node.value) - is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) + is MST.Unary -> unaryOperation(node.operation)(evaluate(node.value)) + is MST.Binary -> when { - this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + this !is NumericAlgebra -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right)) node.left is MST.Numeric && node.right is MST.Numeric -> { - val number = RealField.binaryOperation( - node.operation, - node.left.value.toDouble(), - node.right.value.toDouble() - ) + val number = RealField + .binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble()) number(number) } - node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) - node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) - else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + node.left is MST.Numeric -> leftSideNumberOperation(node.operation)(node.left.value, evaluate(node.right)) + node.right is MST.Numeric -> rightSideNumberOperation(node.operation)(evaluate(node.left), node.right.value) + else -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right)) } } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 6ee6ab9af..cea708342 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -6,115 +6,111 @@ import kscience.kmath.operations.* * [Algebra] over [MST] nodes. */ public object MstAlgebra : NumericAlgebra { - override fun number(value: Number): MST.Numeric = MST.Numeric(value) + public override fun number(value: Number): MST.Numeric = MST.Numeric(value) - override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) + public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = - MST.Unary(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) } - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MST.Binary(operation, left, right) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = + { left, right -> MST.Binary(operation, left, right) } } /** * [Space] over [MST] nodes. */ public object MstSpace : Space, NumericAlgebra { - override val zero: MST.Numeric by lazy { number(0.0) } + public override val zero: MST.Numeric by lazy { number(0.0) } - override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) - override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) + public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) + override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) + override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstAlgebra.binaryOperation(operation, left, right) + override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstAlgebra.binaryOperation(operation) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) } /** * [Ring] over [MST] nodes. */ public object MstRing : Ring, NumericAlgebra { - override val zero: MST.Numeric + override val zero: MST get() = MstSpace.zero + override val one: MST = number(1.0) - override val one: MST.Numeric by lazy { number(1.0) } + public override fun number(value: Number): MST = MstSpace.number(value) + public override fun symbol(value: String): MST = MstSpace.symbol(value) + public override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) + public override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) - override fun number(value: Number): MST.Numeric = MstSpace.number(value) - override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) - override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstSpace.binaryOperation(operation) - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstSpace.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) } /** * [Field] over [MST] nodes. */ public object MstField : Field { - public override val zero: MST.Numeric + public override val zero: MST get() = MstRing.zero - public override val one: MST.Numeric + public override val one: MST get() = MstRing.one - public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) - public override fun number(value: Number): MST.Numeric = MstRing.number(value) - public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) - public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + public override fun symbol(value: String): MST = MstRing.symbol(value) + public override fun number(value: Number): MST = MstRing.number(value) + public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) + public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) + public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) - public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstRing.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstRing.binaryOperation(operation) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstRing.unaryOperation(operation) } /** * [ExtendedField] over [MST] nodes. */ public object MstExtendedField : ExtendedField { - override val zero: MST.Numeric + public override val zero: MST get() = MstField.zero - override val one: MST.Numeric + public override val one: MST get() = MstField.one - override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) - override fun number(value: Number): MST.Numeric = MstField.number(value) - override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) - override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) - override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) - override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) - override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) - override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) - override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) - override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) - override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) - override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) - override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + public override fun symbol(value: String): MST = MstField.symbol(value) + public override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) + public override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) + public override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg) + public override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) + public override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) + public override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) + public override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg) + public override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg) + public override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg) + public override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg) + public override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg) + public override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg) + public override fun add(a: MST, b: MST): MST = MstField.add(a, b) + public override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) + public override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) + public override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) + public override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) + public override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) - override fun power(arg: MST, pow: Number): MST.Binary = - binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + MstField.binaryOperation(operation) - override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) - override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstField.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstField.unaryOperation(operation) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = + MST.Unary(operation, arg) } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt index f68e3f5f8..8a99fd2f4 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -16,10 +16,8 @@ import kotlin.contracts.contract public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T = - algebra.binaryOperation(operation, left, right) + override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation) + override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation) @Suppress("UNCHECKED_CAST") override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt index 8f8175acd..eb5a1c8f3 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt @@ -1,6 +1,5 @@ package kscience.kmath.asm -import kscience.kmath.asm.compile import kscience.kmath.ast.mstInField import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField @@ -10,44 +9,44 @@ import kotlin.test.assertEquals internal class TestAsmSpecialization { @Test fun testUnaryPlus() { - val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperation("+")(symbol("x")) }.compile() assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { - val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperation("-")(symbol("x")) }.compile() assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { - val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperation("+")(symbol("x"), symbol("x")) }.compile() assertEquals(4.0, expr("x" to 2.0)) } @Test fun testSine() { - val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperation("sin")(symbol("x")) }.compile() assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { - val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperation("-")(symbol("x"), symbol("x")) }.compile() assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { - val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperation("/")(symbol("x"), symbol("x")) }.compile() assertEquals(1.0, expr("x" to 2.0)) } @Test fun testPower() { val expr = RealField - .mstInField { binaryOperation("power", symbol("x"), number(2)) } + .mstInField { binaryOperation("power")(symbol("x"), number(2)) } .compile() assertEquals(4.0, expr("x" to 2.0)) diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt index 2dc24597e..e2029ce19 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt @@ -1,8 +1,5 @@ package kscience.kmath.ast -import kscience.kmath.ast.evaluate -import kscience.kmath.ast.mstInField -import kscience.kmath.ast.parseMath import kscience.kmath.expressions.invoke import kscience.kmath.operations.Algebra import kscience.kmath.operations.Complex @@ -45,12 +42,15 @@ internal class ParserTest { val magicalAlgebra = object : Algebra { override fun symbol(value: String): String = value - override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() - - override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) { - "magic" -> "$left ★ $right" - else -> throw NotImplementedError() + override fun unaryOperation(operation: String): (arg: String) -> String { + throw NotImplementedError() } + + override fun binaryOperation(operation: String): (left: String, right: String) -> String = + when (operation) { + "magic" -> { left, right -> "$left ★ $right" } + else -> throw NotImplementedError() + } } val mst = "magic(a, b)".parseMath() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 0630e8e4b..e2413dea8 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -7,9 +7,8 @@ import kscience.kmath.operations.* * * @param algebra The algebra to provide for Expressions built. */ -public abstract class FunctionalExpressionAlgebra>( - public val algebra: A, -) : ExpressionAlgebra> { +public abstract class FunctionalExpressionAlgebra>(public val algebra: A) : + ExpressionAlgebra> { /** * Builds an Expression of constant expression which does not depend on arguments. */ @@ -25,19 +24,18 @@ public abstract class FunctionalExpressionAlgebra>( /** * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. */ - public override fun binaryOperation( - operation: String, - left: Expression, - right: Expression, - ): Expression = Expression { arguments -> - algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments)) - } + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + { left, right -> + Expression { arguments -> + algebra.binaryOperation(operation)(left.invoke(arguments), right.invoke(arguments)) + } + } /** * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. */ - public override fun unaryOperation(operation: String, arg: Expression): Expression = Expression { arguments -> - algebra.unaryOperation(operation, arg.invoke(arguments)) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = { arg -> + Expression { arguments -> algebra.unaryOperation(operation)(arg.invoke(arguments)) } } } @@ -52,7 +50,7 @@ public open class FunctionalExpressionSpace>(algebra: A) : * Builds an Expression of addition of two another expressions. */ public override fun add(a: Expression, b: Expression): Expression = - binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) /** * Builds an Expression of multiplication of expression by number. @@ -66,11 +64,11 @@ public open class FunctionalExpressionSpace>(algebra: A) : public operator fun T.plus(arg: Expression): Expression = arg + this public operator fun T.minus(arg: Expression): Expression = arg - this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), @@ -82,16 +80,16 @@ public open class FunctionalExpressionRing(algebra: A) : FunctionalExpress * Builds an Expression of multiplication of two expressions. */ public override fun multiply(a: Expression, b: Expression): Expression = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) + binaryOperation(RingOperations.TIMES_OPERATION)(a, b) public operator fun Expression.times(arg: T): Expression = this * const(arg) public operator fun T.times(arg: Expression): Expression = arg * this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public open class FunctionalExpressionField(algebra: A) : @@ -101,49 +99,49 @@ public open class FunctionalExpressionField(algebra: A) : * Builds an Expression of division an expression by another one. */ public override fun divide(a: Expression, b: Expression): Expression = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) + binaryOperation(FieldOperations.DIV_OPERATION)(a, b) public operator fun Expression.div(arg: T): Expression = this / const(arg) public operator fun T.div(arg: Expression): Expression = arg / this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public open class FunctionalExpressionExtendedField(algebra: A) : FunctionalExpressionField(algebra), ExtendedField> where A : ExtendedField, A : NumericAlgebra { public override fun sin(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) public override fun cos(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) public override fun asin(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) public override fun acos(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) public override fun atan(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) + unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) public override fun power(arg: Expression, pow: Number): Expression = - binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) public override fun exp(arg: Expression): Expression = - unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + public override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: Expression) -> Expression = + super.unaryOperation(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperation(operation) } public inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index f4dbce89a..174f62c12 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -18,10 +18,11 @@ public interface MatrixContext : SpaceOperations> { */ public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix - public override fun binaryOperation(operation: String, left: Matrix, right: Matrix): Matrix = when (operation) { - "dot" -> left dot right - else -> super.binaryOperation(operation, left, right) - } + public override fun binaryOperation(operation: String): (left: Matrix, right: Matrix) -> Matrix = + when (operation) { + "dot" -> { left, right -> left dot right } + else -> super.binaryOperation(operation) + } /** * Computes the dot product of this matrix and another one. diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 12a45615a..19ca1a4cc 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -13,19 +13,19 @@ public annotation class KMathContext */ public interface Algebra { /** - * Wrap raw string or variable + * Wraps raw string or variable. */ public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") /** - * Dynamic call of unary operation with name [operation] on [arg] + * Dynamically dispatches an unary operation with name [operation]. */ - public fun unaryOperation(operation: String, arg: T): T + public fun unaryOperation(operation: String): (arg: T) -> T /** - * Dynamic call of binary operation [operation] on [left] and [right] + * Dynamically dispatches a binary operation with name [operation]. */ - public fun binaryOperation(operation: String, left: T, right: T): T + public fun binaryOperation(operation: String): (left: T, right: T) -> T } /** @@ -40,16 +40,28 @@ public interface NumericAlgebra : Algebra { public fun number(value: Number): T /** - * Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. + * Dynamically dispatches a binary operation with name [operation] where the left argument is [Number]. */ - public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = - binaryOperation(operation, number(left), right) + public fun leftSideNumberOperation(operation: String): (left: Number, right: T) -> T = + { l, r -> binaryOperation(operation)(number(l), r) } + +// /** +// * Dynamically calls a binary operation with name [operation] where the left argument is [Number]. +// */ +// public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = +// leftSideNumberOperation(operation)(left, right) /** - * Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. + * Dynamically dispatches a binary operation with name [operation] where the right argument is [Number]. */ - public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = - leftSideNumberOperation(operation, right, left) + public fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T = + { l, r -> binaryOperation(operation)(l, number(r)) } + +// /** +// * Dynamically calls a binary operation with name [operation] where the right argument is [Number]. +// */ +// public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = +// rightSideNumberOperation(operation)(left, right) } /** @@ -146,15 +158,15 @@ public interface SpaceOperations : Algebra { */ public operator fun Number.times(b: T): T = b * this - override fun unaryOperation(operation: String, arg: T): T = when (operation) { - PLUS_OPERATION -> arg - MINUS_OPERATION -> -arg + override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) { + PLUS_OPERATION -> { arg -> arg } + MINUS_OPERATION -> { arg -> -arg } else -> error("Unary operation $operation not defined in $this") } - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - PLUS_OPERATION -> add(left, right) - MINUS_OPERATION -> left - right + override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { + PLUS_OPERATION -> ::add + MINUS_OPERATION -> { left, right -> left - right } else -> error("Binary operation $operation not defined in $this") } @@ -207,9 +219,9 @@ public interface RingOperations : SpaceOperations { */ public operator fun T.times(b: T): T = multiply(this, b) - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - TIMES_OPERATION -> multiply(left, right) - else -> super.binaryOperation(operation, left, right) + override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { + TIMES_OPERATION -> ::multiply + else -> super.binaryOperation(operation) } public companion object { @@ -234,20 +246,6 @@ public interface Ring : Space, RingOperations, NumericAlgebra { override fun number(value: Number): T = one * value.toDouble() - override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.leftSideNumberOperation(operation, left, right) - } - - override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.rightSideNumberOperation(operation, left, right) - } - /** * Addition of element and scalar. * @@ -308,9 +306,9 @@ public interface FieldOperations : RingOperations { */ public operator fun T.div(b: T): T = divide(this, b) - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - DIV_OPERATION -> divide(left, right) - else -> super.binaryOperation(operation, left, right) + override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { + DIV_OPERATION -> ::divide + else -> super.binaryOperation(operation) } public companion object { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt index a2b33a0c4..611222d34 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt @@ -15,23 +15,23 @@ public interface ExtendedFieldOperations : public override fun tan(arg: T): T = sin(arg) / cos(arg) public override fun tanh(arg: T): T = sinh(arg) / cosh(arg) - public override fun unaryOperation(operation: String, arg: T): T = when (operation) { - TrigonometricOperations.COS_OPERATION -> cos(arg) - TrigonometricOperations.SIN_OPERATION -> sin(arg) - TrigonometricOperations.TAN_OPERATION -> tan(arg) - TrigonometricOperations.ACOS_OPERATION -> acos(arg) - TrigonometricOperations.ASIN_OPERATION -> asin(arg) - TrigonometricOperations.ATAN_OPERATION -> atan(arg) - HyperbolicOperations.COSH_OPERATION -> cosh(arg) - HyperbolicOperations.SINH_OPERATION -> sinh(arg) - HyperbolicOperations.TANH_OPERATION -> tanh(arg) - HyperbolicOperations.ACOSH_OPERATION -> acosh(arg) - HyperbolicOperations.ASINH_OPERATION -> asinh(arg) - HyperbolicOperations.ATANH_OPERATION -> atanh(arg) - PowerOperations.SQRT_OPERATION -> sqrt(arg) - ExponentialOperations.EXP_OPERATION -> exp(arg) - ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + public override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) { + TrigonometricOperations.COS_OPERATION -> ::cos + TrigonometricOperations.SIN_OPERATION -> ::sin + TrigonometricOperations.TAN_OPERATION -> ::tan + TrigonometricOperations.ACOS_OPERATION -> ::acos + TrigonometricOperations.ASIN_OPERATION -> ::asin + TrigonometricOperations.ATAN_OPERATION -> ::atan + HyperbolicOperations.COSH_OPERATION -> ::cosh + HyperbolicOperations.SINH_OPERATION -> ::sinh + HyperbolicOperations.TANH_OPERATION -> ::tanh + HyperbolicOperations.ACOSH_OPERATION -> ::acosh + HyperbolicOperations.ASINH_OPERATION -> ::asinh + HyperbolicOperations.ATANH_OPERATION -> ::atanh + PowerOperations.SQRT_OPERATION -> ::sqrt + ExponentialOperations.EXP_OPERATION -> ::exp + ExponentialOperations.LN_OPERATION -> ::ln + else -> super.unaryOperation(operation) } } @@ -46,10 +46,11 @@ public interface ExtendedField : ExtendedFieldOperations, Field { public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 - public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - PowerOperations.POW_OPERATION -> power(left, right) - else -> super.rightSideNumberOperation(operation, left, right) - } + public override fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.rightSideNumberOperation(operation) + } } /** @@ -80,10 +81,11 @@ public object RealField : ExtendedField, Norm { public override val one: Double get() = 1.0 - public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) - } + public override fun binaryOperation(operation: String): (left: Double, right: Double) -> Double = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperation(operation) + } public override inline fun add(a: Double, b: Double): Double = a + b public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() @@ -130,9 +132,9 @@ public object FloatField : ExtendedField, Norm { public override val one: Float get() = 1.0f - public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) + public override fun binaryOperation(operation: String): (left: Float, right: Float) -> Float = when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperation(operation) } public override inline fun add(a: Float, b: Float): Float = a + b From e447a15304f2c1b7df9ad00c48a721db9f92edc6 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sat, 28 Nov 2020 13:42:18 +0700 Subject: [PATCH 2/6] Make MstAlgebra operations return specific types --- .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 101 +++++++++--------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index cea708342..44a7f20b9 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -7,10 +7,10 @@ import kscience.kmath.operations.* */ public object MstAlgebra : NumericAlgebra { public override fun number(value: Number): MST.Numeric = MST.Numeric(value) - public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) - public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) } + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = + { arg -> MST.Unary(operation, arg) } public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = { left, right -> MST.Binary(operation, left, right) } @@ -25,92 +25,95 @@ public object MstSpace : Space, NumericAlgebra { public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) + override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) - override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstAlgebra.binaryOperation(operation) - override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) + override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstAlgebra.unaryOperation(operation) } /** * [Ring] over [MST] nodes. */ public object MstRing : Ring, NumericAlgebra { - override val zero: MST + override val zero: MST.Numeric get() = MstSpace.zero - override val one: MST = number(1.0) - public override fun number(value: Number): MST = MstSpace.number(value) - public override fun symbol(value: String): MST = MstSpace.symbol(value) - public override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) + override val one: MST.Numeric by lazy { number(1.0) } - public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + public override fun number(value: Number): MST.Numeric = MstSpace.number(value) + public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) + public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) + + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstSpace.binaryOperation(operation) - public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstAlgebra.unaryOperation(operation) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperation(operation) } /** * [Field] over [MST] nodes. */ public object MstField : Field { - public override val zero: MST + public override val zero: MST.Numeric get() = MstRing.zero - public override val one: MST + public override val one: MST.Numeric get() = MstRing.one - public override fun symbol(value: String): MST = MstRing.symbol(value) - public override fun number(value: Number): MST = MstRing.number(value) - public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) + public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) + public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) - public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperation(operation) - public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstRing.unaryOperation(operation) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstRing.unaryOperation(operation) } /** * [ExtendedField] over [MST] nodes. */ public object MstExtendedField : ExtendedField { - public override val zero: MST + public override val zero: MST.Numeric get() = MstField.zero - public override val one: MST + public override val one: MST.Numeric get() = MstField.one public override fun symbol(value: String): MST = MstField.symbol(value) - public override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) - public override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) - public override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg) - public override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) - public override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) - public override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) - public override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg) - public override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg) - public override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg) - public override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg) - public override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg) - public override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg) - public override fun add(a: MST, b: MST): MST = MstField.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) - public override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) - public override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) - public override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) + public override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg) + public override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg) + public override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg) + public override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg) + public override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg) + public override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg) + public override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg) + public override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg) + public override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg) + public override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg) + public override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg) + public override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg) + public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) - public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST = + public override fun power(arg: MST, pow: Number): MST.Binary = + binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) + + public override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg) + public override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION)(arg) + + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstField.binaryOperation(operation) - public override fun unaryOperation(operation: String): (arg: MST) -> MST = MstField.unaryOperation(operation) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = - MST.Unary(operation, arg) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstField.unaryOperation(operation) } From e62cf4fc6586643f6a5adddd2e1f25deb1b756cc Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 8 Dec 2020 14:42:42 +0700 Subject: [PATCH 3/6] Rewrite ASM codegen to use curried operators, fix bugs, update benchmarks --- .../ast/ExpressionsInterpretersBenchmark.kt | 49 +-- .../kmath/structures/ArrayBenchmark.kt | 8 +- .../kmath/structures/ViktorBenchmark.kt | 2 +- .../kotlin/kscience/kmath/ast/MST.kt | 3 +- .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 15 +- .../kscience/kmath/ast/MstExpression.kt | 7 +- .../jvmMain/kotlin/kscience/kmath/asm/asm.kt | 49 +-- .../kscience/kmath/asm/internal/AsmBuilder.kt | 300 ++++-------------- .../kscience/kmath/asm/internal/MstType.kt | 20 -- .../kmath/asm/internal/codegenUtils.kt | 120 ------- .../kscience/kmath/asm/TestAsmAlgebras.kt | 36 +-- .../kscience/kmath/asm/TestAsmExpressions.kt | 12 +- .../kmath/asm/TestAsmSpecialization.kt | 2 +- .../kscience/kmath/asm/TestAsmVariables.kt | 2 +- .../kscience/kmath/operations/Algebra.kt | 10 +- .../kmath/operations/NumberAlgebra.kt | 2 +- 16 files changed, 158 insertions(+), 479 deletions(-) rename examples/src/{main => benchmarks}/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt (66%) delete mode 100644 kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt diff --git a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt similarity index 66% rename from examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index a4806ed68..6acaca84d 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -4,13 +4,19 @@ import kscience.kmath.asm.compile import kscience.kmath.expressions.Expression import kscience.kmath.expressions.expressionInField import kscience.kmath.expressions.invoke +import kscience.kmath.expressions.symbol import kscience.kmath.operations.Field import kscience.kmath.operations.RealField +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State import kotlin.random.Random -import kotlin.system.measureTimeMillis +@State(Scope.Benchmark) internal class ExpressionsInterpretersBenchmark { private val algebra: Field = RealField + + @Benchmark fun functionalExpression() { val expr = algebra.expressionInField { symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0) @@ -19,6 +25,7 @@ internal class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } + @Benchmark fun mstExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -27,6 +34,7 @@ internal class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } + @Benchmark fun asmExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -35,6 +43,13 @@ internal class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } + @Benchmark + fun rawExpression() { + val x by symbol + val expr = Expression { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 } + invokeAndSum(expr) + } + private fun invokeAndSum(expr: Expression) { val random = Random(0) var sum = 0.0 @@ -46,35 +61,3 @@ internal class ExpressionsInterpretersBenchmark { println(sum) } } - -/** - * This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and - * core FunctionalExpressions API. - * - * The expected rating is: - * - * 1. ASM. - * 2. MST. - * 3. FE. - */ -fun main() { - val benchmark = ExpressionsInterpretersBenchmark() - - val fe = measureTimeMillis { - benchmark.functionalExpression() - } - - println("fe=$fe") - - val mst = measureTimeMillis { - benchmark.mstExpression() - } - - println("mst=$mst") - - val asm = measureTimeMillis { - benchmark.asmExpression() - } - - println("asm=$asm") -} diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt index a91d02253..f7e0d05c6 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt @@ -16,17 +16,13 @@ internal class ArrayBenchmark { @Benchmark fun benchmarkBufferRead() { var res = 0 - for (i in 1..size) res += arrayBuffer.get( - size - i - ) + for (i in 1..size) res += arrayBuffer[size - i] } @Benchmark fun nativeBufferRead() { var res = 0 - for (i in 1..size) res += nativeBuffer.get( - size - i - ) + for (i in 1..size) res += nativeBuffer[size - i] } companion object { diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt index 464925ca0..df35b0b78 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt @@ -36,7 +36,7 @@ internal class ViktorBenchmark { @Benchmark fun rawViktor() { - val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) + val one = F64Array.full(init = 1.0, shape = intArrayOf(dim, dim)) var res = one repeat(n) { res = res + one } } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt index cc3d38f9c..4ff4d7354 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt @@ -64,7 +64,8 @@ public fun Algebra.evaluate(node: MST): T = when (node) { node.left is MST.Numeric && node.right is MST.Numeric -> { val number = RealField - .binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble()) + .binaryOperation(node.operation) + .invoke(node.left.value.toDouble(), node.right.value.toDouble()) number(number) } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 44a7f20b9..96703ffb8 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -24,13 +24,17 @@ public object MstSpace : Space, NumericAlgebra { public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) + public override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) + public override fun MST.unaryMinus(): MST = unaryOperation(SpaceOperations.MINUS_OPERATION)(this) - override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = + public override fun multiply(a: MST, k: Number): MST.Binary = + binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) + + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstAlgebra.binaryOperation(operation) - override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstAlgebra.unaryOperation(operation) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperation(operation) } /** @@ -47,6 +51,7 @@ public object MstRing : Ring, NumericAlgebra { public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) + public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this) public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstSpace.binaryOperation(operation) @@ -71,6 +76,7 @@ public object MstField : Field { public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) + public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this) public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperation(operation) @@ -105,6 +111,7 @@ public object MstExtendedField : ExtendedField { public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this) public override fun power(arg: MST, pow: Number): MST.Binary = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt index 8a99fd2f4..4d1b65621 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -15,7 +15,12 @@ import kotlin.contracts.contract */ public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { - override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) + override fun symbol(value: String): T = try { + algebra.symbol(value) + } catch (ignored: IllegalStateException) { + null + } ?: arguments.getValue(StringSymbol(value)) + override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation) override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt index 9ccfa464c..932813182 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt @@ -1,13 +1,13 @@ package kscience.kmath.asm import kscience.kmath.asm.internal.AsmBuilder -import kscience.kmath.asm.internal.MstType -import kscience.kmath.asm.internal.buildAlgebraOperationCall import kscience.kmath.asm.internal.buildName import kscience.kmath.ast.MST import kscience.kmath.ast.MstExpression import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra +import kscience.kmath.operations.NumericAlgebra +import kscience.kmath.operations.RealField /** * Compiles given MST to an Expression using AST compiler. @@ -23,37 +23,46 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp is MST.Symbolic -> { val symbol = try { algebra.symbol(node.value) - } catch (ignored: Throwable) { + } catch (ignored: IllegalStateException) { null } if (symbol != null) - loadTConstant(symbol) + loadObjectConstant(symbol as Any) else loadVariable(node.value) } - is MST.Numeric -> loadNumeric(node.value) + is MST.Numeric -> loadNumberConstant(node.value) + is MST.Unary -> buildCall(algebra.unaryOperation(node.operation)) { visit(node.value) } - is MST.Unary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "unaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.value)) - ) { visit(node.value) } + is MST.Binary -> when { + algebra is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant( + algebra.number( + RealField + .binaryOperation(node.operation) + .invoke(node.left.value.toDouble(), node.right.value.toDouble()) + ) + ) - is MST.Binary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "binaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right)) - ) { - visit(node.left) - visit(node.right) + algebra is NumericAlgebra && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperation(node.operation)) { + visit(node.left) + visit(node.right) + } + + algebra is NumericAlgebra && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperation(node.operation)) { + visit(node.left) + visit(node.right) + } + + else -> buildCall(algebra.binaryOperation(node.operation)) { + visit(node.left) + visit(node.right) + } } } - return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() + return AsmBuilder(type, buildName(this)) { visit(this@compileWith) }.instance } /** diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index a1e482103..792034232 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -4,26 +4,24 @@ import kscience.kmath.asm.internal.AsmBuilder.ClassLoader import kscience.kmath.ast.MST import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra -import kscience.kmath.operations.NumericAlgebra import org.objectweb.asm.* import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter -import java.util.* -import java.util.stream.Collectors +import java.util.stream.Collectors.toMap +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @property T the type of AsmExpression to unwrap. - * @property algebra the algebra the applied AsmExpressions use. * @property className the unique class name of new loaded class. * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. * @author Iaroslav Postovalov */ internal class AsmBuilder internal constructor( - private val classOfT: Class<*>, - private val algebra: Algebra, + classOfT: Class<*>, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit, ) { @@ -39,15 +37,10 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - /** - * ASM Type for [algebra]. - */ - private val tAlgebraType: Type = algebra.javaClass.asm - /** * ASM type for [T]. */ - internal val tType: Type = classOfT.asm + private val tType: Type = classOfT.asm /** * ASM type for new class. @@ -69,51 +62,13 @@ internal class AsmBuilder internal constructor( */ private var hasConstants: Boolean = true - /** - * States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. - */ - internal var primitiveMode: Boolean = false - - /** - * Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMask: Type = OBJECT_TYPE - - /** - * Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMaskBoxed: Type = OBJECT_TYPE - - /** - * Stack of useful objects types on stack to verify types. - */ - private val typeStack: ArrayDeque = ArrayDeque() - - /** - * Stack of useful objects types on stack expected by algebra calls. - */ - internal val expectationStack: ArrayDeque = ArrayDeque(1).also { it.push(tType) } - - /** - * The cache for instance built by this builder. - */ - private var generatedInstance: Expression? = null - /** * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - internal fun getInstance(): Expression { - generatedInstance?.let { return it } - - if (SIGNATURE_LETTERS.containsKey(classOfT)) { - primitiveMode = true - primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) - primitiveMaskBoxed = tType - } - + val instance: Expression by lazy { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( V1_8, @@ -192,15 +147,6 @@ internal class AsmBuilder internal constructor( hasConstants = constants.isNotEmpty() - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "algebra", - descriptor = tAlgebraType.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd - ) - if (hasConstants) visitField( access = ACC_PRIVATE or ACC_FINAL, @@ -214,25 +160,17 @@ internal class AsmBuilder internal constructor( visitMethod( ACC_PUBLIC, "", - - Type.getMethodDescriptor( - Type.VOID_TYPE, - tAlgebraType, - *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), - + Type.getMethodDescriptor(Type.VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), null, null ).instructionAdapter { val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 + val constantsVar = 1 val l0 = label() load(thisVar, classType) invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) label() load(thisVar, classType) - load(algebraVar, tAlgebraType) - putfield(classType.internalName, "algebra", tAlgebraType.descriptor) if (hasConstants) { label() @@ -246,15 +184,6 @@ internal class AsmBuilder internal constructor( val l4 = label() visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) - visitLocalVariable( - "algebra", - tAlgebraType.descriptor, - null, - l0, - l4, - algebraVar - ) - if (hasConstants) visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) @@ -265,33 +194,55 @@ internal class AsmBuilder internal constructor( visitEnd() } - val new = classLoader +// java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + + classLoader .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression - - generatedInstance = new - return new + .newInstance(*(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression } /** - * Loads a [T] constant from [constants]. + * Loads [java.lang.Object] constant from constants. */ - internal fun loadTConstant(value: T) { - if (classOfT in INLINABLE_NUMBERS) { - val expectedType = expectationStack.pop() - val mustBeBoxed = expectedType.sort == Type.OBJECT - loadNumberConstant(value as Number, mustBeBoxed) + fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { + val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex + loadThis() + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + checkcast(type) + } - if (mustBeBoxed) - invokeMethodVisitor.checkcast(tType) + /** + * Loads `this` variable. + */ + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) - if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) + /** + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive + * constant from the constant pool. + */ + fun loadNumberConstant(value: Number) { + val boxed = value.javaClass.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] + + if (primitive != null) { + when (primitive) { + Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + + box(primitive) return } - loadObjectConstant(value as Any, tType) + loadObjectConstant(value, boxed) } /** @@ -309,77 +260,9 @@ internal class AsmBuilder internal constructor( } /** - * Unboxes the current boxed value and pushes it. + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. */ - private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual( - NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(primitive), - Type.getMethodDescriptor(primitive), - false - ) - - /** - * Loads [java.lang.Object] constant from constants. - */ - private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { - val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - loadThis() - getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) - iconst(idx) - visitInsn(AALOAD) - checkcast(type) - } - - internal fun loadNumeric(value: Number) { - if (expectationStack.peek() == NUMBER_TYPE) { - loadNumberConstant(value, true) - expectationStack.pop() - typeStack.push(NUMBER_TYPE) - } else (algebra as? NumericAlgebra)?.number(value)?.let { loadTConstant(it) } - ?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.") - } - - /** - * Loads this variable. - */ - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) - - /** - * Either loads a numeric constant [value] from the class's constants field or boxes a primitive - * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded - * from it). - */ - private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { - val boxed = value.javaClass.asm - val primitive = BOXED_TO_PRIMITIVES[boxed] - - if (primitive != null) { - when (primitive) { - Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) - Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) - Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) - Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - } - - if (mustBeBoxed) - box(primitive) - - return - } - - loadObjectConstant(value, boxed) - - if (!mustBeBoxed) - unboxTo(primitiveMask) - } - - /** - * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be - * provided. - */ - internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run { + fun loadVariable(name: String): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, MAP_TYPE) aconst(name) @@ -391,70 +274,28 @@ internal class AsmBuilder internal constructor( ) checkcast(tType) - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } } - /** - * Loads algebra from according field of the class and casts it to class of [algebra] provided. - */ - internal fun loadAlgebra() { - loadThis() - invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) - } + inline fun buildCall(function: Function, parameters: AsmBuilder.() -> Unit) { + contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } + val `interface` = function.javaClass.interfaces.first { it.interfaces.contains(Function::class.java) } - /** - * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is - * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be - * called before the arguments and this operation. - * - * The result is casted to [T] automatically. - */ - internal fun invokeAlgebraOperation( - owner: String, - method: String, - descriptor: String, - expectedArity: Int, - opcode: Int = INVOKEINTERFACE, - ) { - run loop@{ - repeat(expectedArity) { - if (typeStack.isEmpty()) return@loop - typeStack.pop() - } - } + val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount + ?: error("Provided function object doesn't contain invoke method") - invokeMethodVisitor.visitMethodInsn( - opcode, - owner, - method, - descriptor, - opcode == INVOKEINTERFACE + val type = Type.getType(`interface`) + loadObjectConstant(function, type) + parameters(this) + + invokeMethodVisitor.invokeinterface( + type.internalName, + "invoke", + Type.getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE}), ) invokeMethodVisitor.checkcast(tType) - val isLastExpr = expectationStack.size == 1 - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } } - /** - * Writes a LDC Instruction with string constant provided. - */ - internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) - internal companion object { /** * Index of `this` variable in invoke method of the built subclass. @@ -490,32 +331,13 @@ internal class AsmBuilder internal constructor( */ private val PRIMITIVES_TO_BOXED: Map by lazy { BOXED_TO_PRIMITIVES.entries.stream().collect( - Collectors.toMap( + toMap( Map.Entry::value, Map.Entry::key ) ) } - /** - * Maps primitive ASM types to [Number] functions unboxing them. - */ - private val NUMBER_CONVERTER_METHODS: Map by lazy { - hashMapOf( - Type.BYTE_TYPE to "byteValue", - Type.SHORT_TYPE to "shortValue", - Type.INT_TYPE to "intValue", - Type.LONG_TYPE to "longValue", - Type.FLOAT_TYPE to "floatValue", - Type.DOUBLE_TYPE to "doubleValue" - ) - } - - /** - * Provides boxed number types values of which can be stored in JVM bytecode constant pool. - */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - /** * ASM type for [Expression]. */ diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt deleted file mode 100644 index 526c27305..000000000 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt +++ /dev/null @@ -1,20 +0,0 @@ -package kscience.kmath.asm.internal - -import kscience.kmath.ast.MST - -/** - * Represents types known in [MST], numbers and general values. - */ -internal enum class MstType { - GENERAL, - NUMBER; - - companion object { - fun fromMst(mst: MST): MstType { - if (mst is MST.Numeric) - return NUMBER - - return GENERAL - } - } -} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt index ef9751502..d301ddf15 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt @@ -2,29 +2,11 @@ package kscience.kmath.asm.internal import kscience.kmath.ast.MST import kscience.kmath.expressions.Expression -import kscience.kmath.operations.Algebra -import kscience.kmath.operations.FieldOperations -import kscience.kmath.operations.RingOperations -import kscience.kmath.operations.SpaceOperations import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.INVOKEVIRTUAL import org.objectweb.asm.commons.InstructionAdapter -import java.lang.reflect.Method -import java.util.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract -private val methodNameAdapters: Map, String> by lazy { - hashMapOf( - SpaceOperations.PLUS_OPERATION to 2 to "add", - RingOperations.TIMES_OPERATION to 2 to "multiply", - FieldOperations.DIV_OPERATION to 2 to "divide", - SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus", - SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus", - SpaceOperations.MINUS_OPERATION to 2 to "minus" - ) -} - /** * Returns ASM [Type] for given [Class]. * @@ -110,106 +92,4 @@ internal inline fun ClassWriter.visitField( return visitField(access, name, descriptor, signature, value).apply(block) } -private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = - context.javaClass.methods.find { method -> - val nameValid = method.name == name - val arityValid = method.parameters.size == parameterTypes.size - val notBridgeInPrimitive = !(primitiveMode && method.isBridge) - val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) -> - !(mstType != MstType.NUMBER && type == java.lang.Number::class.java) - } - - nameValid && arityValid && notBridgeInPrimitive && paramsValid - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds - * type expectation stack for needed arity. - * - * @author Iaroslav Postovalov - */ -private fun AsmBuilder.buildExpectationStack( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes) - - if (specific != null) - mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) } - else - expectationStack.addAll(Collections.nCopies(arity, tType)) - - return specific != null -} - -private fun AsmBuilder.mapTypes(method: Method, parameterTypes: Array): List = method - .parameterTypes - .zip(parameterTypes) - .map { (type, mstType) -> - when { - type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE - else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed - } - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts - * [AsmBuilder.invokeAlgebraOperation] of this method. - * - * @author Iaroslav Postovalov - */ -private fun AsmBuilder.tryInvokeSpecific( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val theName = methodNameAdapters[name to arity] ?: name - val spec = findSpecific(context, theName, parameterTypes) ?: return false - val owner = context.javaClass.asm - - invokeAlgebraOperation( - owner = owner.internalName, - method = theName, - descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()), - expectedArity = arity, - opcode = INVOKEVIRTUAL - ) - - return true -} - -/** - * Builds specialized [context] call with option to fallback to generic algebra operation accepting [String]. - * - * @author Iaroslav Postovalov - */ -internal inline fun AsmBuilder.buildAlgebraOperationCall( - context: Algebra, - name: String, - fallbackMethodName: String, - parameterTypes: Array, - parameters: AsmBuilder.() -> Unit -) { - contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } - val arity = parameterTypes.size - loadAlgebra() - if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) - parameters() - - if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = fallbackMethodName, - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - *Array(arity) { AsmBuilder.OBJECT_TYPE } - ), - - expectedArity = arity - ) -} diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt index 5eebfe43d..a1687c1c7 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt @@ -10,15 +10,11 @@ import kotlin.test.Test import kotlin.test.assertEquals internal class TestAsmAlgebras { - @Test fun space() { val res1 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -30,11 +26,8 @@ internal class TestAsmAlgebras { }("x" to 2.toByte()) val res2 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -51,11 +44,8 @@ internal class TestAsmAlgebras { @Test fun ring() { val res1 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 @@ -67,17 +57,13 @@ internal class TestAsmAlgebras { }("x" to 3.toByte()) val res2 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 ) + 1.toByte()))) * 3.0 - 1.toByte() ), - number(1) ) * number(2) }.compile()("x" to 3.toByte()) @@ -88,8 +74,7 @@ internal class TestAsmAlgebras { @Test fun field() { val res1 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one @@ -97,8 +82,7 @@ internal class TestAsmAlgebras { }("x" to 2.0) val res2 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt index 7f4d453f7..acd9e21ea 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt @@ -1,10 +1,11 @@ package kscience.kmath.asm -import kscience.kmath.asm.compile +import kscience.kmath.ast.mstInExtendedField import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInSpace import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField +import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals @@ -28,4 +29,13 @@ internal class TestAsmExpressions { val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) assertEquals(4.0, res) } + + @Test + fun testMultipleCalls() { + val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile() + val r = Random(0) + var s = 0.0 + repeat(1000000) { s += e("x" to r.nextDouble()) } + println(s) + } } diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt index eb5a1c8f3..ce9e602d3 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt @@ -46,7 +46,7 @@ internal class TestAsmSpecialization { @Test fun testPower() { val expr = RealField - .mstInField { binaryOperation("power")(symbol("x"), number(2)) } + .mstInField { binaryOperation("pow")(symbol("x"), number(2)) } .compile() assertEquals(4.0, expr("x" to 2.0)) diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt index 7b89c74fa..1011515c8 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt @@ -17,6 +17,6 @@ internal class TestAsmVariables { @Test fun testVariableWithoutDefaultFails() { val expr = ByteRing.mstInRing { symbol("x") } - assertFailsWith { expr() } + assertFailsWith { expr() } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 19ca1a4cc..2661525fb 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -20,12 +20,14 @@ public interface Algebra { /** * Dynamically dispatches an unary operation with name [operation]. */ - public fun unaryOperation(operation: String): (arg: T) -> T + public fun unaryOperation(operation: String): (arg: T) -> T = + error("Unary operation $operation not defined in $this") /** * Dynamically dispatches a binary operation with name [operation]. */ - public fun binaryOperation(operation: String): (left: T, right: T) -> T + public fun binaryOperation(operation: String): (left: T, right: T) -> T = + error("Binary operation $operation not defined in $this") } /** @@ -161,13 +163,13 @@ public interface SpaceOperations : Algebra { override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) { PLUS_OPERATION -> { arg -> arg } MINUS_OPERATION -> { arg -> -arg } - else -> error("Unary operation $operation not defined in $this") + else -> super.unaryOperation(operation) } override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { PLUS_OPERATION -> ::add MINUS_OPERATION -> { left, right -> left - right } - else -> error("Binary operation $operation not defined in $this") + else -> super.binaryOperation(operation) } public companion object { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt index 611222d34..cd2bd0417 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt @@ -31,7 +31,7 @@ public interface ExtendedFieldOperations : PowerOperations.SQRT_OPERATION -> ::sqrt ExponentialOperations.EXP_OPERATION -> ::exp ExponentialOperations.LN_OPERATION -> ::ln - else -> super.unaryOperation(operation) + else -> super.unaryOperation(operation) } } From cc45e3683b702b8c8e637af66d7cf07d67dcbd24 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 8 Dec 2020 16:16:32 +0700 Subject: [PATCH 4/6] Refactor ASM builder --- .../kscience/kmath/asm/internal/AsmBuilder.kt | 129 +++++++----------- .../kmath/asm/internal/codegenUtils.kt | 2 - 2 files changed, 53 insertions(+), 78 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index 792034232..46051e158 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra import org.objectweb.asm.* import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.Type.* import org.objectweb.asm.commons.InstructionAdapter import java.util.stream.Collectors.toMap import kotlin.contracts.InvocationKind @@ -17,13 +18,13 @@ import kotlin.contracts.contract * * @property T the type of AsmExpression to unwrap. * @property className the unique class name of new loaded class. - * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. + * @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0. * @author Iaroslav Postovalov */ -internal class AsmBuilder internal constructor( +internal class AsmBuilder( classOfT: Class<*>, private val className: String, - private val invokeLabel0Visitor: AsmBuilder.() -> Unit, + private val callbackAtInvokeL0: AsmBuilder.() -> Unit, ) { /** * Internal classloader of [AsmBuilder] with alias to define class from byte array. @@ -45,7 +46,7 @@ internal class AsmBuilder internal constructor( /** * ASM type for new class. */ - private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/')) /** * List of constants to provide to the subclass. @@ -57,11 +58,6 @@ internal class AsmBuilder internal constructor( */ private lateinit var invokeMethodVisitor: InstructionAdapter - /** - * States whether this [AsmBuilder] needs to generate constants field. - */ - private var hasConstants: Boolean = true - /** * Subclasses, loads and instantiates [Expression] for given parameters. * @@ -69,6 +65,8 @@ internal class AsmBuilder internal constructor( */ @Suppress("UNCHECKED_CAST") val instance: Expression by lazy { + val hasConstants: Boolean + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( V1_8, @@ -82,14 +80,14 @@ internal class AsmBuilder internal constructor( visitMethod( ACC_PUBLIC or ACC_FINAL, "invoke", - Type.getMethodDescriptor(tType, MAP_TYPE), + getMethodDescriptor(tType, MAP_TYPE), "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", null ).instructionAdapter { invokeMethodVisitor = this visitCode() val l0 = label() - invokeLabel0Visitor() + callbackAtInvokeL0() areturn(tType) val l1 = label() @@ -99,7 +97,7 @@ internal class AsmBuilder internal constructor( null, l0, l1, - invokeThisVar + 0 ) visitLocalVariable( @@ -108,7 +106,7 @@ internal class AsmBuilder internal constructor( "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", l0, l1, - invokeArgumentsVar + 1 ) visitMaxs(0, 2) @@ -118,7 +116,7 @@ internal class AsmBuilder internal constructor( visitMethod( ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, "invoke", - Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), null, null ).instructionAdapter { @@ -128,7 +126,7 @@ internal class AsmBuilder internal constructor( val l0 = label() load(thisVar, OBJECT_TYPE) load(argumentsVar, MAP_TYPE) - invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false) areturn(tType) val l1 = label() @@ -160,32 +158,30 @@ internal class AsmBuilder internal constructor( visitMethod( ACC_PUBLIC, "", - Type.getMethodDescriptor(Type.VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), + getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), null, null ).instructionAdapter { - val thisVar = 0 - val constantsVar = 1 val l0 = label() - load(thisVar, classType) - invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + load(0, classType) + invokespecial(OBJECT_TYPE.internalName, "", getMethodDescriptor(VOID_TYPE), false) label() - load(thisVar, classType) + load(0, classType) if (hasConstants) { label() - load(thisVar, classType) - load(constantsVar, OBJECT_ARRAY_TYPE) + load(0, classType) + load(1, OBJECT_ARRAY_TYPE) putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) } label() visitInsn(RETURN) val l4 = label() - visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) + visitLocalVariable("this", classType.descriptor, null, l0, l4, 0) if (hasConstants) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1) visitMaxs(0, 3) visitEnd() @@ -218,7 +214,7 @@ internal class AsmBuilder internal constructor( /** * Loads `this` variable. */ - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) + private fun loadThis(): Unit = invokeMethodVisitor.load(0, classType) /** * Either loads a numeric constant [value] from the class's constants field or boxes a primitive @@ -230,12 +226,12 @@ internal class AsmBuilder internal constructor( if (primitive != null) { when (primitive) { - Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) - Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) - Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) - Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) } box(primitive) @@ -254,7 +250,7 @@ internal class AsmBuilder internal constructor( invokeMethodVisitor.invokestatic( r.internalName, "valueOf", - Type.getMethodDescriptor(r, primitive), + getMethodDescriptor(r, primitive), false ) } @@ -263,13 +259,13 @@ internal class AsmBuilder internal constructor( * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. */ fun loadVariable(name: String): Unit = invokeMethodVisitor.run { - load(invokeArgumentsVar, MAP_TYPE) + load(1, MAP_TYPE) aconst(name) invokestatic( MAP_INTRINSICS_TYPE.internalName, "getOrFail", - Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), false ) @@ -283,100 +279,81 @@ internal class AsmBuilder internal constructor( val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount ?: error("Provided function object doesn't contain invoke method") - val type = Type.getType(`interface`) + val type = getType(`interface`) loadObjectConstant(function, type) parameters(this) invokeMethodVisitor.invokeinterface( type.internalName, "invoke", - Type.getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE}), + getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }), ) invokeMethodVisitor.checkcast(tType) } - internal companion object { - /** - * Index of `this` variable in invoke method of the built subclass. - */ - private const val invokeThisVar: Int = 0 - - /** - * Index of `arguments` variable in invoke method of the built subclass. - */ - private const val invokeArgumentsVar: Int = 1 - - /** - * Maps JVM primitive numbers boxed types to their primitive ASM types. - */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { - hashMapOf( - java.lang.Byte::class.java to Type.BYTE_TYPE, - java.lang.Short::class.java to Type.SHORT_TYPE, - java.lang.Integer::class.java to Type.INT_TYPE, - java.lang.Long::class.java to Type.LONG_TYPE, - java.lang.Float::class.java to Type.FLOAT_TYPE, - java.lang.Double::class.java to Type.DOUBLE_TYPE - ) - } - + companion object { /** * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. */ - private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + private val BOXED_TO_PRIMITIVES: Map by lazy { + hashMapOf( + Byte::class.java.asm to BYTE_TYPE, + Short::class.java.asm to SHORT_TYPE, + Integer::class.java.asm to INT_TYPE, + Long::class.java.asm to LONG_TYPE, + Float::class.java.asm to FLOAT_TYPE, + Double::class.java.asm to DOUBLE_TYPE, + ) + } /** * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. */ private val PRIMITIVES_TO_BOXED: Map by lazy { BOXED_TO_PRIMITIVES.entries.stream().collect( - toMap( - Map.Entry::value, - Map.Entry::key - ) + toMap(Map.Entry::value, Map.Entry::key), ) } /** * ASM type for [Expression]. */ - internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") } + val EXPRESSION_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Expression") } /** * ASM type for [java.lang.Number]. */ - internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") } + val NUMBER_TYPE: Type by lazy { getObjectType("java/lang/Number") } /** * ASM type for [java.util.Map]. */ - internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") } + val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") } /** * ASM type for [java.lang.Object]. */ - internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") } + val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") } /** * ASM type for array of [java.lang.Object]. */ - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") } + val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } /** * ASM type for [Algebra]. */ - internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") } + val ALGEBRA_TYPE: Type by lazy { getObjectType("kscience/kmath/operations/Algebra") } /** * ASM type for [java.lang.String]. */ - internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") } + val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") } /** * ASM type for MapIntrinsics. */ - internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/asm/internal/MapIntrinsics") } + val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("kscience/kmath/asm/internal/MapIntrinsics") } } } diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt index d301ddf15..6d5d19d42 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt @@ -91,5 +91,3 @@ internal inline fun ClassWriter.visitField( contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return visitField(access, name, descriptor, signature, value).apply(block) } - - From 95c1504c00006eb4f5096b569cede64788032f93 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Wed, 9 Dec 2020 11:41:37 +0700 Subject: [PATCH 5/6] Add cast microoptimization to AsmBuilder --- .../kscience/kmath/asm/internal/AsmBuilder.kt | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index 46051e158..eb0f60f4d 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -3,7 +3,6 @@ package kscience.kmath.asm.internal import kscience.kmath.asm.internal.AsmBuilder.ClassLoader import kscience.kmath.ast.MST import kscience.kmath.expressions.Expression -import kscience.kmath.operations.Algebra import org.objectweb.asm.* import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type.* @@ -74,7 +73,7 @@ internal class AsmBuilder( classType.internalName, "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", OBJECT_TYPE.internalName, - arrayOf(EXPRESSION_TYPE.internalName) + arrayOf(EXPRESSION_TYPE.internalName), ) visitMethod( @@ -82,7 +81,7 @@ internal class AsmBuilder( "invoke", getMethodDescriptor(tType, MAP_TYPE), "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", - null + null, ).instructionAdapter { invokeMethodVisitor = this visitCode() @@ -97,7 +96,7 @@ internal class AsmBuilder( null, l0, l1, - 0 + 0, ) visitLocalVariable( @@ -106,7 +105,7 @@ internal class AsmBuilder( "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", l0, l1, - 1 + 1, ) visitMaxs(0, 2) @@ -118,14 +117,12 @@ internal class AsmBuilder( "invoke", getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), null, - null + null, ).instructionAdapter { - val thisVar = 0 - val argumentsVar = 1 visitCode() val l0 = label() - load(thisVar, OBJECT_TYPE) - load(argumentsVar, MAP_TYPE) + load(0, OBJECT_TYPE) + load(1, MAP_TYPE) invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false) areturn(tType) val l1 = label() @@ -136,7 +133,7 @@ internal class AsmBuilder( null, l0, l1, - thisVar + 0, ) visitMaxs(0, 2) @@ -152,7 +149,7 @@ internal class AsmBuilder( descriptor = OBJECT_ARRAY_TYPE.descriptor, signature = null, value = null, - block = FieldVisitor::visitEnd + block = FieldVisitor::visitEnd, ) visitMethod( @@ -208,7 +205,7 @@ internal class AsmBuilder( getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) iconst(idx) visitInsn(AALOAD) - checkcast(type) + if (type != OBJECT_TYPE) checkcast(type) } /** @@ -266,7 +263,7 @@ internal class AsmBuilder( MAP_INTRINSICS_TYPE.internalName, "getOrFail", getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), - false + false, ) checkcast(tType) @@ -321,11 +318,6 @@ internal class AsmBuilder( */ val EXPRESSION_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Expression") } - /** - * ASM type for [java.lang.Number]. - */ - val NUMBER_TYPE: Type by lazy { getObjectType("java/lang/Number") } - /** * ASM type for [java.util.Map]. */ @@ -341,11 +333,6 @@ internal class AsmBuilder( */ val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } - /** - * ASM type for [Algebra]. - */ - val ALGEBRA_TYPE: Type by lazy { getObjectType("kscience/kmath/operations/Algebra") } - /** * ASM type for [java.lang.String]. */ From 07d6d891928d89d3eb14917cfbf78d790e4883b8 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Wed, 9 Dec 2020 14:43:37 +0700 Subject: [PATCH 6/6] Replace reflective constructor invocation with method handle --- .../kscience/kmath/asm/internal/AsmBuilder.kt | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index eb0f60f4d..d146afeee 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -7,6 +7,8 @@ import org.objectweb.asm.* import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type.* import org.objectweb.asm.commons.InstructionAdapter +import java.lang.invoke.MethodHandles +import java.lang.invoke.MethodType import java.util.stream.Collectors.toMap import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -187,13 +189,14 @@ internal class AsmBuilder( visitEnd() } -// java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + val cls = classLoader.defineClass(className, classWriter.toByteArray()) + val l = MethodHandles.publicLookup() - classLoader - .defineClass(className, classWriter.toByteArray()) - .constructors - .first() - .newInstance(*(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression + if (hasConstants) + l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array::class.java)) + .invoke(constants.toTypedArray()) as Expression + else + l.findConstructor(cls, MethodType.methodType(Void.TYPE)).invoke() as Expression } /** @@ -248,7 +251,7 @@ internal class AsmBuilder( r.internalName, "valueOf", getMethodDescriptor(r, primitive), - false + false, ) }