From 4a26c1080e2b16ab3a6d88675e7604ae2141f775 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 14 May 2020 20:30:43 +0300 Subject: [PATCH 001/104] A prototype for advanced expressoins --- .../kmath/expressions/Expression.kt | 67 +--------- .../expressions/functionalExpressions.kt | 125 ++++++++++++++++++ .../kmath/expressions/syntaxTree.kt | 31 +++++ .../kmath/expressions/ExpressionFieldTest.kt | 12 +- 4 files changed, 163 insertions(+), 72 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index aa7407c0a..bc3d70b4f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,9 +1,5 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space - /** * An elementary function that could be invoked on a map of arguments */ @@ -26,67 +22,6 @@ interface ExpressionContext { * A constant expression which does not depend on arguments */ fun const(value: T): Expression -} -internal class VariableExpression(val name: String, val default: T? = null) : Expression { - override fun invoke(arguments: Map): T = - arguments[name] ?: default ?: error("Parameter not found: $name") -} - -internal class ConstantExpression(val value: T) : Expression { - override fun invoke(arguments: Map): T = value -} - -internal class SumExpression(val context: Space, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = - context.multiply(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : - Expression { - override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) -} - -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) -} - -open class ExpressionSpace(val space: Space) : Space>, ExpressionContext { - override val zero: Expression = ConstantExpression(space.zero) - - override fun const(value: T): Expression = ConstantExpression(value) - - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) - - - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this -} - - -class ExpressionField(val field: Field) : Field>, ExpressionSpace(field) { - override val one: Expression = ConstantExpression(field.one) - override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) - - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) - - operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - - operator fun T.times(arg: Expression) = arg * this - operator fun T.div(arg: Expression) = arg / this + fun produce(node: SyntaxTreeNode): Expression } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt new file mode 100644 index 000000000..64b3ad72c --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -0,0 +1,125 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.Space + +internal class VariableExpression(val name: String, val default: T? = null) : Expression { + override fun invoke(arguments: Map): T = + arguments[name] ?: default ?: error("Parameter not found: $name") +} + +internal class ConstantExpression(val value: T) : Expression { + override fun invoke(arguments: Map): T = value +} + +internal class SumExpression( + val context: Space, + val first: Expression, + val second: Expression +) : Expression { + override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) +} + +internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : + Expression { + override fun invoke(arguments: Map): T = + context.multiply(first.invoke(arguments), second.invoke(arguments)) +} + +internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : + Expression { + override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) +} + +internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : + Expression { + override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) +} + +open class FunctionalExpressionSpace( + val space: Space, + one: T +) : Space>, ExpressionContext { + + override val zero: Expression = ConstantExpression(space.zero) + + val one: Expression = ConstantExpression(one) + + override fun const(value: T): Expression = ConstantExpression(value) + + override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) + + override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) + + override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) + + + operator fun Expression.plus(arg: T) = this + const(arg) + operator fun Expression.minus(arg: T) = this - const(arg) + + operator fun T.plus(arg: Expression) = arg + this + operator fun T.minus(arg: Expression) = arg - this + + fun const(value: Double): Expression = one.times(value) + + open fun produceSingular(value: String): Expression { + val numberValue = value.toDoubleOrNull() + return if (numberValue == null) { + variable(value) + } else { + const(numberValue) + } + } + + open fun produceUnary(operation: String, value: Expression): Expression { + return when (operation) { + UnaryNode.PLUS_OPERATION -> value + UnaryNode.MINUS_OPERATION -> -value + else -> error("Unary operation $operation is not supported by $this") + } + } + + open fun produceBinary(operation: String, left: Expression, right: Expression): Expression { + return when (operation) { + BinaryNode.PLUS_OPERATION -> left + right + BinaryNode.MINUS_OPERATION -> left - right + else -> error("Binary operation $operation is not supported by $this") + } + } + + override fun produce(node: SyntaxTreeNode): Expression { + return when (node) { + is SingularNode -> produceSingular(node.value) + is UnaryNode -> produceUnary(node.operation, produce(node.value)) + is BinaryNode -> produceBinary(node.operation, produce(node.left), produce(node.right)) + } + } +} + +open class FunctionalExpressionField( + val field: Field +) : Field>, FunctionalExpressionSpace(field, field.one) { + override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) + + override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) + + operator fun Expression.times(arg: T) = this * const(arg) + operator fun Expression.div(arg: T) = this / const(arg) + + operator fun T.times(arg: Expression) = arg * this + operator fun T.div(arg: Expression) = arg / this + + override fun produce(node: SyntaxTreeNode): Expression { + //TODO bring together numeric and typed expressions + return super.produce(node) + } + + override fun produceBinary(operation: String, left: Expression, right: Expression): Expression { + return when (operation) { + BinaryNode.TIMES_OPERATION -> left * right + BinaryNode.DIV_OPERATION -> left / right + else -> super.produceBinary(operation, left, right) + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt new file mode 100644 index 000000000..daa17977d --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt @@ -0,0 +1,31 @@ +package scientifik.kmath.expressions + +sealed class SyntaxTreeNode + +data class SingularNode(val value: String) : SyntaxTreeNode() + +data class UnaryNode(val operation: String, val value: SyntaxTreeNode): SyntaxTreeNode(){ + companion object{ + const val PLUS_OPERATION = "+" + const val MINUS_OPERATION = "-" + const val NOT_OPERATION = "!" + const val ABS_OPERATION = "abs" + const val SIN_OPERATION = "sin" + const val cos_OPERATION = "cos" + //TODO add operations + } +} + +data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode): SyntaxTreeNode(){ + companion object{ + const val PLUS_OPERATION = "+" + const val MINUS_OPERATION = "-" + const val TIMES_OPERATION = "*" + const val DIV_OPERATION = "/" + //TODO add operations + } +} + +//TODO add a function with positional arguments + +//TODO add a function with named arguments \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index 033b2792f..b933f6a17 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -9,7 +9,7 @@ import kotlin.test.assertEquals class ExpressionFieldTest { @Test fun testExpression() { - val context = ExpressionField(RealField) + val context = FunctionalExpressionField(RealField) val expression = with(context) { val x = variable("x", 2.0) x * x + 2 * x + one @@ -20,7 +20,7 @@ class ExpressionFieldTest { @Test fun testComplex() { - val context = ExpressionField(ComplexField) + val context = FunctionalExpressionField(ComplexField) val expression = with(context) { val x = variable("x", Complex(2.0, 0.0)) x * x + 2 * x + one @@ -31,23 +31,23 @@ class ExpressionFieldTest { @Test fun separateContext() { - fun ExpressionField.expression(): Expression { + fun FunctionalExpressionField.expression(): Expression { val x = variable("x") return x * x + 2 * x + one } - val expression = ExpressionField(RealField).expression() + val expression = FunctionalExpressionField(RealField).expression() assertEquals(expression("x" to 1.0), 4.0) } @Test fun valueExpression() { - val expressionBuilder: ExpressionField.() -> Expression = { + val expressionBuilder: FunctionalExpressionField.() -> Expression = { val x = variable("x") x * x + 2 * x + one } - val expression = ExpressionField(RealField).expressionBuilder() + val expression = FunctionalExpressionField(RealField).expressionBuilder() assertEquals(expression("x" to 1.0), 4.0) } } \ No newline at end of file From 1a869ace0e23659eca398d7a768a54191bc084d7 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 15 May 2020 16:10:19 +0300 Subject: [PATCH 002/104] Refactored Expression tree API --- .../commons/expressions/DiffExpression.kt | 5 +- .../kmath/expressions/Expression.kt | 92 ++++++++++++++++++- .../{syntaxTree.kt => SyntaxTreeNode.kt} | 14 ++- .../expressions/functionalExpressions.kt | 60 ++---------- 4 files changed, 107 insertions(+), 64 deletions(-) rename kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/{syntaxTree.kt => SyntaxTreeNode.kt} (67%) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index d5c038dc4..65e915ef4 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,9 +2,8 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.expressions.ExpressionField import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -113,7 +112,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionContext, Field { +object DiffExpressionContext : ExpressionField { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index bc3d70b4f..94bc81413 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,5 +1,8 @@ package scientifik.kmath.expressions +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space + /** * An elementary function that could be invoked on a map of arguments */ @@ -12,16 +15,97 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext { +interface ExpressionContext> { /** * Introduce a variable into expression context */ - fun variable(name: String, default: T? = null): Expression + fun variable(name: String, default: T? = null): E /** * A constant expression which does not depend on arguments */ - fun const(value: T): Expression + fun const(value: T): E + + fun produce(node: SyntaxTreeNode): E +} + +interface ExpressionSpace> : Space, ExpressionContext { + + open fun produceSingular(value: String): E = variable(value) + + open fun produceUnary(operation: String, value: E): E { + return when (operation) { + UnaryNode.PLUS_OPERATION -> value + UnaryNode.MINUS_OPERATION -> -value + else -> error("Unary operation $operation is not supported by $this") + } + } + + open fun produceBinary(operation: String, left: E, right: E): E { + return when (operation) { + BinaryNode.PLUS_OPERATION -> left + right + BinaryNode.MINUS_OPERATION -> left - right + else -> error("Binary operation $operation is not supported by $this") + } + } + + override fun produce(node: SyntaxTreeNode): E { + return when (node) { + is NumberNode -> error("Single number nodes are not supported") + is SingularNode -> produceSingular(node.value) + is UnaryNode -> produceUnary(node.operation, produce(node.value)) + is BinaryNode -> { + when (node.operation) { + BinaryNode.TIMES_OPERATION -> { + if (node.left is NumberNode) { + return produce(node.right) * node.left.value + } else if (node.right is NumberNode) { + return produce(node.left) * node.right.value + } + } + BinaryNode.DIV_OPERATION -> { + if (node.right is NumberNode) { + return produce(node.left) / node.right.value + } + } + } + produceBinary(node.operation, produce(node.left), produce(node.right)) + } + } + } +} + +interface ExpressionField> : Field, ExpressionSpace { + fun const(value: Double): E = one.times(value) + + override fun produce(node: SyntaxTreeNode): E { + if (node is BinaryNode) { + when (node.operation) { + BinaryNode.PLUS_OPERATION -> { + if (node.left is NumberNode) { + return produce(node.right) + one * node.left.value + } else if (node.right is NumberNode) { + return produce(node.left) + one * node.right.value + } + } + BinaryNode.MINUS_OPERATION -> { + if (node.left is NumberNode) { + return one * node.left.value - produce(node.right) + } else if (node.right is NumberNode) { + return produce(node.left) - one * node.right.value + } + } + } + } + return super.produce(node) + } + + override fun produceBinary(operation: String, left: E, right: E): E { + return when (operation) { + BinaryNode.TIMES_OPERATION -> left * right + BinaryNode.DIV_OPERATION -> left / right + else -> super.produceBinary(operation, left, right) + } + } - fun produce(node: SyntaxTreeNode): Expression } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt similarity index 67% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt index daa17977d..fff3c5bc0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/syntaxTree.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt @@ -4,20 +4,24 @@ sealed class SyntaxTreeNode data class SingularNode(val value: String) : SyntaxTreeNode() -data class UnaryNode(val operation: String, val value: SyntaxTreeNode): SyntaxTreeNode(){ - companion object{ +data class NumberNode(val value: Number) : SyntaxTreeNode() + +data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() { + companion object { const val PLUS_OPERATION = "+" const val MINUS_OPERATION = "-" const val NOT_OPERATION = "!" const val ABS_OPERATION = "abs" const val SIN_OPERATION = "sin" - const val cos_OPERATION = "cos" + const val COS_OPERATION = "cos" + const val EXP_OPERATION = "exp" + const val LN_OPERATION = "ln" //TODO add operations } } -data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode): SyntaxTreeNode(){ - companion object{ +data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() { + companion object { const val PLUS_OPERATION = "+" const val MINUS_OPERATION = "-" const val TIMES_OPERATION = "*" diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt index 64b3ad72c..ad0acbc4a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -40,12 +40,10 @@ internal class DivExpession(val context: Field, val expr: Expression, v open class FunctionalExpressionSpace( val space: Space, one: T -) : Space>, ExpressionContext { +) : Space>, ExpressionSpace> { override val zero: Expression = ConstantExpression(space.zero) - val one: Expression = ConstantExpression(one) - override fun const(value: T): Expression = ConstantExpression(value) override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) @@ -60,46 +58,17 @@ open class FunctionalExpressionSpace( operator fun T.plus(arg: Expression) = arg + this operator fun T.minus(arg: Expression) = arg - this - - fun const(value: Double): Expression = one.times(value) - - open fun produceSingular(value: String): Expression { - val numberValue = value.toDoubleOrNull() - return if (numberValue == null) { - variable(value) - } else { - const(numberValue) - } - } - - open fun produceUnary(operation: String, value: Expression): Expression { - return when (operation) { - UnaryNode.PLUS_OPERATION -> value - UnaryNode.MINUS_OPERATION -> -value - else -> error("Unary operation $operation is not supported by $this") - } - } - - open fun produceBinary(operation: String, left: Expression, right: Expression): Expression { - return when (operation) { - BinaryNode.PLUS_OPERATION -> left + right - BinaryNode.MINUS_OPERATION -> left - right - else -> error("Binary operation $operation is not supported by $this") - } - } - - override fun produce(node: SyntaxTreeNode): Expression { - return when (node) { - is SingularNode -> produceSingular(node.value) - is UnaryNode -> produceUnary(node.operation, produce(node.value)) - is BinaryNode -> produceBinary(node.operation, produce(node.left), produce(node.right)) - } - } } open class FunctionalExpressionField( val field: Field -) : Field>, FunctionalExpressionSpace(field, field.one) { +) : ExpressionField>, FunctionalExpressionSpace(field, field.one) { + + override val one: Expression + get() = const(this.field.one) + + override fun const(value: Double): Expression = const(field.run { one*value}) + override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) @@ -109,17 +78,4 @@ open class FunctionalExpressionField( operator fun T.times(arg: Expression) = arg * this operator fun T.div(arg: Expression) = arg / this - - override fun produce(node: SyntaxTreeNode): Expression { - //TODO bring together numeric and typed expressions - return super.produce(node) - } - - override fun produceBinary(operation: String, left: Expression, right: Expression): Expression { - return when (operation) { - BinaryNode.TIMES_OPERATION -> left * right - BinaryNode.DIV_OPERATION -> left / right - else -> super.produceBinary(operation, left, right) - } - } } \ No newline at end of file From 3ea76d56a563175a4a66f3ea68ce7d1e5b6cd157 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 5 Jun 2020 22:05:16 +0700 Subject: [PATCH 003/104] Implement kmath-asm module stubs --- kmath-asm/build.gradle.kts | 9 + .../kmath/expressions/AsmExpressions.kt | 385 ++++++++++++++++++ .../kmath/expressions/MethodVisitors.kt | 17 + .../scientifik/kmath/expressions/AsmTest.kt | 23 ++ kmath-core/build.gradle.kts | 2 +- .../kmath/expressions/Expression.kt | 12 +- ...xpressions.kt => FunctionalExpressions.kt} | 0 .../{bigNumbers.kt => BigNumbers.kt} | 0 settings.gradle.kts | 3 +- 9 files changed, 443 insertions(+), 8 deletions(-) create mode 100644 kmath-asm/build.gradle.kts create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt create mode 100644 kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt rename kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/{functionalExpressions.kt => FunctionalExpressions.kt} (100%) rename kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/{bigNumbers.kt => BigNumbers.kt} (100%) diff --git a/kmath-asm/build.gradle.kts b/kmath-asm/build.gradle.kts new file mode 100644 index 000000000..ab4684cdc --- /dev/null +++ b/kmath-asm/build.gradle.kts @@ -0,0 +1,9 @@ +plugins { + id("scientifik.jvm") +} + +dependencies { + api(project(path = ":kmath-core")) + api("org.ow2.asm:asm:8.0.1") + api("org.ow2.asm:asm-commons:8.0.1") +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt new file mode 100644 index 000000000..b9c357671 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -0,0 +1,385 @@ +package scientifik.kmath.expressions + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes.* +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space +import java.io.File + +abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { + abstract fun evaluate(arguments: Map): T +} + +class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra, private val className: String) { + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + + @Suppress("PrivatePropertyName") + private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + + @Suppress("PrivatePropertyName") + private val T_CLASS: String = classOfT.name.replace('.', '/') + private val constants: MutableList = mutableListOf() + private val asmCompiledClassWriter = ClassWriter(0) + private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + private val evaluateMethodVisitor: MethodVisitor + private val evaluateThisVar: Int = 0 + private val evaluateArgumentsVar: Int = 1 + private var evaluateL0: Label + private lateinit var evaluateL1: Label + var maxStack: Int = 0 + + init { + asmCompiledClassWriter.visit( + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, + slashesClassName, + "L$ASM_COMPILED_CLASS;", + ASM_COMPILED_CLASS, + arrayOf() + ) + + asmCompiledClassWriter.run { + visitMethod(ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = Label() + visitLabel(l0) + visitVarInsn(ALOAD, thisVar) + visitVarInsn(ALOAD, algebraVar) + visitVarInsn(ALOAD, constantsVar) + + visitMethodInsn( + INVOKESPECIAL, + ASM_COMPILED_CLASS, + "", + "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", + false + ) + + val l1 = Label() + visitLabel(l1) + visitInsn(RETURN) + val l2 = Label() + visitLabel(l2) + visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + + visitLocalVariable( + "algebra", + "L$ALGEBRA_CLASS;", + "L$ALGEBRA_CLASS;", + l0, + l2, + algebraVar + ) + + visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) + visitMaxs(3, 3) + visitEnd() + } + + evaluateMethodVisitor = visitMethod( + ACC_PUBLIC or ACC_FINAL, + "evaluate", + "(L$MAP_CLASS;)L$T_CLASS;", + "(L$MAP_CLASS;)L$T_CLASS;", + null + ) + + evaluateMethodVisitor.run { + visitCode() + evaluateL0 = Label() + visitLabel(evaluateL0) + } + } + } + + @Suppress("UNCHECKED_CAST") + fun generate(): AsmCompiled { + evaluateMethodVisitor.run { + visitInsn(ARETURN) + evaluateL1 = Label() + visitLabel(evaluateL1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + evaluateL0, + evaluateL1, + evaluateThisVar + ) + + visitLocalVariable( + "arguments", + "L$MAP_CLASS;", + "L$MAP_CLASS;", + evaluateL0, + evaluateL1, + evaluateArgumentsVar + ) + + visitMaxs(maxStack + 1, 2) + visitEnd() + } + + asmCompiledClassWriter.visitMethod( + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, + "evaluate", + "(L$MAP_CLASS;)L$OBJECT_CLASS;", + null, + null + ).run { + val thisVar = 0 + visitCode() + val l0 = Label() + visitLabel(l0) + visitVarInsn(ALOAD, 0) + visitVarInsn(ALOAD, 1) + visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "evaluate", "(L$MAP_CLASS;)L$T_CLASS;", false) + visitInsn(ARETURN) + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + thisVar + ) + + visitMaxs(2, 2) + visitEnd() + } + + asmCompiledClassWriter.visitEnd() + + return classLoader + .defineClass(className, asmCompiledClassWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, constants) as AsmCompiled + } + + fun visitLoadFromConstants(value: T) { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + maxStack++ + + evaluateMethodVisitor.run { + visitLoadThis() + visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") + visitLdcOrIConstInsn(idx) + visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) + visitCastToT() + } + } + + private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar) + + fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value) + + fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run { + maxStack++ + visitVarInsn(ALOAD, evaluateArgumentsVar) + + if (defaultValue != null) { + visitLdcInsn(name) + visitLoadFromConstants(defaultValue) + + visitMethodInsn( + INVOKEINTERFACE, + MAP_CLASS, + "getOrDefault", + "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", + true + ) + + visitCastToT() + return + } + + visitLdcInsn(name) + visitMethodInsn(INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) + visitCastToT() + } + + fun visitLoadAlgebra() { + evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar) + + evaluateMethodVisitor.visitFieldInsn( + GETFIELD, + ASM_COMPILED_CLASS, "algebra", "L$ALGEBRA_CLASS;" + ) + + evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS) + } + + fun visitInvokeAlgebraOperation(owner: String, method: String, descriptor: String) { + maxStack++ + evaluateMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true) + visitCastToT() + } + + fun visitCastToT() { + evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS) + } + + companion object { + const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled" + const val LIST_CLASS = "java/util/List" + const val MAP_CLASS = "java/util/Map" + const val OBJECT_CLASS = "java/lang/Object" + const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + const val SPACE_CLASS = "scientifik/kmath/operations/Space" + const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + const val FIELD_CLASS = "scientifik/kmath/operations/Field" + const val STRING_CLASS = "java/lang/String" + } +} + +interface AsmExpression { + fun invoke(gen: AsmGenerationContext) +} + +internal class AsmVariableExpression(val name: String, val default: T? = null) : + AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadFromVariables(name, default) + } +} + +internal class AsmConstantExpression(val value: T) : + AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadFromConstants(value) + } +} + +internal class AsmSumExpression( + val first: AsmExpression, + val second: AsmExpression +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + first.invoke(gen) + second.invoke(gen) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + method = "add", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmProductExpression( + val first: AsmExpression, + val second: AsmExpression +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + first.invoke(gen) + second.invoke(gen) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.SPACE_CLASS, + method = "times", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmConstProductExpression( + val expr: AsmExpression, + val const: Number +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + expr.invoke(gen) + gen.visitNumberConstant(const) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.SPACE_CLASS, + method = "multiply", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmDivExpression( + val expr: AsmExpression, + val second: AsmExpression +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + expr.invoke(gen) + second.invoke(gen) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.FIELD_CLASS, + method = "divide", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + + +open class AsmFunctionalExpressionSpace( + val space: Space, + one: T +) : Space>, + ExpressionSpace> { + override val zero: AsmExpression = + AsmConstantExpression(space.zero) + + override fun const(value: T): AsmExpression = + AsmConstantExpression(value) + + override fun variable(name: String, default: T?): AsmExpression = + AsmVariableExpression( + name, + default + ) + + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmSumExpression(a, b) + + override fun multiply(a: AsmExpression, k: Number): AsmExpression = + AsmConstProductExpression(a, k) + + + operator fun AsmExpression.plus(arg: T) = this + const(arg) + operator fun AsmExpression.minus(arg: T) = this - const(arg) + + operator fun T.plus(arg: AsmExpression) = arg + this + operator fun T.minus(arg: AsmExpression) = arg - this +} + +class AsmFunctionalExpressionField(val field: Field) : ExpressionField>, + AsmFunctionalExpressionSpace(field, field.one) { + override val one: AsmExpression + get() = const(this.field.one) + + override fun const(value: Double): AsmExpression = const(field.run { one * value }) + + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmProductExpression(a, b) + + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmDivExpression(a, b) + + operator fun AsmExpression.times(arg: T) = this * const(arg) + operator fun AsmExpression.div(arg: T) = this / const(arg) + + operator fun T.times(arg: AsmExpression) = arg * this + operator fun T.div(arg: AsmExpression) = arg / this +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt new file mode 100644 index 000000000..fdbc1062e --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt @@ -0,0 +1,17 @@ +package scientifik.kmath.expressions + +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes + +fun MethodVisitor.visitLdcOrIConstInsn(value: Int) { + when (value) { + -1 -> visitInsn(Opcodes.ICONST_M1) + 0 -> visitInsn(Opcodes.ICONST_0) + 1 -> visitInsn(Opcodes.ICONST_1) + 2 -> visitInsn(Opcodes.ICONST_2) + 3 -> visitInsn(Opcodes.ICONST_3) + 4 -> visitInsn(Opcodes.ICONST_4) + 5 -> visitInsn(Opcodes.ICONST_5) + else -> visitLdcInsn(value) + } +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt new file mode 100644 index 000000000..e95f8df76 --- /dev/null +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -0,0 +1,23 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +class AsmTest { + @Test + fun test() { + val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")) + + val gen = AsmGenerationContext( + java.lang.Double::class.java, + RealField, + "MyAsmCompiled" + ) + + expr.invoke(gen) + val compiled = gen.generate() + val value = compiled.evaluate(mapOf("x" to 25.0)) + assertEquals(26.0, value) + } +} diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 092f3deb7..18c0cc771 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -8,4 +8,4 @@ kotlin.sourceSets { api(project(":kmath-memory")) } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 94bc81413..7b781e4e1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -15,7 +15,7 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext> { +interface ExpressionContext { /** * Introduce a variable into expression context */ @@ -29,11 +29,11 @@ interface ExpressionContext> { fun produce(node: SyntaxTreeNode): E } -interface ExpressionSpace> : Space, ExpressionContext { +interface ExpressionSpace : Space, ExpressionContext { - open fun produceSingular(value: String): E = variable(value) + fun produceSingular(value: String): E = variable(value) - open fun produceUnary(operation: String, value: E): E { + fun produceUnary(operation: String, value: E): E { return when (operation) { UnaryNode.PLUS_OPERATION -> value UnaryNode.MINUS_OPERATION -> -value @@ -41,7 +41,7 @@ interface ExpressionSpace> : Space, ExpressionContext left + right BinaryNode.MINUS_OPERATION -> left - right @@ -75,7 +75,7 @@ interface ExpressionSpace> : Space, ExpressionContext> : Field, ExpressionSpace { +interface ExpressionField : Field, ExpressionSpace { fun const(value: Double): E = one.times(value) override fun produce(node: SyntaxTreeNode): E { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt similarity index 100% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt similarity index 100% rename from kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt rename to kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt diff --git a/settings.gradle.kts b/settings.gradle.kts index 57173250b..fcaa72698 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -44,5 +44,6 @@ include( ":kmath-dimensions", ":kmath-for-real", ":kmath-geometry", - ":examples" + ":examples", + ":kmath-asm" ) From fdd2551c3f8efd6efa76e4d7c09e74ab18a1aa35 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 5 Jun 2020 22:12:39 +0700 Subject: [PATCH 004/104] Minor refactor --- .../kmath/expressions/AsmExpressions.kt | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index b9c357671..36607d2b5 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -7,7 +7,6 @@ import org.objectweb.asm.Opcodes.* import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space -import java.io.File abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { abstract fun evaluate(arguments: Map): T @@ -223,15 +222,13 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra { internal class AsmVariableExpression(val name: String, val default: T? = null) : AsmExpression { - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadFromVariables(name, default) - } + override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) } internal class AsmConstantExpression(val value: T) : AsmExpression { - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadFromConstants(value) - } + override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } internal class AsmSumExpression( @@ -273,7 +267,7 @@ internal class AsmSumExpression( first.invoke(gen) second.invoke(gen) - gen.visitInvokeAlgebraOperation( + gen.visitAlgebraOperation( owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, method = "add", descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" @@ -290,8 +284,8 @@ internal class AsmProductExpression( first.invoke(gen) second.invoke(gen) - gen.visitInvokeAlgebraOperation( - owner = AsmGenerationContext.SPACE_CLASS, + gen.visitAlgebraOperation( + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, method = "times", descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" ) @@ -307,7 +301,7 @@ internal class AsmConstProductExpression( expr.invoke(gen) gen.visitNumberConstant(const) - gen.visitInvokeAlgebraOperation( + gen.visitAlgebraOperation( owner = AsmGenerationContext.SPACE_CLASS, method = "multiply", descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" @@ -324,8 +318,8 @@ internal class AsmDivExpression( expr.invoke(gen) second.invoke(gen) - gen.visitInvokeAlgebraOperation( - owner = AsmGenerationContext.FIELD_CLASS, + gen.visitAlgebraOperation( + owner = AsmGenerationContext.FIELD_OPERATIONS_CLASS, method = "divide", descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" ) From 557142c2bab38bffe63de44e4f6dee79d3324291 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 5 Jun 2020 23:02:16 +0700 Subject: [PATCH 005/104] Add more tests, fix constant product and product operations impl. --- .../kmath/expressions/AsmExpressions.kt | 25 ++++--- .../kmath/expressions/MethodVisitors.kt | 33 ++++++--- .../scientifik/kmath/expressions/AsmTest.kt | 71 +++++++++++++++---- 3 files changed, 97 insertions(+), 32 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 36607d2b5..9a7ad69f2 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -7,6 +7,7 @@ import org.objectweb.asm.Opcodes.* import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space +import java.io.File abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { abstract fun evaluate(arguments: Map): T @@ -184,10 +185,13 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra( second.invoke(gen) gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, - method = "times", + owner = AsmGenerationContext.RING_OPERATIONS_CLASS, + method = "multiply", descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" ) } @@ -298,13 +302,13 @@ internal class AsmConstProductExpression( ) : AsmExpression { override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() - expr.invoke(gen) gen.visitNumberConstant(const) + expr.invoke(gen) gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_CLASS, + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" ) } } @@ -326,7 +330,6 @@ internal class AsmDivExpression( } } - open class AsmFunctionalExpressionSpace( val space: Space, one: T diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt index fdbc1062e..5b9a8afe8 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt @@ -1,17 +1,34 @@ package scientifik.kmath.expressions import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes +import org.objectweb.asm.Opcodes.* fun MethodVisitor.visitLdcOrIConstInsn(value: Int) { when (value) { - -1 -> visitInsn(Opcodes.ICONST_M1) - 0 -> visitInsn(Opcodes.ICONST_0) - 1 -> visitInsn(Opcodes.ICONST_1) - 2 -> visitInsn(Opcodes.ICONST_2) - 3 -> visitInsn(Opcodes.ICONST_3) - 4 -> visitInsn(Opcodes.ICONST_4) - 5 -> visitInsn(Opcodes.ICONST_5) + -1 -> visitInsn(ICONST_M1) + 0 -> visitInsn(ICONST_0) + 1 -> visitInsn(ICONST_1) + 2 -> visitInsn(ICONST_2) + 3 -> visitInsn(ICONST_3) + 4 -> visitInsn(ICONST_4) + 5 -> visitInsn(ICONST_5) else -> visitLdcInsn(value) } } + +private val signatureLetters = mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D", + java.lang.Short::class.java to "S" +) + +fun MethodVisitor.visitBoxedNumberConstant(number: Number) { + val clazz = number.javaClass + val c = clazz.name.replace('.', '/') + visitLdcInsn(number) + visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false) +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index e95f8df76..402307050 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -1,23 +1,68 @@ package scientifik.kmath.expressions +import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals class AsmTest { - @Test - fun test() { - val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")) - - val gen = AsmGenerationContext( - java.lang.Double::class.java, - RealField, - "MyAsmCompiled" + private fun testExpressionValue( + expectedValue: T, + expr: AsmExpression, + arguments: Map, + algebra: Algebra, + clazz: Class<*> + ) { + assertEquals( + expectedValue, AsmGenerationContext( + clazz, + algebra, + "TestAsmCompiled" + ).also(expr::invoke).generate().evaluate(arguments) ) - - expr.invoke(gen) - val compiled = gen.generate() - val value = compiled.evaluate(mapOf("x" to 25.0)) - assertEquals(26.0, value) } + + @Suppress("UNCHECKED_CAST") + private fun testDoubleExpressionValue( + expectedValue: Double, + expr: AsmExpression, + arguments: Map, + algebra: Algebra = RealField, + clazz: Class = java.lang.Double::class.java as Class + ) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) + + @Test + fun testSum() = testDoubleExpressionValue( + 25.0, + AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")), + mapOf("x" to 24.0) + ) + + @Test + fun testConst() = testDoubleExpressionValue( + 123.0, + AsmConstantExpression(123.0), + mapOf() + ) + + @Test + fun testDiv() = testDoubleExpressionValue( + 0.5, + AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)), + mapOf() + ) + + @Test + fun testProduct() = testDoubleExpressionValue( + 25.0, + AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")), + mapOf("x" to 5.0) + ) + + @Test + fun testCProduct() = testDoubleExpressionValue( + 25.0, + AsmConstProductExpression(AsmVariableExpression("x"), 5.0), + mapOf("x" to 5.0) + ) } From d6d9351c9c15bf8780a1bc16953d98b37d9650cb Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 5 Jun 2020 23:04:04 +0700 Subject: [PATCH 006/104] Add more tests --- .../kotlin/scientifik/kmath/expressions/AsmTest.kt | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 402307050..98aa3706b 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -65,4 +65,18 @@ class AsmTest { AsmConstProductExpression(AsmVariableExpression("x"), 5.0), mapOf("x" to 5.0) ) + + @Test + fun testVar() = testDoubleExpressionValue( + 10000.0, + AsmVariableExpression("x"), + mapOf("x" to 10000.0) + ) + + @Test + fun testVarWithDefault() = testDoubleExpressionValue( + 10000.0, + AsmVariableExpression("x", 10000.0), + mapOf() + ) } From 723bd84417e168f067d59d84ab20330b9efa8423 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 6 Jun 2020 21:06:15 +0700 Subject: [PATCH 007/104] Remove unused import --- .../main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 9a7ad69f2..f4951081b 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -7,7 +7,6 @@ import org.objectweb.asm.Opcodes.* import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space -import java.io.File abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { abstract fun evaluate(arguments: Map): T From 30b32c15159118f1a93d1f2d1f466711c656c6ee Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 6 Jun 2020 21:45:15 +0700 Subject: [PATCH 008/104] Move initialization code to separate method to make AsmGenerationContext restartable --- .../kmath/expressions/AsmExpressions.kt | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index f4951081b..876e7abe2 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -24,17 +24,26 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra = mutableListOf() - private val asmCompiledClassWriter = ClassWriter(0) + private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') - private val evaluateMethodVisitor: MethodVisitor private val evaluateThisVar: Int = 0 private val evaluateArgumentsVar: Int = 1 - private var evaluateL0: Label + private var maxStack: Int = 0 + private lateinit var constants: MutableList + private lateinit var asmCompiledClassWriter: ClassWriter + private lateinit var evaluateMethodVisitor: MethodVisitor + private lateinit var evaluateL0: Label private lateinit var evaluateL1: Label - var maxStack: Int = 0 init { + start() + } + + fun start() { + constants = mutableListOf() + asmCompiledClassWriter = ClassWriter(0) + maxStack = 0 + asmCompiledClassWriter.visit( V1_8, ACC_PUBLIC or ACC_FINAL or ACC_SUPER, From 41094e63da244314b07bb7043d519600b018d9dc Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 6 Jun 2020 21:48:52 +0700 Subject: [PATCH 009/104] Optimize number constants placing to economize contant pool places --- .../kmath/expressions/MethodVisitors.kt | 42 +++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt index 5b9a8afe8..dc5d83ea5 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt @@ -3,17 +3,28 @@ package scientifik.kmath.expressions import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* -fun MethodVisitor.visitLdcOrIConstInsn(value: Int) { - when (value) { - -1 -> visitInsn(ICONST_M1) - 0 -> visitInsn(ICONST_0) - 1 -> visitInsn(ICONST_1) - 2 -> visitInsn(ICONST_2) - 3 -> visitInsn(ICONST_3) - 4 -> visitInsn(ICONST_4) - 5 -> visitInsn(ICONST_5) - else -> visitLdcInsn(value) - } +fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { + -1 -> visitInsn(ICONST_M1) + 0 -> visitInsn(ICONST_0) + 1 -> visitInsn(ICONST_1) + 2 -> visitInsn(ICONST_2) + 3 -> visitInsn(ICONST_3) + 4 -> visitInsn(ICONST_4) + 5 -> visitInsn(ICONST_5) + else -> visitLdcInsn(value) +} + +fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) { + 0.0 -> visitInsn(DCONST_0) + 1.0 -> visitInsn(DCONST_1) + else -> visitLdcInsn(value) +} + +fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) { + 0f -> visitInsn(FCONST_0) + 1f -> visitInsn(FCONST_1) + 2f -> visitInsn(FCONST_2) + else -> visitLdcInsn(value) } private val signatureLetters = mapOf( @@ -29,6 +40,13 @@ private val signatureLetters = mapOf( fun MethodVisitor.visitBoxedNumberConstant(number: Number) { val clazz = number.javaClass val c = clazz.name.replace('.', '/') - visitLdcInsn(number) + + when (number) { + is Int -> visitLdcOrIConstInsn(number) + is Double -> visitLdcOrDConstInsn(number) + is Float -> visitLdcOrFConstInsn(number) + else -> visitLdcInsn(number) + } + visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false) } From cdb24ea8e2c798850e09c22f12c44cd03e16e2f0 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 6 Jun 2020 21:51:22 +0700 Subject: [PATCH 010/104] Remove duplicate Short key from signatureLetters map --- .../main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt index dc5d83ea5..7ca488846 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt @@ -33,8 +33,7 @@ private val signatureLetters = mapOf( java.lang.Integer::class.java to "I", java.lang.Long::class.java to "J", java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D", - java.lang.Short::class.java to "S" + java.lang.Double::class.java to "D" ) fun MethodVisitor.visitBoxedNumberConstant(number: Number) { From fb74e74b01b780cc7d389140ad61cb06d3d383f6 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 6 Jun 2020 22:05:25 +0700 Subject: [PATCH 011/104] Remove constant number allocation hack and support uncommon Number implementations to be available in constants --- .../kmath/expressions/AsmExpressions.kt | 36 ++++++++++++++++--- .../kmath/expressions/MethodVisitors.kt | 23 ------------ 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 876e7abe2..1b44903cd 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -8,7 +8,7 @@ import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space -abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { +abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { abstract fun evaluate(arguments: Map): T } @@ -29,7 +29,7 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra + private lateinit var constants: MutableList private lateinit var asmCompiledClassWriter: ClassWriter private lateinit var evaluateMethodVisitor: MethodVisitor private lateinit var evaluateL0: Label @@ -178,7 +178,9 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra } - fun visitLoadFromConstants(value: T) { + fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any) + + fun visitLoadAnyFromConstants(value: Any) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex maxStack++ @@ -195,7 +197,24 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra evaluateMethodVisitor.visitLdcOrIConstInsn(value) + is Double -> evaluateMethodVisitor.visitLdcOrDConstInsn(value) + is Float -> evaluateMethodVisitor.visitLdcOrFConstInsn(value) + else -> evaluateMethodVisitor.visitLdcInsn(value) + } + + evaluateMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) + return + } + + visitLoadAnyFromConstants(value) } fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run { @@ -243,6 +262,15 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra visitInsn(FCONST_2) else -> visitLdcInsn(value) } - -private val signatureLetters = mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" -) - -fun MethodVisitor.visitBoxedNumberConstant(number: Number) { - val clazz = number.javaClass - val c = clazz.name.replace('.', '/') - - when (number) { - is Int -> visitLdcOrIConstInsn(number) - is Double -> visitLdcOrDConstInsn(number) - is Float -> visitLdcOrFConstInsn(number) - else -> visitLdcInsn(number) - } - - visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false) -} From 6686144538268343e47856d39b26fc8a42d17e01 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 6 Jun 2020 22:09:18 +0700 Subject: [PATCH 012/104] Add type casts for constants --- .../kotlin/scientifik/kmath/expressions/AsmExpressions.kt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 1b44903cd..8060674e5 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -178,9 +178,9 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra } - fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any) + fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) - fun visitLoadAnyFromConstants(value: Any) { + fun visitLoadAnyFromConstants(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex maxStack++ @@ -189,7 +189,7 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra Date: Sun, 7 Jun 2020 15:57:23 +0700 Subject: [PATCH 013/104] Add more test for ASM const product --- .../scientifik/kmath/expressions/AsmTest.kt | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 98aa3706b..9ab6f26c9 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -66,6 +66,30 @@ class AsmTest { mapOf("x" to 5.0) ) + @Test + fun testCProductWithOtherTypeNumber() = testDoubleExpressionValue( + 25.0, + AsmConstProductExpression(AsmVariableExpression("x"), 5f), + mapOf("x" to 5.0) + ) + + object CustomZero : Number() { + override fun toByte(): Byte = 0 + override fun toChar(): Char = 0.toChar() + override fun toDouble(): Double = 0.0 + override fun toFloat(): Float = 0f + override fun toInt(): Int = 0 + override fun toLong(): Long = 0L + override fun toShort(): Short = 0 + } + + @Test + fun testCProductWithCustomTypeNumber() = testDoubleExpressionValue( + 0.0, + AsmConstProductExpression(AsmVariableExpression("x"), CustomZero), + mapOf("x" to 5.0) + ) + @Test fun testVar() = testDoubleExpressionValue( 10000.0, From 6ac0297530010799e2ede800f93a7a50e8c891b0 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Sun, 7 Jun 2020 19:04:39 +0700 Subject: [PATCH 014/104] Move asm dependency to implementation configuration; rename many ASM API classes, make AsmCompiledExpression implement functional Expression, fix typos, encapsulate AsmGenerationContext --- kmath-asm/build.gradle.kts | 4 +- .../kmath/expressions/AsmExpressions.kt | 199 +++++++++--------- .../scientifik/kmath/expressions/AsmTest.kt | 25 +-- .../kmath/expressions/Expression.kt | 2 +- .../expressions/FunctionalExpressions.kt | 46 ++-- 5 files changed, 129 insertions(+), 147 deletions(-) diff --git a/kmath-asm/build.gradle.kts b/kmath-asm/build.gradle.kts index ab4684cdc..4104349f8 100644 --- a/kmath-asm/build.gradle.kts +++ b/kmath-asm/build.gradle.kts @@ -4,6 +4,6 @@ plugins { dependencies { api(project(path = ":kmath-core")) - api("org.ow2.asm:asm:8.0.1") - api("org.ow2.asm:asm-commons:8.0.1") + implementation("org.ow2.asm:asm:8.0.1") + implementation("org.ow2.asm:asm-commons:8.0.1") } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 8060674e5..f196efffe 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -8,11 +8,30 @@ import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space -abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { - abstract fun evaluate(arguments: Map): T +abstract class AsmCompiledExpression internal constructor( + @JvmField val algebra: Algebra, + @JvmField val constants: MutableList +) : Expression { + abstract override fun invoke(arguments: Map): T } -class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra, private val className: String) { +internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + +class AsmGenerationContext( + classOfT: Class<*>, + private val algebra: Algebra, + private val className: String +) { private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } @@ -26,30 +45,23 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra - private lateinit var asmCompiledClassWriter: ClassWriter - private lateinit var evaluateMethodVisitor: MethodVisitor - private lateinit var evaluateL0: Label - private lateinit var evaluateL1: Label + private val constants: MutableList = mutableListOf() + private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) + private val invokeMethodVisitor: MethodVisitor + private val invokeL0: Label + private lateinit var invokeL1: Label + private var generatedInstance: AsmCompiledExpression? = null init { - start() - } - - fun start() { - constants = mutableListOf() - asmCompiledClassWriter = ClassWriter(0) - maxStack = 0 - asmCompiledClassWriter.visit( V1_8, ACC_PUBLIC or ACC_FINAL or ACC_SUPER, slashesClassName, - "L$ASM_COMPILED_CLASS;", - ASM_COMPILED_CLASS, + "L$ASM_COMPILED_EXPRESSION_CLASS;", + ASM_COMPILED_EXPRESSION_CLASS, arrayOf() ) @@ -66,7 +78,7 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", false @@ -93,45 +105,48 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra;)L$T_CLASS;", null ) - evaluateMethodVisitor.run { + invokeMethodVisitor.run { visitCode() - evaluateL0 = Label() - visitLabel(evaluateL0) + invokeL0 = Label() + visitLabel(invokeL0) } } } + @PublishedApi @Suppress("UNCHECKED_CAST") - fun generate(): AsmCompiled { - evaluateMethodVisitor.run { + internal fun generate(): AsmCompiledExpression { + generatedInstance?.let { return it } + + invokeMethodVisitor.run { visitInsn(ARETURN) - evaluateL1 = Label() - visitLabel(evaluateL1) + invokeL1 = Label() + visitLabel(invokeL1) visitLocalVariable( "this", "L$slashesClassName;", T_CLASS, - evaluateL0, - evaluateL1, - evaluateThisVar + invokeL0, + invokeL1, + invokeThisVar ) visitLocalVariable( "arguments", "L$MAP_CLASS;", "L$MAP_CLASS;", - evaluateL0, - evaluateL1, - evaluateArgumentsVar + invokeL0, + invokeL1, + invokeArgumentsVar ) visitMaxs(maxStack + 1, 2) @@ -140,7 +155,7 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra + .newInstance(algebra, constants) as AsmCompiledExpression + + generatedInstance = new + return new } - fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) + internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) - fun visitLoadAnyFromConstants(value: Any, type: String) { + private fun visitLoadAnyFromConstants(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex maxStack++ - evaluateMethodVisitor.run { + invokeMethodVisitor.run { visitLoadThis() visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") visitLdcOrIConstInsn(idx) visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) - evaluateMethodVisitor.visitTypeInsn(CHECKCAST, type) + invokeMethodVisitor.visitTypeInsn(CHECKCAST, type) } } - private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar) + private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar) - fun visitNumberConstant(value: Number) { + internal fun visitNumberConstant(value: Number) { maxStack++ val clazz = value.javaClass val c = clazz.name.replace('.', '/') @@ -204,22 +222,22 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra evaluateMethodVisitor.visitLdcOrIConstInsn(value) - is Double -> evaluateMethodVisitor.visitLdcOrDConstInsn(value) - is Float -> evaluateMethodVisitor.visitLdcOrFConstInsn(value) - else -> evaluateMethodVisitor.visitLdcInsn(value) + is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) + is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) + is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) + else -> invokeMethodVisitor.visitLdcInsn(value) } - evaluateMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) + invokeMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) return } visitLoadAnyFromConstants(value, c) } - fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run { + internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run { maxStack += 2 - visitVarInsn(ALOAD, evaluateArgumentsVar) + visitVarInsn(ALOAD, invokeArgumentsVar) if (defaultValue != null) { visitLdcInsn(name) @@ -242,26 +260,26 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra( } } -open class AsmFunctionalExpressionSpace( - val space: Space, - one: T -) : Space>, +open class AsmExpressionSpace(space: Space) : Space>, ExpressionSpace> { - override val zero: AsmExpression = - AsmConstantExpression(space.zero) - - override fun const(value: T): AsmExpression = - AsmConstantExpression(value) - - override fun variable(name: String, default: T?): AsmExpression = - AsmVariableExpression( - name, - default - ) - - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmSumExpression(a, b) - - override fun multiply(a: AsmExpression, k: Number): AsmExpression = - AsmConstProductExpression(a, k) - - + override val zero: AsmExpression = AsmConstantExpression(space.zero) + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) operator fun AsmExpression.plus(arg: T) = this + const(arg) operator fun AsmExpression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: AsmExpression) = arg + this operator fun T.minus(arg: AsmExpression) = arg - this } -class AsmFunctionalExpressionField(val field: Field) : ExpressionField>, - AsmFunctionalExpressionSpace(field, field.one) { +class AsmExpressionField(private val field: Field) : ExpressionField>, + AsmExpressionSpace(field) { override val one: AsmExpression get() = const(this.field.one) - override fun const(value: Double): AsmExpression = const(field.run { one * value }) + override fun number(value: Number): AsmExpression = const(field.run { one * value }) override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(a, b) @@ -412,7 +412,6 @@ class AsmFunctionalExpressionField(val field: Field) : ExpressionField.times(arg: T) = this * const(arg) operator fun AsmExpression.div(arg: T) = this / const(arg) - operator fun T.times(arg: AsmExpression) = arg * this operator fun T.div(arg: AsmExpression) = arg / this } diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 9ab6f26c9..875080616 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -12,15 +12,12 @@ class AsmTest { arguments: Map, algebra: Algebra, clazz: Class<*> - ) { - assertEquals( - expectedValue, AsmGenerationContext( - clazz, - algebra, - "TestAsmCompiled" - ).also(expr::invoke).generate().evaluate(arguments) - ) - } + ): Unit = assertEquals( + expectedValue, AsmGenerationContext(clazz, algebra, "TestAsmCompiled") + .also(expr::invoke) + .generate() + .invoke(arguments) + ) @Suppress("UNCHECKED_CAST") private fun testDoubleExpressionValue( @@ -29,7 +26,7 @@ class AsmTest { arguments: Map, algebra: Algebra = RealField, clazz: Class = java.lang.Double::class.java as Class - ) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) + ): Unit = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) @Test fun testSum() = testDoubleExpressionValue( @@ -39,28 +36,28 @@ class AsmTest { ) @Test - fun testConst() = testDoubleExpressionValue( + fun testConst(): Unit = testDoubleExpressionValue( 123.0, AsmConstantExpression(123.0), mapOf() ) @Test - fun testDiv() = testDoubleExpressionValue( + fun testDiv(): Unit = testDoubleExpressionValue( 0.5, AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)), mapOf() ) @Test - fun testProduct() = testDoubleExpressionValue( + fun testProduct(): Unit = testDoubleExpressionValue( 25.0, AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")), mapOf("x" to 5.0) ) @Test - fun testCProduct() = testDoubleExpressionValue( + fun testCProduct(): Unit = testDoubleExpressionValue( 25.0, AsmConstProductExpression(AsmVariableExpression("x"), 5.0), mapOf("x" to 5.0) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 7b781e4e1..dae069879 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -76,7 +76,7 @@ interface ExpressionSpace : Space, ExpressionContext { } interface ExpressionField : Field, ExpressionSpace { - fun const(value: Double): E = one.times(value) + fun number(value: Number): E = one * value override fun produce(node: SyntaxTreeNode): E { if (node is BinaryNode) { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index ad0acbc4a..dbc2eac1e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -27,55 +27,41 @@ internal class ProductExpression(val context: Ring, val first: Expression< context.multiply(first.invoke(arguments), second.invoke(arguments)) } -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : +internal class ConstProductExpression(val context: Space, val expr: Expression, val const: Number) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : +internal class DivExpression(val context: Field, val expr: Expression, val second: Expression) : Expression { override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) } -open class FunctionalExpressionSpace( - val space: Space, - one: T -) : Space>, ExpressionSpace> { - +open class FunctionalExpressionSpace(val space: Space) : Space>, ExpressionSpace> { override val zero: Expression = ConstantExpression(space.zero) override fun const(value: T): Expression = ConstantExpression(value) - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) - - - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this + override fun multiply(a: Expression, k: Number): Expression = ConstProductExpression(space, a, k) + operator fun Expression.plus(arg: T): Expression = this + const(arg) + operator fun Expression.minus(arg: T): Expression = this - const(arg) + operator fun T.plus(arg: Expression): Expression = arg + this + operator fun T.minus(arg: Expression): Expression = arg - this } open class FunctionalExpressionField( val field: Field -) : ExpressionField>, FunctionalExpressionSpace(field, field.one) { +) : ExpressionField>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) - override fun const(value: Double): Expression = const(field.run { one*value}) - + override fun number(value: Number): Expression = const(field.run { one * value }) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) - - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) - - operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - - operator fun T.times(arg: Expression) = arg * this - operator fun T.div(arg: Expression) = arg / this -} \ No newline at end of file + override fun divide(a: Expression, b: Expression): Expression = DivExpression(field, a, b) + operator fun Expression.times(arg: T): Expression = this * const(arg) + operator fun Expression.div(arg: T): Expression = this / const(arg) + operator fun T.times(arg: Expression): Expression = arg * this + operator fun T.div(arg: Expression): Expression = arg / this +} From 013030951ee8cd0c8319c5504d07704a91a90901 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Sun, 7 Jun 2020 19:23:39 +0700 Subject: [PATCH 015/104] Make AsmCompiledExpression fields private, add builder functions not to expose AsmGenerationContext to public API, refactor code --- .../kmath/expressions/AsmExpressionSpaces.kt | 31 ++ .../kmath/expressions/AsmExpressions.kt | 329 +----------------- .../kmath/expressions/AsmGenerationContext.kt | 283 +++++++++++++++ .../scientifik/kmath/expressions/Builders.kt | 38 ++ .../scientifik/kmath/expressions/AsmTest.kt | 8 +- 5 files changed, 358 insertions(+), 331 deletions(-) create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt new file mode 100644 index 000000000..5c93fd729 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt @@ -0,0 +1,31 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space + +open class AsmExpressionSpace(space: Space) : Space>, + ExpressionSpace> { + override val zero: AsmExpression = AsmConstantExpression(space.zero) + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this +} + +class AsmExpressionField(private val field: Field) : ExpressionField>, + AsmExpressionSpace(field) { + override val one: AsmExpression + get() = const(this.field.one) + + override fun number(value: Number): AsmExpression = const(field.run { one * value }) + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(a, b) + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(a, b) + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index f196efffe..8708a33e1 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -1,307 +1,14 @@ package scientifik.kmath.expressions -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.Label -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes.* import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Space abstract class AsmCompiledExpression internal constructor( - @JvmField val algebra: Algebra, - @JvmField val constants: MutableList + @JvmField private val algebra: Algebra, + @JvmField private val constants: MutableList ) : Expression { abstract override fun invoke(arguments: Map): T } -internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(expression, collision + 1) -} - -class AsmGenerationContext( - classOfT: Class<*>, - private val algebra: Algebra, - private val className: String -) { - private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { - internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) - } - - private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - - @Suppress("PrivatePropertyName") - private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') - - @Suppress("PrivatePropertyName") - private val T_CLASS: String = classOfT.name.replace('.', '/') - - private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') - private val invokeThisVar: Int = 0 - private val invokeArgumentsVar: Int = 1 - private var maxStack: Int = 0 - private val constants: MutableList = mutableListOf() - private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) - private val invokeMethodVisitor: MethodVisitor - private val invokeL0: Label - private lateinit var invokeL1: Label - private var generatedInstance: AsmCompiledExpression? = null - - init { - asmCompiledClassWriter.visit( - V1_8, - ACC_PUBLIC or ACC_FINAL or ACC_SUPER, - slashesClassName, - "L$ASM_COMPILED_EXPRESSION_CLASS;", - ASM_COMPILED_EXPRESSION_CLASS, - arrayOf() - ) - - asmCompiledClassWriter.run { - visitMethod(ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { - val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 - val l0 = Label() - visitLabel(l0) - visitVarInsn(ALOAD, thisVar) - visitVarInsn(ALOAD, algebraVar) - visitVarInsn(ALOAD, constantsVar) - - visitMethodInsn( - INVOKESPECIAL, - ASM_COMPILED_EXPRESSION_CLASS, - "", - "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", - false - ) - - val l1 = Label() - visitLabel(l1) - visitInsn(RETURN) - val l2 = Label() - visitLabel(l2) - visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) - - visitLocalVariable( - "algebra", - "L$ALGEBRA_CLASS;", - "L$ALGEBRA_CLASS;", - l0, - l2, - algebraVar - ) - - visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) - visitMaxs(3, 3) - visitEnd() - } - - invokeMethodVisitor = visitMethod( - ACC_PUBLIC or ACC_FINAL, - "invoke", - "(L$MAP_CLASS;)L$T_CLASS;", - "(L$MAP_CLASS;)L$T_CLASS;", - null - ) - - invokeMethodVisitor.run { - visitCode() - invokeL0 = Label() - visitLabel(invokeL0) - } - } - } - - @PublishedApi - @Suppress("UNCHECKED_CAST") - internal fun generate(): AsmCompiledExpression { - generatedInstance?.let { return it } - - invokeMethodVisitor.run { - visitInsn(ARETURN) - invokeL1 = Label() - visitLabel(invokeL1) - - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - invokeL0, - invokeL1, - invokeThisVar - ) - - visitLocalVariable( - "arguments", - "L$MAP_CLASS;", - "L$MAP_CLASS;", - invokeL0, - invokeL1, - invokeArgumentsVar - ) - - visitMaxs(maxStack + 1, 2) - visitEnd() - } - - asmCompiledClassWriter.visitMethod( - ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, - "invoke", - "(L$MAP_CLASS;)L$OBJECT_CLASS;", - null, - null - ).run { - val thisVar = 0 - visitCode() - val l0 = Label() - visitLabel(l0) - visitVarInsn(ALOAD, 0) - visitVarInsn(ALOAD, 1) - visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) - visitInsn(ARETURN) - val l1 = Label() - visitLabel(l1) - - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - l0, - l1, - thisVar - ) - - visitMaxs(2, 2) - visitEnd() - } - - asmCompiledClassWriter.visitEnd() - - val new = classLoader - .defineClass(className, asmCompiledClassWriter.toByteArray()) - .constructors - .first() - .newInstance(algebra, constants) as AsmCompiledExpression - - generatedInstance = new - return new - } - - internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) - - private fun visitLoadAnyFromConstants(value: Any, type: String) { - val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - maxStack++ - - invokeMethodVisitor.run { - visitLoadThis() - visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") - visitLdcOrIConstInsn(idx) - visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) - invokeMethodVisitor.visitTypeInsn(CHECKCAST, type) - } - } - - private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar) - - internal fun visitNumberConstant(value: Number) { - maxStack++ - val clazz = value.javaClass - val c = clazz.name.replace('.', '/') - - val sigLetter = SIGNATURE_LETTERS[clazz] - - if (sigLetter != null) { - when (value) { - is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) - is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) - is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) - else -> invokeMethodVisitor.visitLdcInsn(value) - } - - invokeMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) - return - } - - visitLoadAnyFromConstants(value, c) - } - - internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run { - maxStack += 2 - visitVarInsn(ALOAD, invokeArgumentsVar) - - if (defaultValue != null) { - visitLdcInsn(name) - visitLoadFromConstants(defaultValue) - - visitMethodInsn( - INVOKEINTERFACE, - MAP_CLASS, - "getOrDefault", - "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", - true - ) - - visitCastToT() - return - } - - visitLdcInsn(name) - visitMethodInsn(INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) - visitCastToT() - } - - internal fun visitLoadAlgebra() { - invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar) - - invokeMethodVisitor.visitFieldInsn( - GETFIELD, - ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" - ) - - invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS) - } - - internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) { - maxStack++ - invokeMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true) - visitCastToT() - } - - private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS) - - internal companion object { - private val SIGNATURE_LETTERS = mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" - ) - - internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression" - internal const val LIST_CLASS = "java/util/List" - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" - internal const val STRING_CLASS = "java/lang/String" - internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations" - internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations" - internal const val NUMBER_CLASS = "java/lang/Number" - } -} - interface AsmExpression { fun invoke(gen: AsmGenerationContext) } @@ -383,35 +90,3 @@ internal class AsmDivExpression( ) } } - -open class AsmExpressionSpace(space: Space) : Space>, - ExpressionSpace> { - override val zero: AsmExpression = AsmConstantExpression(space.zero) - override fun const(value: T): AsmExpression = AsmConstantExpression(value) - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) - operator fun AsmExpression.plus(arg: T) = this + const(arg) - operator fun AsmExpression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: AsmExpression) = arg + this - operator fun T.minus(arg: AsmExpression) = arg - this -} - -class AsmExpressionField(private val field: Field) : ExpressionField>, - AsmExpressionSpace(field) { - override val one: AsmExpression - get() = const(this.field.one) - - override fun number(value: Number): AsmExpression = const(field.run { one * value }) - - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmProductExpression(a, b) - - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmDivExpression(a, b) - - operator fun AsmExpression.times(arg: T) = this * const(arg) - operator fun AsmExpression.div(arg: T) = this / const(arg) - operator fun T.times(arg: AsmExpression) = arg * this - operator fun T.div(arg: AsmExpression) = arg / this -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt new file mode 100644 index 000000000..cedd5c0fd --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt @@ -0,0 +1,283 @@ +package scientifik.kmath.expressions + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes +import scientifik.kmath.operations.Algebra + +class AsmGenerationContext( + classOfT: Class<*>, + private val algebra: Algebra, + private val className: String +) { + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + + @Suppress("PrivatePropertyName") + private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + + @Suppress("PrivatePropertyName") + private val T_CLASS: String = classOfT.name.replace('.', '/') + + private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + private val invokeThisVar: Int = 0 + private val invokeArgumentsVar: Int = 1 + private var maxStack: Int = 0 + private val constants: MutableList = mutableListOf() + private val asmCompiledClassWriter: ClassWriter = + ClassWriter(0) + private val invokeMethodVisitor: MethodVisitor + private val invokeL0: Label + private lateinit var invokeL1: Label + private var generatedInstance: AsmCompiledExpression? = null + + init { + asmCompiledClassWriter.visit( + Opcodes.V1_8, + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + slashesClassName, + "L$ASM_COMPILED_EXPRESSION_CLASS;", + ASM_COMPILED_EXPRESSION_CLASS, + arrayOf() + ) + + asmCompiledClassWriter.run { + visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = Label() + visitLabel(l0) + visitVarInsn(Opcodes.ALOAD, thisVar) + visitVarInsn(Opcodes.ALOAD, algebraVar) + visitVarInsn(Opcodes.ALOAD, constantsVar) + + visitMethodInsn( + Opcodes.INVOKESPECIAL, + ASM_COMPILED_EXPRESSION_CLASS, + "", + "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", + false + ) + + val l1 = Label() + visitLabel(l1) + visitInsn(Opcodes.RETURN) + val l2 = Label() + visitLabel(l2) + visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + + visitLocalVariable( + "algebra", + "L$ALGEBRA_CLASS;", + "L$ALGEBRA_CLASS;", + l0, + l2, + algebraVar + ) + + visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) + visitMaxs(3, 3) + visitEnd() + } + + invokeMethodVisitor = visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + "invoke", + "(L$MAP_CLASS;)L$T_CLASS;", + "(L$MAP_CLASS;)L$T_CLASS;", + null + ) + + invokeMethodVisitor.run { + visitCode() + invokeL0 = Label() + visitLabel(invokeL0) + } + } + } + + @PublishedApi + @Suppress("UNCHECKED_CAST") + internal fun generate(): AsmCompiledExpression { + generatedInstance?.let { return it } + + invokeMethodVisitor.run { + visitInsn(Opcodes.ARETURN) + invokeL1 = Label() + visitLabel(invokeL1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + invokeL0, + invokeL1, + invokeThisVar + ) + + visitLocalVariable( + "arguments", + "L$MAP_CLASS;", + "L$MAP_CLASS;", + invokeL0, + invokeL1, + invokeArgumentsVar + ) + + visitMaxs(maxStack + 1, 2) + visitEnd() + } + + asmCompiledClassWriter.visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + "invoke", + "(L$MAP_CLASS;)L$OBJECT_CLASS;", + null, + null + ).run { + val thisVar = 0 + visitCode() + val l0 = Label() + visitLabel(l0) + visitVarInsn(Opcodes.ALOAD, 0) + visitVarInsn(Opcodes.ALOAD, 1) + visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) + visitInsn(Opcodes.ARETURN) + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + thisVar + ) + + visitMaxs(2, 2) + visitEnd() + } + + asmCompiledClassWriter.visitEnd() + + val new = classLoader + .defineClass(className, asmCompiledClassWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, constants) as AsmCompiledExpression + + generatedInstance = new + return new + } + + internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) + + private fun visitLoadAnyFromConstants(value: Any, type: String) { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + maxStack++ + + invokeMethodVisitor.run { + visitLoadThis() + visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") + visitLdcOrIConstInsn(idx) + visitMethodInsn(Opcodes.INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) + invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) + } + } + + private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) + + internal fun visitNumberConstant(value: Number) { + maxStack++ + val clazz = value.javaClass + val c = clazz.name.replace('.', '/') + + val sigLetter = SIGNATURE_LETTERS[clazz] + + if (sigLetter != null) { + when (value) { + is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) + is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) + is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) + else -> invokeMethodVisitor.visitLdcInsn(value) + } + + invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) + return + } + + visitLoadAnyFromConstants(value, c) + } + + internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run { + maxStack += 2 + visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) + + if (defaultValue != null) { + visitLdcInsn(name) + visitLoadFromConstants(defaultValue) + + visitMethodInsn( + Opcodes.INVOKEINTERFACE, + MAP_CLASS, + "getOrDefault", + "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", + true + ) + + visitCastToT() + return + } + + visitLdcInsn(name) + visitMethodInsn(Opcodes.INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) + visitCastToT() + } + + internal fun visitLoadAlgebra() { + invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) + + invokeMethodVisitor.visitFieldInsn( + Opcodes.GETFIELD, + ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" + ) + + invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) + } + + internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) { + maxStack++ + invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, owner, method, descriptor, true) + visitCastToT() + } + + private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) + + internal companion object { + private val SIGNATURE_LETTERS = mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) + + internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression" + internal const val LIST_CLASS = "java/util/List" + internal const val MAP_CLASS = "java/util/Map" + internal const val OBJECT_CLASS = "java/lang/Object" + internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + internal const val STRING_CLASS = "java/lang/String" + internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations" + internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations" + internal const val NUMBER_CLASS = "java/lang/Number" + } +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt new file mode 100644 index 000000000..bdbab2b42 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space + +@PublishedApi +internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + + +inline fun asmSpace( + algebra: Space, + block: AsmExpressionSpace.() -> AsmExpression +): Expression { + val expression = AsmExpressionSpace(algebra).block() + val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) + expression.invoke(ctx) + return ctx.generate() +} + +inline fun asmField( + algebra: Field, + block: AsmExpressionField.() -> AsmExpression +): Expression { + val expression = AsmExpressionField(algebra).block() + val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) + expression.invoke(ctx) + return ctx.generate() +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 875080616..4370029f6 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -64,7 +64,7 @@ class AsmTest { ) @Test - fun testCProductWithOtherTypeNumber() = testDoubleExpressionValue( + fun testCProductWithOtherTypeNumber(): Unit = testDoubleExpressionValue( 25.0, AsmConstProductExpression(AsmVariableExpression("x"), 5f), mapOf("x" to 5.0) @@ -81,21 +81,21 @@ class AsmTest { } @Test - fun testCProductWithCustomTypeNumber() = testDoubleExpressionValue( + fun testCProductWithCustomTypeNumber(): Unit = testDoubleExpressionValue( 0.0, AsmConstProductExpression(AsmVariableExpression("x"), CustomZero), mapOf("x" to 5.0) ) @Test - fun testVar() = testDoubleExpressionValue( + fun testVar(): Unit = testDoubleExpressionValue( 10000.0, AsmVariableExpression("x"), mapOf("x" to 10000.0) ) @Test - fun testVarWithDefault() = testDoubleExpressionValue( + fun testVarWithDefault(): Unit = testDoubleExpressionValue( 10000.0, AsmVariableExpression("x", 10000.0), mapOf() From b7d1fe256023acf0b03d48fb77a55c5119b453e8 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 8 Jun 2020 14:38:37 +0700 Subject: [PATCH 016/104] Implement recursive constants evaluation, improve builders --- .../kmath/expressions/AsmExpressionSpaces.kt | 13 +++-- .../kmath/expressions/AsmExpressions.kt | 49 +++++++++++++++---- .../scientifik/kmath/expressions/Builders.kt | 20 ++++---- .../kmath/expressions/Optimization.kt | 6 +++ 4 files changed, 63 insertions(+), 25 deletions(-) create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt index 5c93fd729..da788372b 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt @@ -3,13 +3,13 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space -open class AsmExpressionSpace(space: Space) : Space>, +open class AsmExpressionSpace(private val space: Space) : Space>, ExpressionSpace> { override val zero: AsmExpression = AsmConstantExpression(space.zero) override fun const(value: T): AsmExpression = AsmConstantExpression(value) override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(space, a, b) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this @@ -22,8 +22,11 @@ class AsmExpressionField(private val field: Field) : ExpressionField = const(field.run { one * value }) - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(a, b) - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(a, b) + + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmProductExpression(field, a, b) + + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(field, a, b) operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) operator fun T.times(arg: AsmExpression): AsmExpression = arg * this diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 8708a33e1..cc861d363 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -1,6 +1,6 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.* abstract class AsmCompiledExpression internal constructor( @JvmField private val algebra: Algebra, @@ -10,6 +10,7 @@ abstract class AsmCompiledExpression internal constructor( } interface AsmExpression { + fun tryEvaluate(): T? = null fun invoke(gen: AsmGenerationContext) } @@ -20,13 +21,22 @@ internal class AsmVariableExpression(val name: String, val default: T? = null internal class AsmConstantExpression(val value: T) : AsmExpression { + override fun tryEvaluate(): T = value override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } internal class AsmSumExpression( - val first: AsmExpression, - val second: AsmExpression + private val algebra: SpaceOperations, + first: AsmExpression, + second: AsmExpression ) : AsmExpression { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = algebra { + (first.tryEvaluate() ?: return@algebra null) + (second.tryEvaluate() ?: return@algebra null) + } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() first.invoke(gen) @@ -41,9 +51,17 @@ internal class AsmSumExpression( } internal class AsmProductExpression( - val first: AsmExpression, - val second: AsmExpression + private val algebra: RingOperations, + first: AsmExpression, + second: AsmExpression ) : AsmExpression { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = algebra { + (first.tryEvaluate() ?: return@algebra null) * (second.tryEvaluate() ?: return@algebra null) + } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() first.invoke(gen) @@ -58,9 +76,14 @@ internal class AsmProductExpression( } internal class AsmConstProductExpression( - val expr: AsmExpression, - val const: Number + private val algebra: SpaceOperations, + expr: AsmExpression, + private val const: Number ) : AsmExpression { + private val expr: AsmExpression = expr.optimize() + + override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() gen.visitNumberConstant(const) @@ -75,9 +98,17 @@ internal class AsmConstProductExpression( } internal class AsmDivExpression( - val expr: AsmExpression, - val second: AsmExpression + private val algebra: FieldOperations, + expr: AsmExpression, + second: AsmExpression ) : AsmExpression { + private val expr: AsmExpression = expr.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = algebra { + (expr.tryEvaluate() ?: return@algebra null) / (second.tryEvaluate() ?: return@algebra null) + } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() expr.invoke(gen) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt index bdbab2b42..f8b4afd5b 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt @@ -1,5 +1,6 @@ package scientifik.kmath.expressions +import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space @@ -17,22 +18,19 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String } -inline fun asmSpace( - algebra: Space, - block: AsmExpressionSpace.() -> AsmExpression -): Expression { - val expression = AsmExpressionSpace(algebra).block() +inline fun asm(i: I, algebra: Algebra, block: I.() -> AsmExpression): Expression { + val expression = i.block().optimize() val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) expression.invoke(ctx) return ctx.generate() } +inline fun asmSpace( + algebra: Space, + block: AsmExpressionSpace.() -> AsmExpression +): Expression = asm(AsmExpressionSpace(algebra), algebra, block) + inline fun asmField( algebra: Field, block: AsmExpressionField.() -> AsmExpression -): Expression { - val expression = AsmExpressionField(algebra).block() - val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) - expression.invoke(ctx) - return ctx.generate() -} +): Expression = asm(AsmExpressionField(algebra), algebra, block) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt new file mode 100644 index 000000000..bb7d0476d --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt @@ -0,0 +1,6 @@ +package scientifik.kmath.expressions + +fun AsmExpression.optimize(): AsmExpression { + val a = tryEvaluate() + return if (a == null) this else AsmConstantExpression(a) +} From c576e46020b9a10ce2b6d3769032cc57e6522bbc Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 8 Jun 2020 14:46:00 +0700 Subject: [PATCH 017/104] Minor refactor --- .../kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt | 3 ++- .../scientifik/kmath/expressions/FunctionalExpressions.kt | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt index da788372b..8147bcd5c 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt @@ -2,6 +2,7 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke open class AsmExpressionSpace(private val space: Space) : Space>, ExpressionSpace> { @@ -21,7 +22,7 @@ class AsmExpressionField(private val field: Field) : ExpressionField get() = const(this.field.one) - override fun number(value: Number): AsmExpression = const(field.run { one * value }) + override fun number(value: Number): AsmExpression = const(field { one * value }) override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(field, a, b) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index dbc2eac1e..0304a665f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -3,6 +3,7 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Field import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke internal class VariableExpression(val name: String, val default: T? = null) : Expression { override fun invoke(arguments: Map): T = @@ -57,7 +58,7 @@ open class FunctionalExpressionField( override val one: Expression get() = const(this.field.one) - override fun number(value: Number): Expression = const(field.run { one * value }) + override fun number(value: Number): Expression = const(field { one * value }) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) override fun divide(a: Expression, b: Expression): Expression = DivExpression(field, a, b) operator fun Expression.times(arg: T): Expression = this * const(arg) From 8f1cf0179a95345c3d9e0bd5323eb0b320f4f9dc Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 8 Jun 2020 14:48:44 +0700 Subject: [PATCH 018/104] Minor refactor --- .../scientifik/kmath/expressions/FunctionalExpressions.kt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index 0304a665f..7b5a93465 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -51,9 +51,9 @@ open class FunctionalExpressionSpace(val space: Space) : Space): Expression = arg - this } -open class FunctionalExpressionField( - val field: Field -) : ExpressionField>, FunctionalExpressionSpace(field) { +open class FunctionalExpressionField(val field: Field) : + ExpressionField>, + FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) From 774b1123f729123e94461aee1ddc9517ba8d10df Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 8 Jun 2020 12:07:27 +0300 Subject: [PATCH 019/104] Minor change in grid builders --- .../kmath/linear/LUPDecomposition.kt | 6 +-- .../kotlin/scientifik/kmath/misc/Grids.kt | 44 ++++++++++++------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index 87e0ef027..75d7da169 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -128,14 +128,14 @@ fun , F : Field> GenericMatrixContext.lup( luRow[col] = sum // maintain best permutation choice - if (abs(sum) > largest) { - largest = abs(sum) + if (this@lup.abs(sum) > largest) { + largest = this@lup.abs(sum) max = row } } // Singularity check - if (checkSingular(abs(lu[max, col]))) { + if (checkSingular(this@lup.abs(lu[max, col]))) { error("The matrix is singular") } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt index 90ce5da68..f040fb8d4 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt @@ -1,5 +1,7 @@ package scientifik.kmath.misc +import kotlin.math.abs + /** * Convert double range to sequence. * @@ -8,28 +10,36 @@ package scientifik.kmath.misc * * If step is negative, the same goes from upper boundary downwards */ -fun ClosedFloatingPointRange.toSequence(step: Double): Sequence = - when { - step == 0.0 -> error("Zero step in double progression") - step > 0 -> sequence { - var current = start - while (current <= endInclusive) { - yield(current) - current += step - } - } - else -> sequence { - var current = endInclusive - while (current >= start) { - yield(current) - current += step - } - } +fun ClosedFloatingPointRange.toSequenceWithStep(step: Double): Sequence = when { + step == 0.0 -> error("Zero step in double progression") + step > 0 -> sequence { + var current = start + while (current <= endInclusive) { + yield(current) + current += step } + } + else -> sequence { + var current = endInclusive + while (current >= start) { + yield(current) + current += step + } + } +} + +/** + * Convert double range to sequence with the fixed number of points + */ +fun ClosedFloatingPointRange.toSequenceWithPoints(numPoints: Int): Sequence { + require(numPoints > 1) { "The number of points should be more than 2" } + return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1)) +} /** * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] */ +@Deprecated("Replace by 'toSequenceWithPoints'") fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { if (numPoints < 2) error("Can't create generic grid with less than two points") return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } From 65370f93fb0f425d9c9ed0bc6bb5cbd9a049e47b Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 8 Jun 2020 23:18:08 +0700 Subject: [PATCH 020/104] Make algebra and constants protected, fix tests --- .../scientifik/kmath/expressions/AsmExpressions.kt | 4 ++-- .../kotlin/scientifik/kmath/expressions/AsmTest.kt | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index cc861d363..fc7788589 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -3,8 +3,8 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.* abstract class AsmCompiledExpression internal constructor( - @JvmField private val algebra: Algebra, - @JvmField private val constants: MutableList + @JvmField protected val algebra: Algebra, + @JvmField protected val constants: MutableList ) : Expression { abstract override fun invoke(arguments: Map): T } diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 4370029f6..437fd158f 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -31,7 +31,7 @@ class AsmTest { @Test fun testSum() = testDoubleExpressionValue( 25.0, - AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")), + AsmSumExpression(RealField, AsmConstantExpression(1.0), AsmVariableExpression("x")), mapOf("x" to 24.0) ) @@ -45,28 +45,28 @@ class AsmTest { @Test fun testDiv(): Unit = testDoubleExpressionValue( 0.5, - AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)), + AsmDivExpression(RealField, AsmConstantExpression(1.0), AsmConstantExpression(2.0)), mapOf() ) @Test fun testProduct(): Unit = testDoubleExpressionValue( 25.0, - AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")), + AsmProductExpression(RealField,AsmVariableExpression("x"), AsmVariableExpression("x")), mapOf("x" to 5.0) ) @Test fun testCProduct(): Unit = testDoubleExpressionValue( 25.0, - AsmConstProductExpression(AsmVariableExpression("x"), 5.0), + AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5.0), mapOf("x" to 5.0) ) @Test fun testCProductWithOtherTypeNumber(): Unit = testDoubleExpressionValue( 25.0, - AsmConstProductExpression(AsmVariableExpression("x"), 5f), + AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5f), mapOf("x" to 5.0) ) @@ -83,7 +83,7 @@ class AsmTest { @Test fun testCProductWithCustomTypeNumber(): Unit = testDoubleExpressionValue( 0.0, - AsmConstProductExpression(AsmVariableExpression("x"), CustomZero), + AsmConstProductExpression(RealField,AsmVariableExpression("x"), CustomZero), mapOf("x" to 5.0) ) From 2855ad29a409885aad5decd89e4391330d3820c3 Mon Sep 17 00:00:00 2001 From: darksnake Date: Tue, 9 Jun 2020 11:14:47 +0300 Subject: [PATCH 021/104] Remove unused argument in functionalExpressions.kt --- .../scientifik/kmath/expressions/functionalExpressions.kt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt index ad0acbc4a..90694b04a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -38,8 +38,7 @@ internal class DivExpession(val context: Field, val expr: Expression, v } open class FunctionalExpressionSpace( - val space: Space, - one: T + val space: Space ) : Space>, ExpressionSpace> { override val zero: Expression = ConstantExpression(space.zero) @@ -62,7 +61,7 @@ open class FunctionalExpressionSpace( open class FunctionalExpressionField( val field: Field -) : ExpressionField>, FunctionalExpressionSpace(field, field.one) { +) : ExpressionField>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) From 5dc07febe344baff46c1a512622f8fb085a67a7f Mon Sep 17 00:00:00 2001 From: darksnake Date: Tue, 9 Jun 2020 13:07:40 +0300 Subject: [PATCH 022/104] Expression simplification --- .../commons/expressions/DiffExpression.kt | 5 +- .../kmath/expressions/Expression.kt | 91 ++----------------- .../kmath/expressions/SyntaxTreeNode.kt | 7 -- .../expressions/functionalExpressions.kt | 6 +- .../scientifik/kmath/operations/Algebra.kt | 45 ++++++++- .../kmath/linear/VectorSpaceTest.kt | 0 6 files changed, 56 insertions(+), 98 deletions(-) create mode 100644 kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 65e915ef4..8c19395d3 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,8 +2,9 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionField +import scientifik.kmath.expressions.ExpressionContext import scientifik.kmath.operations.ExtendedField +import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -112,7 +113,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionField { +object DiffExpressionContext : ExpressionContext, Field { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 94bc81413..73745f78f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,7 +1,6 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.Algebra /** * An elementary function that could be invoked on a map of arguments @@ -15,7 +14,7 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext> { +interface ExpressionContext : Algebra { /** * Introduce a variable into expression context */ @@ -25,87 +24,13 @@ interface ExpressionContext> { * A constant expression which does not depend on arguments */ fun const(value: T): E - - fun produce(node: SyntaxTreeNode): E } -interface ExpressionSpace> : Space, ExpressionContext { - - open fun produceSingular(value: String): E = variable(value) - - open fun produceUnary(operation: String, value: E): E { - return when (operation) { - UnaryNode.PLUS_OPERATION -> value - UnaryNode.MINUS_OPERATION -> -value - else -> error("Unary operation $operation is not supported by $this") - } +fun ExpressionContext.produce(node: SyntaxTreeNode): E { + return when (node) { + is NumberNode -> error("Single number nodes are not supported") + is SingularNode -> variable(node.value) + is UnaryNode -> unaryOperation(node.operation, produce(node.value)) + is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) } - - open fun produceBinary(operation: String, left: E, right: E): E { - return when (operation) { - BinaryNode.PLUS_OPERATION -> left + right - BinaryNode.MINUS_OPERATION -> left - right - else -> error("Binary operation $operation is not supported by $this") - } - } - - override fun produce(node: SyntaxTreeNode): E { - return when (node) { - is NumberNode -> error("Single number nodes are not supported") - is SingularNode -> produceSingular(node.value) - is UnaryNode -> produceUnary(node.operation, produce(node.value)) - is BinaryNode -> { - when (node.operation) { - BinaryNode.TIMES_OPERATION -> { - if (node.left is NumberNode) { - return produce(node.right) * node.left.value - } else if (node.right is NumberNode) { - return produce(node.left) * node.right.value - } - } - BinaryNode.DIV_OPERATION -> { - if (node.right is NumberNode) { - return produce(node.left) / node.right.value - } - } - } - produceBinary(node.operation, produce(node.left), produce(node.right)) - } - } - } -} - -interface ExpressionField> : Field, ExpressionSpace { - fun const(value: Double): E = one.times(value) - - override fun produce(node: SyntaxTreeNode): E { - if (node is BinaryNode) { - when (node.operation) { - BinaryNode.PLUS_OPERATION -> { - if (node.left is NumberNode) { - return produce(node.right) + one * node.left.value - } else if (node.right is NumberNode) { - return produce(node.left) + one * node.right.value - } - } - BinaryNode.MINUS_OPERATION -> { - if (node.left is NumberNode) { - return one * node.left.value - produce(node.right) - } else if (node.right is NumberNode) { - return produce(node.left) - one * node.right.value - } - } - } - } - return super.produce(node) - } - - override fun produceBinary(operation: String, left: E, right: E): E { - return when (operation) { - BinaryNode.TIMES_OPERATION -> left * right - BinaryNode.DIV_OPERATION -> left / right - else -> super.produceBinary(operation, left, right) - } - } - } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt index fff3c5bc0..b4038f680 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt @@ -8,9 +8,6 @@ data class NumberNode(val value: Number) : SyntaxTreeNode() data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() { companion object { - const val PLUS_OPERATION = "+" - const val MINUS_OPERATION = "-" - const val NOT_OPERATION = "!" const val ABS_OPERATION = "abs" const val SIN_OPERATION = "sin" const val COS_OPERATION = "cos" @@ -22,10 +19,6 @@ data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxT data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() { companion object { - const val PLUS_OPERATION = "+" - const val MINUS_OPERATION = "-" - const val TIMES_OPERATION = "*" - const val DIV_OPERATION = "/" //TODO add operations } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt index 90694b04a..4f38e3a71 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -39,7 +39,7 @@ internal class DivExpession(val context: Field, val expr: Expression, v open class FunctionalExpressionSpace( val space: Space -) : Space>, ExpressionSpace> { +) : Space>, ExpressionContext> { override val zero: Expression = ConstantExpression(space.zero) @@ -61,12 +61,12 @@ open class FunctionalExpressionSpace( open class FunctionalExpressionField( val field: Field -) : ExpressionField>, FunctionalExpressionSpace(field) { +) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) - override fun const(value: Double): Expression = const(field.run { one*value}) + fun const(value: Double): Expression = const(field.run { one*value}) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 485185526..b3282e5f3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -6,9 +6,12 @@ annotation class KMathContext /** * Marker interface for any algebra */ -interface Algebra +interface Algebra { + fun unaryOperation(operation: String, arg: T): T + fun binaryOperation(operation: String, left: T, right: T): T +} -inline operator fun , R> T.invoke(block: T.() -> R): R = run(block) +inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** * Space-like operations without neutral element @@ -24,7 +27,7 @@ interface SpaceOperations : Algebra { */ fun multiply(a: T, k: Number): T - //Operation to be performed in this context + //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.plus(b: T): T = add(this, b) @@ -32,6 +35,24 @@ interface SpaceOperations : Algebra { operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) operator fun Number.times(b: T) = b * this + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + PLUS_OPERATION -> arg + MINUS_OPERATION -> -arg + else -> error("Unary operation $operation not defined in $this") + } + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + PLUS_OPERATION -> add(left, right) + MINUS_OPERATION -> left - right + else -> error("Binary operation $operation not defined in $this") + } + + companion object { + const val PLUS_OPERATION = "+" + const val MINUS_OPERATION = "-" + const val NOT_OPERATION = "!" + } } @@ -60,6 +81,15 @@ interface RingOperations : SpaceOperations { fun multiply(a: T, b: T): T operator fun T.times(b: T): T = multiply(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + TIMES_OPERATION -> multiply(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object{ + const val TIMES_OPERATION = "*" + } } /** @@ -85,6 +115,15 @@ interface FieldOperations : RingOperations { fun divide(a: T, b: T): T operator fun T.div(b: T): T = divide(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + DIV_OPERATION -> divide(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object{ + const val DIV_OPERATION = "/" + } } /** diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt new file mode 100644 index 000000000..e69de29bb From 33a519c10b10a39d99aad40c210354f6e4909d40 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Tue, 9 Jun 2020 22:22:15 +0700 Subject: [PATCH 023/104] Apply new interfaces structure to ASM Expression Field/Space --- .../kmath/expressions/AsmExpressionSpaces.kt | 11 +++++++---- .../kmath/expressions/FunctionalExpressions.kt | 6 +++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt index 8147bcd5c..38ed00605 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt @@ -4,8 +4,9 @@ import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space import scientifik.kmath.operations.invoke -open class AsmExpressionSpace(private val space: Space) : Space>, - ExpressionSpace> { +open class AsmExpressionSpace(private val space: Space) : + Space>, + ExpressionContext> { override val zero: AsmExpression = AsmConstantExpression(space.zero) override fun const(value: T): AsmExpression = AsmConstantExpression(value) override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) @@ -17,12 +18,14 @@ open class AsmExpressionSpace(private val space: Space) : Space): AsmExpression = arg - this } -class AsmExpressionField(private val field: Field) : ExpressionField>, +class AsmExpressionField(private val field: Field) : + ExpressionContext>, + Field>, AsmExpressionSpace(field) { override val one: AsmExpression get() = const(this.field.one) - override fun number(value: Number): AsmExpression = const(field { one * value }) + fun number(value: Number): AsmExpression = const(field { one * value }) override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(field, a, b) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index 4f38e3a71..97002f664 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -39,7 +39,7 @@ internal class DivExpession(val context: Field, val expr: Expression, v open class FunctionalExpressionSpace( val space: Space -) : Space>, ExpressionContext> { +) : Space>, ExpressionContext> { override val zero: Expression = ConstantExpression(space.zero) @@ -61,12 +61,12 @@ open class FunctionalExpressionSpace( open class FunctionalExpressionField( val field: Field -) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { +) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) - fun const(value: Double): Expression = const(field.run { one*value}) + fun number(value: Number): Expression = const(field.run { one * value }) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) From 1b6a0a13d8308504d7df6b8673b9c321961d518f Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 10 Jun 2020 00:44:56 +0700 Subject: [PATCH 024/104] Completely rework Expression API to expose direct unaryOperation and binaryOperation, improve ASM API accordingly --- .../kmath/expressions/AsmExpressionSpaces.kt | 38 ------ .../kmath/expressions/AsmExpressions.kt | 123 ------------------ .../kmath/expressions/Optimization.kt | 6 - .../expressions/asm/AsmExpressionSpaces.kt | 74 +++++++++++ .../kmath/expressions/asm/AsmExpressions.kt | 106 +++++++++++++++ .../{ => asm}/AsmGenerationContext.kt | 24 ++-- .../kmath/expressions/{ => asm}/Builders.kt | 14 +- .../expressions/{ => asm}/MethodVisitors.kt | 2 +- .../kmath/expressions/asm/Optimization.kt | 9 ++ .../scientifik/kmath/expressions/AsmTest.kt | 97 +++----------- .../expressions/FunctionalExpressions.kt | 114 +++++++++------- 11 files changed, 301 insertions(+), 306 deletions(-) delete mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt delete mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt delete mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt rename kmath-asm/src/main/kotlin/scientifik/kmath/expressions/{ => asm}/AsmGenerationContext.kt (92%) rename kmath-asm/src/main/kotlin/scientifik/kmath/expressions/{ => asm}/Builders.kt (69%) rename kmath-asm/src/main/kotlin/scientifik/kmath/expressions/{ => asm}/MethodVisitors.kt (94%) create mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt deleted file mode 100644 index 38ed00605..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt +++ /dev/null @@ -1,38 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke - -open class AsmExpressionSpace(private val space: Space) : - Space>, - ExpressionContext> { - override val zero: AsmExpression = AsmConstantExpression(space.zero) - override fun const(value: T): AsmExpression = AsmConstantExpression(value) - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(space, a, b) - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) - operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) - operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) - operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this - operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this -} - -class AsmExpressionField(private val field: Field) : - ExpressionContext>, - Field>, - AsmExpressionSpace(field) { - override val one: AsmExpression - get() = const(this.field.one) - - fun number(value: Number): AsmExpression = const(field { one * value }) - - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmProductExpression(field, a, b) - - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(field, a, b) - operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) - operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) - operator fun T.times(arg: AsmExpression): AsmExpression = arg * this - operator fun T.div(arg: AsmExpression): AsmExpression = arg / this -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt deleted file mode 100644 index fc7788589..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ /dev/null @@ -1,123 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.* - -abstract class AsmCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: MutableList -) : Expression { - abstract override fun invoke(arguments: Map): T -} - -interface AsmExpression { - fun tryEvaluate(): T? = null - fun invoke(gen: AsmGenerationContext) -} - -internal class AsmVariableExpression(val name: String, val default: T? = null) : - AsmExpression { - override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) -} - -internal class AsmConstantExpression(val value: T) : - AsmExpression { - override fun tryEvaluate(): T = value - override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) -} - -internal class AsmSumExpression( - private val algebra: SpaceOperations, - first: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val first: AsmExpression = first.optimize() - private val second: AsmExpression = second.optimize() - - override fun tryEvaluate(): T? = algebra { - (first.tryEvaluate() ?: return@algebra null) + (second.tryEvaluate() ?: return@algebra null) - } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - first.invoke(gen) - second.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, - method = "add", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} - -internal class AsmProductExpression( - private val algebra: RingOperations, - first: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val first: AsmExpression = first.optimize() - private val second: AsmExpression = second.optimize() - - override fun tryEvaluate(): T? = algebra { - (first.tryEvaluate() ?: return@algebra null) * (second.tryEvaluate() ?: return@algebra null) - } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - first.invoke(gen) - second.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.RING_OPERATIONS_CLASS, - method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} - -internal class AsmConstProductExpression( - private val algebra: SpaceOperations, - expr: AsmExpression, - private val const: Number -) : AsmExpression { - private val expr: AsmExpression = expr.optimize() - - override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - gen.visitNumberConstant(const) - expr.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, - method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} - -internal class AsmDivExpression( - private val algebra: FieldOperations, - expr: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val expr: AsmExpression = expr.optimize() - private val second: AsmExpression = second.optimize() - - override fun tryEvaluate(): T? = algebra { - (expr.tryEvaluate() ?: return@algebra null) / (second.tryEvaluate() ?: return@algebra null) - } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - expr.invoke(gen) - second.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.FIELD_OPERATIONS_CLASS, - method = "divide", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt deleted file mode 100644 index bb7d0476d..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt +++ /dev/null @@ -1,6 +0,0 @@ -package scientifik.kmath.expressions - -fun AsmExpression.optimize(): AsmExpression { - val a = tryEvaluate() - return if (a == null) this else AsmConstantExpression(a) -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt new file mode 100644 index 000000000..17b6dc023 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt @@ -0,0 +1,74 @@ +package scientifik.kmath.expressions.asm + +import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.operations.* + +open class AsmExpressionAlgebra(val algebra: Algebra) : + Algebra>, + ExpressionContext> { + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) +} + +open class AsmExpressionSpace( + val space: Space +) : AsmExpressionAlgebra(space), Space> { + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override val zero: AsmExpression = AsmConstantExpression(space.zero) + + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) + operator fun AsmExpression.plus(arg: T) = this + const(arg) + operator fun AsmExpression.minus(arg: T) = this - const(arg) + operator fun T.plus(arg: AsmExpression) = arg + this + operator fun T.minus(arg: AsmExpression) = arg - this +} + +open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace(ring), Ring> { + override val one: AsmExpression + get() = const(this.ring.one) + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + fun number(value: Number): AsmExpression = const(ring { one * value }) + + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) + + operator fun AsmExpression.times(arg: T) = this * const(arg) + operator fun T.times(arg: AsmExpression) = arg * this +} + +open class AsmExpressionField(private val field: Field) : + AsmExpressionRing(field), + Field> { + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) + + operator fun AsmExpression.div(arg: T) = this / const(arg) + operator fun T.div(arg: AsmExpression) = arg / this +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt new file mode 100644 index 000000000..d7d655d6e --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt @@ -0,0 +1,106 @@ +package scientifik.kmath.expressions.asm + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke + +interface AsmExpression { + fun tryEvaluate(): T? = null + fun invoke(gen: AsmGenerationContext) +} + +internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : + AsmExpression { + private val expr: AsmExpression = expr.optimize() + + override fun tryEvaluate(): T? = context { + unaryOperation( + name, + expr.tryEvaluate() ?: return@context null + ) + } + + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + gen.visitStringConstant(name) + expr.invoke(gen) + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.ALGEBRA_CLASS, + method = "unaryOperation", + descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmBinaryOperation( + private val context: Algebra, + private val name: String, + first: AsmExpression, + second: AsmExpression +) : AsmExpression { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = context { + binaryOperation( + name, + first.tryEvaluate() ?: return@context null, + second.tryEvaluate() ?: return@context null + ) + } + + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + gen.visitStringConstant(name) + first.invoke(gen) + second.invoke(gen) + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.ALGEBRA_CLASS, + method = "binaryOperation", + descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmVariableExpression(private val name: String, private val default: T? = null) : AsmExpression { + override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) +} + +internal class AsmConstantExpression(private val value: T) : AsmExpression { + override fun tryEvaluate(): T = value + override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) +} + +internal class AsmConstProductExpression(private val context: Space, expr: AsmExpression, private val const: Number) : + AsmExpression { + private val expr: AsmExpression = expr.optimize() + + override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } + + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + gen.visitNumberConstant(const) + expr.invoke(gen) + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + method = "multiply", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal abstract class FunctionalCompiledExpression internal constructor( + @JvmField protected val algebra: Algebra, + @JvmField protected val constants: MutableList +) : Expression { + abstract override fun invoke(arguments: Map): T +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt similarity index 92% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt rename to kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt index cedd5c0fd..3e47da5e3 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions +package scientifik.kmath.expressions.asm import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label @@ -33,15 +33,15 @@ class AsmGenerationContext( private val invokeMethodVisitor: MethodVisitor private val invokeL0: Label private lateinit var invokeL1: Label - private var generatedInstance: AsmCompiledExpression? = null + private var generatedInstance: FunctionalCompiledExpression? = null init { asmCompiledClassWriter.visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, slashesClassName, - "L$ASM_COMPILED_EXPRESSION_CLASS;", - ASM_COMPILED_EXPRESSION_CLASS, + "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, arrayOf() ) @@ -58,7 +58,7 @@ class AsmGenerationContext( visitMethodInsn( Opcodes.INVOKESPECIAL, - ASM_COMPILED_EXPRESSION_CLASS, + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", false @@ -103,7 +103,7 @@ class AsmGenerationContext( @PublishedApi @Suppress("UNCHECKED_CAST") - internal fun generate(): AsmCompiledExpression { + internal fun generate(): FunctionalCompiledExpression { generatedInstance?.let { return it } invokeMethodVisitor.run { @@ -170,7 +170,7 @@ class AsmGenerationContext( .defineClass(className, asmCompiledClassWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants) as AsmCompiledExpression + .newInstance(algebra, constants) as FunctionalCompiledExpression generatedInstance = new return new @@ -245,7 +245,7 @@ class AsmGenerationContext( invokeMethodVisitor.visitFieldInsn( Opcodes.GETFIELD, - ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" ) invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) @@ -259,6 +259,10 @@ class AsmGenerationContext( private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) + internal fun visitStringConstant(string: String) { + invokeMethodVisitor.visitLdcInsn(string) + } + internal companion object { private val SIGNATURE_LETTERS = mapOf( java.lang.Byte::class.java to "B", @@ -269,15 +273,13 @@ class AsmGenerationContext( java.lang.Double::class.java to "D" ) - internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression" + internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/asm/FunctionalCompiledExpression" internal const val LIST_CLASS = "java/util/List" internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" internal const val STRING_CLASS = "java/lang/String" - internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations" - internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations" internal const val NUMBER_CLASS = "java/lang/Number" } } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt similarity index 69% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt rename to kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt index f8b4afd5b..d55555567 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt @@ -1,7 +1,9 @@ -package scientifik.kmath.expressions +package scientifik.kmath.expressions.asm +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space @PublishedApi @@ -25,11 +27,21 @@ inline fun asm(i: I, algebra: Algebra, block: I.() -> AsmExpre return ctx.generate() } +inline fun asmAlgebra( + algebra: Algebra, + block: AsmExpressionAlgebra.() -> AsmExpression +): Expression = asm(AsmExpressionAlgebra(algebra), algebra, block) + inline fun asmSpace( algebra: Space, block: AsmExpressionSpace.() -> AsmExpression ): Expression = asm(AsmExpressionSpace(algebra), algebra, block) +inline fun asmRing( + algebra: Ring, + block: AsmExpressionRing.() -> AsmExpression +): Expression = asm(AsmExpressionRing(algebra), algebra, block) + inline fun asmField( algebra: Field, block: AsmExpressionField.() -> AsmExpression diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt similarity index 94% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt rename to kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt index 9cdb0672f..9f697ab6d 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions +package scientifik.kmath.expressions.asm import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt new file mode 100644 index 000000000..db57a690c --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt @@ -0,0 +1,9 @@ +package scientifik.kmath.expressions.asm + +import scientifik.kmath.expressions.asm.AsmConstantExpression +import scientifik.kmath.expressions.asm.AsmExpression + +fun AsmExpression.optimize(): AsmExpression { + val a = tryEvaluate() + return if (a == null) this else AsmConstantExpression(a) +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 437fd158f..55c240cef 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -1,103 +1,40 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Algebra +import scientifik.kmath.expressions.asm.AsmExpression +import scientifik.kmath.expressions.asm.AsmExpressionField +import scientifik.kmath.expressions.asm.asmField import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals class AsmTest { - private fun testExpressionValue( - expectedValue: T, - expr: AsmExpression, - arguments: Map, - algebra: Algebra, - clazz: Class<*> - ): Unit = assertEquals( - expectedValue, AsmGenerationContext(clazz, algebra, "TestAsmCompiled") - .also(expr::invoke) - .generate() - .invoke(arguments) - ) - - @Suppress("UNCHECKED_CAST") - private fun testDoubleExpressionValue( - expectedValue: Double, - expr: AsmExpression, - arguments: Map, - algebra: Algebra = RealField, - clazz: Class = java.lang.Double::class.java as Class - ): Unit = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) + private fun testDoubleExpression( + expected: Double?, + arguments: Map = emptyMap(), + block: AsmExpressionField.() -> AsmExpression + ): Unit = assertEquals(expected = expected, actual = asmField(RealField, block)(arguments)) @Test - fun testSum() = testDoubleExpressionValue( - 25.0, - AsmSumExpression(RealField, AsmConstantExpression(1.0), AsmVariableExpression("x")), - mapOf("x" to 24.0) - ) + fun testConstantsSum() = testDoubleExpression(16.0) { const(8.0) + 8.0 } @Test - fun testConst(): Unit = testDoubleExpressionValue( - 123.0, - AsmConstantExpression(123.0), - mapOf() - ) + fun testVarsSum() = testDoubleExpression(1000.0, mapOf("x" to 500.0)) { variable("x") + 500.0 } @Test - fun testDiv(): Unit = testDoubleExpressionValue( - 0.5, - AsmDivExpression(RealField, AsmConstantExpression(1.0), AsmConstantExpression(2.0)), - mapOf() - ) + fun testProduct() = testDoubleExpression(24.0) { const(4.0) * const(6.0) } @Test - fun testProduct(): Unit = testDoubleExpressionValue( - 25.0, - AsmProductExpression(RealField,AsmVariableExpression("x"), AsmVariableExpression("x")), - mapOf("x" to 5.0) - ) + fun testConstantProduct() = testDoubleExpression(984.0) { const(8.0) * 123 } @Test - fun testCProduct(): Unit = testDoubleExpressionValue( - 25.0, - AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5.0), - mapOf("x" to 5.0) - ) + fun testSubtraction() = testDoubleExpression(2.0) { const(4.0) - 2.0 } @Test - fun testCProductWithOtherTypeNumber(): Unit = testDoubleExpressionValue( - 25.0, - AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5f), - mapOf("x" to 5.0) - ) - - object CustomZero : Number() { - override fun toByte(): Byte = 0 - override fun toChar(): Char = 0.toChar() - override fun toDouble(): Double = 0.0 - override fun toFloat(): Float = 0f - override fun toInt(): Int = 0 - override fun toLong(): Long = 0L - override fun toShort(): Short = 0 - } + fun testDivision() = testDoubleExpression(64.0) { const(128.0) / 2 } @Test - fun testCProductWithCustomTypeNumber(): Unit = testDoubleExpressionValue( - 0.0, - AsmConstProductExpression(RealField,AsmVariableExpression("x"), CustomZero), - mapOf("x" to 5.0) - ) + fun testDirectCall() = testDoubleExpression(4096.0) { binaryOperation("*", const(64.0), const(64.0)) } - @Test - fun testVar(): Unit = testDoubleExpressionValue( - 10000.0, - AsmVariableExpression("x"), - mapOf("x" to 10000.0) - ) - - @Test - fun testVarWithDefault(): Unit = testDoubleExpressionValue( - 10000.0, - AsmVariableExpression("x", 10000.0), - mapOf() - ) +// @Test +// fun testSine() = testDoubleExpression(0.0) { unaryOperation("sin", const(PI)) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index 97002f664..b36ea8b52 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -1,80 +1,102 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.* -internal class VariableExpression(val name: String, val default: T? = null) : Expression { +internal class FunctionalUnaryOperation(val context: Algebra, val name: String, val expr: Expression) : + Expression { + override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) +} + +internal class FunctionalBinaryOperation( + val context: Algebra, + val name: String, + val first: Expression, + val second: Expression +) : Expression { + override fun invoke(arguments: Map): T = + context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) +} + +internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { override fun invoke(arguments: Map): T = arguments[name] ?: default ?: error("Parameter not found: $name") } -internal class ConstantExpression(val value: T) : Expression { +internal class FunctionalConstantExpression(val value: T) : Expression { override fun invoke(arguments: Map): T = value } -internal class SumExpression( - val context: Space, - val first: Expression, - val second: Expression -) : Expression { - override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = - context.multiply(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : +internal class FunctionalConstProductExpression(val context: Space, val expr: Expression, val const: Number) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) +open class FunctionalExpressionAlgebra(val algebra: Algebra) : + Algebra>, + ExpressionContext> { + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) + + override fun const(value: T): Expression = FunctionalConstantExpression(value) + override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) } -open class FunctionalExpressionSpace( - val space: Space -) : Space>, ExpressionContext> { +open class FunctionalExpressionSpace(val space: Space) : + FunctionalExpressionAlgebra(space), + Space> { + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) - override val zero: Expression = ConstantExpression(space.zero) + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) - override fun const(value: T): Expression = ConstantExpression(value) - - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) + override val zero: Expression = FunctionalConstantExpression(space.zero) + override fun add(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) + override fun multiply(a: Expression, k: Number): Expression = FunctionalConstProductExpression(space, a, k) operator fun Expression.plus(arg: T) = this + const(arg) operator fun Expression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: Expression) = arg + this operator fun T.minus(arg: Expression) = arg - this } -open class FunctionalExpressionField( - val field: Field -) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { - +open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpressionSpace(ring), Ring> { override val one: Expression - get() = const(this.field.one) + get() = const(this.ring.one) - fun number(value: Number): Expression = const(field.run { one * value }) + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) - override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) + fun number(value: Number): Expression = const(ring { one * value }) + + override fun multiply(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - operator fun T.times(arg: Expression) = arg * this +} + +open class FunctionalExpressionField(val field: Field) : + FunctionalExpressionRing(field), + Field> { + + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) + + override fun divide(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) + + operator fun Expression.div(arg: T) = this / const(arg) operator fun T.div(arg: Expression) = arg / this -} \ No newline at end of file +} From 2e5c13aea99b8ecc82c867ca859353df80ea194c Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 10 Jun 2020 02:05:13 +0700 Subject: [PATCH 025/104] Improve support of string-identified operations API, rework trigonometric operations algebra part: introduce inverse trigonometric operations, rename tg to tan --- .../commons/expressions/DiffExpression.kt | 7 +- .../scientifik/kmath/operations/Complex.kt | 8 +- .../kmath/operations/NumberAlgebra.kt | 12 +- .../kmath/operations/OptionalOperations.kt | 34 +++-- .../kmath/structures/ComplexNDField.kt | 7 ++ .../kmath/structures/ExtendedNDField.kt | 2 +- .../kmath/structures/RealBufferField.kt | 117 ++++++++++++------ .../kmath/structures/RealNDField.kt | 7 ++ 8 files changed, 140 insertions(+), 54 deletions(-) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 8c19395d3..88d32378e 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -59,8 +59,10 @@ class DerivativeStructureField( override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() - override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() + override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() + override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { is Double -> arg.pow(pow) @@ -136,6 +138,3 @@ object DiffExpressionContext : ExpressionContext, Field< override fun divide(a: DiffExpression, b: DiffExpression) = DiffExpression { a.function(this) / b.function(this) } } - - - diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 6c529f55e..ea9425bef 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -8,6 +8,8 @@ import scientifik.memory.MemorySpec import scientifik.memory.MemoryWriter import kotlin.math.* +private val PI_DIV_2 = Complex(PI / 2, 0) + /** * A field for complex numbers */ @@ -30,9 +32,11 @@ object ComplexField : ExtendedFieldOperations, Field { return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) } - override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg)) - + override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 + override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg) + override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg) + override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2 override fun power(arg: Complex, pow: Number): Complex = arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 9639e4c28..3d942beac 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -8,7 +8,7 @@ import kotlin.math.pow as kpow */ interface ExtendedFieldOperations : FieldOperations, - TrigonometricOperations, + InverseTrigonometricOperations, PowerOperations, ExponentialOperations @@ -44,6 +44,10 @@ object RealField : ExtendedField, Norm { override inline fun sin(arg: Double) = kotlin.math.sin(arg) override inline fun cos(arg: Double) = kotlin.math.cos(arg) + override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) + override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) + override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) + override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) @@ -75,6 +79,10 @@ object FloatField : ExtendedField, Norm { override inline fun sin(arg: Float) = kotlin.math.sin(arg) override inline fun cos(arg: Float) = kotlin.math.cos(arg) + override inline fun tan(arg: Float) = kotlin.math.tan(arg) + override inline fun acos(arg: Float) = kotlin.math.acos(arg) + override inline fun asin(arg: Float) = kotlin.math.asin(arg) + override inline fun atan(arg: Float) = kotlin.math.atan(arg) override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) @@ -180,4 +188,4 @@ object LongRing : Ring, Norm { override inline fun Long.minus(b: Long) = (this - b) override inline fun Long.times(b: Long) = (this * b) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index bd83932e7..bbd7a110e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -13,16 +13,33 @@ package scientifik.kmath.operations interface TrigonometricOperations : FieldOperations { fun sin(arg: T): T fun cos(arg: T): T + fun tan(arg: T): T = sin(arg) / cos(arg) - fun tg(arg: T): T = sin(arg) / cos(arg) + companion object { + const val SIN_OPERATION = "sin" + const val COS_OPERATION = "cos" + const val TAN_OPERATION = "tan" + } +} - fun ctg(arg: T): T = cos(arg) / sin(arg) +interface InverseTrigonometricOperations : TrigonometricOperations { + fun asin(arg: T): T + fun acos(arg: T): T + fun atan(arg: T): T + + companion object { + const val ASIN_OPERATION = "asin" + const val ACOS_OPERATION = "acos" + const val ATAN_OPERATION = "atan" + } } fun >> sin(arg: T): T = arg.context.sin(arg) fun >> cos(arg: T): T = arg.context.cos(arg) -fun >> tg(arg: T): T = arg.context.tg(arg) -fun >> ctg(arg: T): T = arg.context.ctg(arg) +fun >> tan(arg: T): T = arg.context.tan(arg) +fun >> asin(arg: T): T = arg.context.asin(arg) +fun >> acos(arg: T): T = arg.context.acos(arg) +fun >> atan(arg: T): T = arg.context.atan(arg) /* Power and roots */ @@ -32,8 +49,11 @@ fun >> ctg(arg: T): T = arg.conte interface PowerOperations : Algebra { fun power(arg: T, pow: Number): T fun sqrt(arg: T) = power(arg, 0.5) - infix fun T.pow(pow: Number) = power(this, pow) + + companion object { + const val SQRT_OPERATION = "sqrt" + } } infix fun >> T.pow(power: Double): T = context.power(this, power) @@ -42,7 +62,7 @@ fun >> sqr(arg: T): T = arg pow 2.0 /* Exponential */ -interface ExponentialOperations: Algebra { +interface ExponentialOperations : Algebra { fun exp(arg: T): T fun ln(arg: T): T } @@ -54,4 +74,4 @@ interface Norm { fun norm(arg: T): R } -fun >, R> norm(arg: T): R = arg.context.norm(arg) \ No newline at end of file +fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt index a79366a99..c7e672c28 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt @@ -79,6 +79,13 @@ class ComplexNDField(override val shape: IntArray) : override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): NDBuffer = map(arg) {acos(it)} + + override fun atan(arg: NDBuffer): NDBuffer = map(arg) {atan(it)} } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 3437644ff..c986ff011 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -4,7 +4,7 @@ import scientifik.kmath.operations.* interface ExtendedNDField> : NDField, - TrigonometricOperations, + InverseTrigonometricOperations, PowerOperations, ExponentialOperations where F : ExtendedFieldOperations, F : Field diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 88c8c29db..2fb6d15d4 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -9,8 +9,10 @@ import kotlin.math.* * A simple field over linear buffers of [Double] */ object RealBufferFieldOperations : ExtendedFieldOperations> { + override fun add(a: Buffer, b: Buffer): DoubleBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array @@ -22,6 +24,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun multiply(a: Buffer, k: Number): DoubleBuffer { val kValue = k.toDouble() + return if (a is DoubleBuffer) { val aArray = a.array DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) @@ -32,6 +35,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array @@ -43,6 +47,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun divide(a: Buffer, b: Buffer): DoubleBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array @@ -52,49 +57,67 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } } - override fun sin(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - } + override fun sin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } - override fun cos(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) - } + override fun cos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - } + override fun tan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { tan(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { tan(arg[it]) }) } - override fun exp(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - } + override fun asin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { asin(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { asin(arg[it]) }) } - override fun ln(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) - } + override fun acos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { acos(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { acos(arg[it]) }) + } + + override fun atan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { atan(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { atan(arg[it]) }) + } + + override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + } + + override fun exp(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + } + + override fun ln(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } } @@ -119,7 +142,6 @@ class RealBufferField(val size: Int) : ExtendedField> { return RealBufferFieldOperations.multiply(a, b) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) @@ -135,6 +157,26 @@ class RealBufferField(val size: Int) : ExtendedField> { return RealBufferFieldOperations.cos(arg) } + override fun tan(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.tan(arg) + } + + override fun asin(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.asin(arg) + } + + override fun acos(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.acos(arg) + } + + override fun atan(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.atan(arg) + } + override fun power(arg: Buffer, pow: Number): DoubleBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) @@ -149,5 +191,4 @@ class RealBufferField(val size: Int) : ExtendedField> { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } - -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 8c1bd4239..22b33aa4d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -74,6 +74,13 @@ class RealNDField(override val shape: IntArray) : override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): NDBuffer = map(arg) { acos(it) } + + override fun atan(arg: NDBuffer): NDBuffer = map(arg) { atan(it) } } From a0453da4b34706cd7853f49cd4b7de6fa69aa179 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Wed, 10 Jun 2020 08:57:17 +0700 Subject: [PATCH 026/104] Refactor, replace constants List with Array, create specification of named functions --- kmath-asm/build.gradle.kts | 1 + .../expressions/asm/AsmExpressionSpaces.kt | 16 ++--- .../kmath/expressions/asm/AsmExpressions.kt | 59 +++++++++++++++---- .../expressions/asm/AsmGenerationContext.kt | 17 +++--- .../kmath/expressions/asm/Builders.kt | 12 ++-- .../kmath/expressions/asm/Optimization.kt | 6 +- .../kmath/expressions/Expression.kt | 14 ++--- .../expressions/FunctionalExpressions.kt | 16 ++--- 8 files changed, 89 insertions(+), 52 deletions(-) diff --git a/kmath-asm/build.gradle.kts b/kmath-asm/build.gradle.kts index 4104349f8..3b004eb8e 100644 --- a/kmath-asm/build.gradle.kts +++ b/kmath-asm/build.gradle.kts @@ -6,4 +6,5 @@ dependencies { api(project(path = ":kmath-core")) implementation("org.ow2.asm:asm:8.0.1") implementation("org.ow2.asm:asm-commons:8.0.1") + implementation(kotlin("reflect")) } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt index 17b6dc023..24260b871 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt @@ -31,10 +31,10 @@ open class AsmExpressionSpace( AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) - operator fun AsmExpression.plus(arg: T) = this + const(arg) - operator fun AsmExpression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: AsmExpression) = arg + this - operator fun T.minus(arg: AsmExpression) = arg - this + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this } open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace(ring), Ring> { @@ -52,8 +52,8 @@ open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace< override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) - operator fun AsmExpression.times(arg: T) = this * const(arg) - operator fun T.times(arg: AsmExpression) = arg * this + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this } open class AsmExpressionField(private val field: Field) : @@ -69,6 +69,6 @@ open class AsmExpressionField(private val field: Field) : override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) - operator fun AsmExpression.div(arg: T) = this / const(arg) - operator fun T.div(arg: AsmExpression) = arg / this + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt index d7d655d6e..662e1279e 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt @@ -4,28 +4,55 @@ import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Space import scientifik.kmath.operations.invoke +import kotlin.reflect.full.memberFunctions +import kotlin.reflect.jvm.jvmName interface AsmExpression { fun tryEvaluate(): T? = null fun invoke(gen: AsmGenerationContext) } +internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { + context::class.memberFunctions.find { it.name == name && it.parameters.size == arity } + ?: return false + + return true +} + +internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { + context::class.memberFunctions.find { it.name == name && it.parameters.size == arity } + ?: return false + + val owner = context::class.jvmName.replace('.', '/') + + val sig = buildString { + append('(') + repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } + append(')') + append("L${AsmGenerationContext.OBJECT_CLASS};") + } + + visitAlgebraOperation(owner = owner, method = name, descriptor = sig) + + return true +} + internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : AsmExpression { private val expr: AsmExpression = expr.optimize() - - override fun tryEvaluate(): T? = context { - unaryOperation( - name, - expr.tryEvaluate() ?: return@context null - ) - } + override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() - gen.visitStringConstant(name) + + if (!hasSpecific(context, name, 1)) + gen.visitStringConstant(name) + expr.invoke(gen) + if (gen.tryInvokeSpecific(context, name, 1)) + return + gen.visitAlgebraOperation( owner = AsmGenerationContext.ALGEBRA_CLASS, method = "unaryOperation", @@ -55,10 +82,16 @@ internal class AsmBinaryOperation( override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() - gen.visitStringConstant(name) + + if (!hasSpecific(context, name, 1)) + gen.visitStringConstant(name) + first.invoke(gen) second.invoke(gen) + if (gen.tryInvokeSpecific(context, name, 1)) + return + gen.visitAlgebraOperation( owner = AsmGenerationContext.ALGEBRA_CLASS, method = "binaryOperation", @@ -79,7 +112,11 @@ internal class AsmConstantExpression(private val value: T) : AsmExpression override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } -internal class AsmConstProductExpression(private val context: Space, expr: AsmExpression, private val const: Number) : +internal class AsmConstProductExpression( + private val context: Space, + expr: AsmExpression, + private val const: Number +) : AsmExpression { private val expr: AsmExpression = expr.optimize() @@ -100,7 +137,7 @@ internal class AsmConstProductExpression(private val context: Space, expr: internal abstract class FunctionalCompiledExpression internal constructor( @JvmField protected val algebra: Algebra, - @JvmField protected val constants: MutableList + @JvmField protected val constants: Array ) : Expression { abstract override fun invoke(arguments: Map): T } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt index 3e47da5e3..0375b7fc4 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt @@ -46,7 +46,7 @@ class AsmGenerationContext( ) asmCompiledClassWriter.run { - visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { + visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 @@ -60,7 +60,7 @@ class AsmGenerationContext( Opcodes.INVOKESPECIAL, FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "", - "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", + "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", false ) @@ -80,7 +80,7 @@ class AsmGenerationContext( algebraVar ) - visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) + visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) visitMaxs(3, 3) visitEnd() } @@ -170,7 +170,7 @@ class AsmGenerationContext( .defineClass(className, asmCompiledClassWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants) as FunctionalCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression generatedInstance = new return new @@ -184,9 +184,9 @@ class AsmGenerationContext( invokeMethodVisitor.run { visitLoadThis() - visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") + visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;") visitLdcOrIConstInsn(idx) - visitMethodInsn(Opcodes.INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) + visitInsn(Opcodes.AALOAD) invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) } } @@ -273,8 +273,9 @@ class AsmGenerationContext( java.lang.Double::class.java to "D" ) - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/asm/FunctionalCompiledExpression" - internal const val LIST_CLASS = "java/util/List" + internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = + "scientifik/kmath/expressions/asm/FunctionalCompiledExpression" + internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt index d55555567..663bd55ba 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt @@ -1,10 +1,8 @@ package scientifik.kmath.expressions.asm import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.operations.* @PublishedApi internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { @@ -20,7 +18,11 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String } -inline fun asm(i: I, algebra: Algebra, block: I.() -> AsmExpression): Expression { +inline fun >> asm( + i: E, + algebra: Algebra, + block: E.() -> AsmExpression +): Expression { val expression = i.block().optimize() val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) expression.invoke(ctx) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt index db57a690c..bcd182dc3 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt @@ -1,9 +1,7 @@ package scientifik.kmath.expressions.asm -import scientifik.kmath.expressions.asm.AsmConstantExpression -import scientifik.kmath.expressions.asm.AsmExpression - -fun AsmExpression.optimize(): AsmExpression { +@PublishedApi +internal fun AsmExpression.optimize(): AsmExpression { val a = tryEvaluate() return if (a == null) this else AsmConstantExpression(a) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 73745f78f..8d8bd3a7a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -26,11 +26,9 @@ interface ExpressionContext : Algebra { fun const(value: T): E } -fun ExpressionContext.produce(node: SyntaxTreeNode): E { - return when (node) { - is NumberNode -> error("Single number nodes are not supported") - is SingularNode -> variable(node.value) - is UnaryNode -> unaryOperation(node.operation, produce(node.value)) - is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) - } -} \ No newline at end of file +fun ExpressionContext.produce(node: SyntaxTreeNode): E = when (node) { + is NumberNode -> error("Single number nodes are not supported") + is SingularNode -> variable(node.value) + is UnaryNode -> unaryOperation(node.operation, produce(node.value)) + is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index b36ea8b52..d20a5dbbf 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -59,10 +59,10 @@ open class FunctionalExpressionSpace(val space: Space) : FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: Expression, k: Number): Expression = FunctionalConstProductExpression(space, a, k) - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this + operator fun Expression.plus(arg: T): Expression = this + const(arg) + operator fun Expression.minus(arg: T): Expression = this - const(arg) + operator fun T.plus(arg: Expression): Expression = arg + this + operator fun T.minus(arg: Expression): Expression = arg - this } open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpressionSpace(ring), Ring> { @@ -80,8 +80,8 @@ open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpression override fun multiply(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) - operator fun Expression.times(arg: T) = this * const(arg) - operator fun T.times(arg: Expression) = arg * this + operator fun Expression.times(arg: T): Expression = this * const(arg) + operator fun T.times(arg: Expression): Expression = arg * this } open class FunctionalExpressionField(val field: Field) : @@ -97,6 +97,6 @@ open class FunctionalExpressionField(val field: Field) : override fun divide(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) - operator fun Expression.div(arg: T) = this / const(arg) - operator fun T.div(arg: Expression) = arg / this + operator fun Expression.div(arg: T): Expression = this / const(arg) + operator fun T.div(arg: Expression): Expression = arg / this } From 927aa589ad4277d8a95aea2c402128a8183e59e7 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 11 Jun 2020 08:49:38 +0700 Subject: [PATCH 027/104] Add missing override modifiers --- .../scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt | 2 +- .../scientifik/kmath/expressions/FunctionalExpressions.kt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt index 24260b871..78f6ee351 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt @@ -47,7 +47,7 @@ open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace< override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, operation, left, right) - fun number(value: Number): AsmExpression = const(ring { one * value }) + override fun number(value: Number): AsmExpression = const(ring { one * value }) override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index d20a5dbbf..ce9b4756c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -75,7 +75,7 @@ open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpression override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = FunctionalBinaryOperation(algebra, operation, left, right) - fun number(value: Number): Expression = const(ring { one * value }) + override fun number(value: Number): Expression = const(ring { one * value }) override fun multiply(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) From e6f97c532b645d019fd76dfe7fec99fac5d55899 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 11 Jun 2020 08:50:37 +0700 Subject: [PATCH 028/104] Minor refactor: replace space property with field --- .../scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt index 78f6ee351..232710d6c 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt @@ -67,7 +67,7 @@ open class AsmExpressionField(private val field: Field) : AsmBinaryOperation(algebra, operation, left, right) override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) + AsmBinaryOperation(field, FieldOperations.DIV_OPERATION, a, b) operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) operator fun T.div(arg: AsmExpression): AsmExpression = arg / this From 89fae390134b1c5ff27c4864b9ef094356d49036 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 11 Jun 2020 09:27:23 +0700 Subject: [PATCH 029/104] Improve tests --- .../expressions/asm/AsmGenerationContext.kt | 3 ++- .../scientifik/kmath/expressions/AsmTest.kt | 23 +++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt index 0375b7fc4..92663ba8f 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt @@ -215,7 +215,7 @@ class AsmGenerationContext( visitLoadAnyFromConstants(value, c) } - internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run { + internal fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { maxStack += 2 visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) @@ -241,6 +241,7 @@ class AsmGenerationContext( } internal fun visitLoadAlgebra() { + maxStack++ invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) invokeMethodVisitor.visitFieldInsn( diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 55c240cef..93c52faac 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -3,6 +3,8 @@ package scientifik.kmath.expressions import scientifik.kmath.expressions.asm.AsmExpression import scientifik.kmath.expressions.asm.AsmExpressionField import scientifik.kmath.expressions.asm.asmField +import scientifik.kmath.expressions.asm.asmRing +import scientifik.kmath.operations.IntRing import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals @@ -15,26 +17,29 @@ class AsmTest { ): Unit = assertEquals(expected = expected, actual = asmField(RealField, block)(arguments)) @Test - fun testConstantsSum() = testDoubleExpression(16.0) { const(8.0) + 8.0 } + fun testConstantsSum(): Unit = testDoubleExpression(16.0) { const(8.0) + 8.0 } @Test - fun testVarsSum() = testDoubleExpression(1000.0, mapOf("x" to 500.0)) { variable("x") + 500.0 } + fun testVarsSum(): Unit = testDoubleExpression(1000.0, mapOf("x" to 500.0)) { variable("x") + 500.0 } @Test - fun testProduct() = testDoubleExpression(24.0) { const(4.0) * const(6.0) } + fun testProduct(): Unit = testDoubleExpression(24.0) { const(4.0) * const(6.0) } @Test - fun testConstantProduct() = testDoubleExpression(984.0) { const(8.0) * 123 } + fun testConstantProduct(): Unit = testDoubleExpression(984.0) { const(8.0) * 123 } @Test - fun testSubtraction() = testDoubleExpression(2.0) { const(4.0) - 2.0 } + fun testVarsConstantProductVar(): Unit = testDoubleExpression(984.0, mapOf("x" to 8.0)) { variable("x") * 123 } @Test - fun testDivision() = testDoubleExpression(64.0) { const(128.0) / 2 } + fun testSubtraction(): Unit = testDoubleExpression(2.0) { const(4.0) - 2.0 } @Test - fun testDirectCall() = testDoubleExpression(4096.0) { binaryOperation("*", const(64.0), const(64.0)) } + fun testDivision(): Unit = testDoubleExpression(64.0) { const(128.0) / 2 } -// @Test -// fun testSine() = testDoubleExpression(0.0) { unaryOperation("sin", const(PI)) } + @Test + fun testDirectUnaryCall(): Unit = testDoubleExpression(64.0) { unaryOperation("+", const(64.0)) } + + @Test + fun testDirectBinaryCall(): Unit = testDoubleExpression(4096.0) { binaryOperation("*", const(64.0), const(64.0)) } } From 0507bfcc24ada2bdc52ad5ceadaafd4d7f0d0ccb Mon Sep 17 00:00:00 2001 From: darksnake Date: Thu, 11 Jun 2020 09:46:42 +0300 Subject: [PATCH 030/104] Expression simplification --- .../kmath/expressions/Expression.kt | 9 ----- .../kmath/expressions/SyntaxTreeNode.kt | 32 ++++++++++++++--- .../scientifik/kmath/operations/Algebra.kt | 36 ++++++++++++++++--- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 73745f78f..aba7357e8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -24,13 +24,4 @@ interface ExpressionContext : Algebra { * A constant expression which does not depend on arguments */ fun const(value: T): E -} - -fun ExpressionContext.produce(node: SyntaxTreeNode): E { - return when (node) { - is NumberNode -> error("Single number nodes are not supported") - is SingularNode -> variable(node.value) - is UnaryNode -> unaryOperation(node.operation, produce(node.value)) - is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) - } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt index b4038f680..e56165aad 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt @@ -1,11 +1,25 @@ package scientifik.kmath.expressions +import scientifik.kmath.operations.NumericAlgebra + +/** + * A syntax tree node for mathematical expressions + */ sealed class SyntaxTreeNode +/** + * A node containing unparsed string + */ data class SingularNode(val value: String) : SyntaxTreeNode() +/** + * A node containing a number + */ data class NumberNode(val value: Number) : SyntaxTreeNode() +/** + * A node containing an unary operation + */ data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() { companion object { const val ABS_OPERATION = "abs" @@ -17,12 +31,22 @@ data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxT } } +/** + * A node containing binary operation + */ data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() { - companion object { - //TODO add operations - } + companion object } //TODO add a function with positional arguments -//TODO add a function with named arguments \ No newline at end of file +//TODO add a function with named arguments + +fun NumericAlgebra.compile(node: SyntaxTreeNode): T{ + return when (node) { + is NumberNode -> number(node.value) + is SingularNode -> raw(node.value) + is UnaryNode -> unaryOperation(node.operation, compile(node.value)) + is BinaryNode -> binaryOperation(node.operation, compile(node.left), compile(node.right)) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index b3282e5f3..c9fbc1c8b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -7,10 +7,35 @@ annotation class KMathContext * Marker interface for any algebra */ interface Algebra { + /** + * Wrap raw string or variable + */ + fun raw(value: String): T = error("Wrapping of '$value' is not supported in $this") + + /** + * Dynamic call of unary operation with name [operation] on [arg] + */ fun unaryOperation(operation: String, arg: T): T + + /** + * Dynamic call of binary operation [operation] on [left] and [right] + */ fun binaryOperation(operation: String, left: T, right: T): T } +/** + * An algebra with numeric representation of members + */ +interface NumericAlgebra : Algebra { + /** + * Wrap a number + */ + fun number(value: Number): T +} + +/** + * Call a block with an [Algebra] as receiver + */ inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** @@ -87,7 +112,7 @@ interface RingOperations : SpaceOperations { else -> super.binaryOperation(operation, left, right) } - companion object{ + companion object { const val TIMES_OPERATION = "*" } } @@ -95,13 +120,16 @@ interface RingOperations : SpaceOperations { /** * The same as {@link Space} but with additional multiplication operation */ -interface Ring : Space, RingOperations { +interface Ring : Space, RingOperations, NumericAlgebra { /** * neutral operation for multiplication */ val one: T -// operator fun T.plus(b: Number) = this.plus(b * one) + override fun number(value: Number): T = one * value.toDouble() + + // those operators are blocked by type conflict in RealField + // operator fun T.plus(b: Number) = this.plus(b * one) // operator fun Number.plus(b: T) = b + this // // operator fun T.minus(b: Number) = this.minus(b * one) @@ -121,7 +149,7 @@ interface FieldOperations : RingOperations { else -> super.binaryOperation(operation, left, right) } - companion object{ + companion object { const val DIV_OPERATION = "/" } } From f46615d3bc3e036bd33d59912ae8e3a82f6e1b9b Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 12 Jun 2020 08:43:47 +0300 Subject: [PATCH 031/104] Left and right-side operations in Algebra --- .../kmath/expressions/MathSyntaxTree.kt | 61 +++++++++++++++++++ .../kmath/expressions/SyntaxTreeNode.kt | 52 ---------------- .../scientifik/kmath/operations/Algebra.kt | 16 ++++- .../kmath/operations/NumberAlgebra.kt | 28 ++++++++- .../kmath/operations/OptionalOperations.kt | 26 +++++--- 5 files changed, 120 insertions(+), 63 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/MathSyntaxTree.kt delete mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/MathSyntaxTree.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/MathSyntaxTree.kt new file mode 100644 index 000000000..fbc055f80 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/MathSyntaxTree.kt @@ -0,0 +1,61 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.operations.RealField + +/** + * A syntax tree node for mathematical expressions + */ +sealed class MathSyntaxTree + +/** + * A node containing unparsed string + */ +data class SingularNode(val value: String) : MathSyntaxTree() + +/** + * A node containing a number + */ +data class NumberNode(val value: Number) : MathSyntaxTree() + +/** + * A node containing an unary operation + */ +data class UnaryNode(val operation: String, val value: MathSyntaxTree) : MathSyntaxTree() { + companion object { + const val ABS_OPERATION = "abs" + //TODO add operations + } +} + +/** + * A node containing binary operation + */ +data class BinaryNode(val operation: String, val left: MathSyntaxTree, val right: MathSyntaxTree) : MathSyntaxTree() { + companion object +} + +//TODO add a function with positional arguments + +//TODO add a function with named arguments + +fun NumericAlgebra.compile(node: MathSyntaxTree): T { + return when (node) { + is NumberNode -> number(node.value) + is SingularNode -> raw(node.value) + is UnaryNode -> unaryOperation(node.operation, compile(node.value)) + is BinaryNode -> when { + node.left is NumberNode && node.right is NumberNode -> { + val number = RealField.binaryOperation( + node.operation, + node.left.value.toDouble(), + node.right.value.toDouble() + ) + number(number) + } + node.left is NumberNode -> leftSideNumberOperation(node.operation, node.left.value, compile(node.right)) + node.right is NumberNode -> rightSideNumberOperation(node.operation, compile(node.left), node.right.value) + else -> binaryOperation(node.operation, compile(node.left), compile(node.right)) + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt deleted file mode 100644 index e56165aad..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/SyntaxTreeNode.kt +++ /dev/null @@ -1,52 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.NumericAlgebra - -/** - * A syntax tree node for mathematical expressions - */ -sealed class SyntaxTreeNode - -/** - * A node containing unparsed string - */ -data class SingularNode(val value: String) : SyntaxTreeNode() - -/** - * A node containing a number - */ -data class NumberNode(val value: Number) : SyntaxTreeNode() - -/** - * A node containing an unary operation - */ -data class UnaryNode(val operation: String, val value: SyntaxTreeNode) : SyntaxTreeNode() { - companion object { - const val ABS_OPERATION = "abs" - const val SIN_OPERATION = "sin" - const val COS_OPERATION = "cos" - const val EXP_OPERATION = "exp" - const val LN_OPERATION = "ln" - //TODO add operations - } -} - -/** - * A node containing binary operation - */ -data class BinaryNode(val operation: String, val left: SyntaxTreeNode, val right: SyntaxTreeNode) : SyntaxTreeNode() { - companion object -} - -//TODO add a function with positional arguments - -//TODO add a function with named arguments - -fun NumericAlgebra.compile(node: SyntaxTreeNode): T{ - return when (node) { - is NumberNode -> number(node.value) - is SingularNode -> raw(node.value) - is UnaryNode -> unaryOperation(node.operation, compile(node.value)) - is BinaryNode -> binaryOperation(node.operation, compile(node.left), compile(node.right)) - } -} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index c9fbc1c8b..166287ec7 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -31,6 +31,12 @@ interface NumericAlgebra : Algebra { * Wrap a number */ fun number(value: Number): T + + fun leftSideNumberOperation(operation: String, left: Number, right: T): T = + binaryOperation(operation, number(left), right) + + fun rightSideNumberOperation(operation: String, left: T, right: Number): T = + leftSideNumberOperation(operation, right, left) } /** @@ -128,8 +134,14 @@ interface Ring : Space, RingOperations, NumericAlgebra { override fun number(value: Number): T = one * value.toDouble() - // those operators are blocked by type conflict in RealField - // operator fun T.plus(b: Number) = this.plus(b * one) + override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { + RingOperations.TIMES_OPERATION -> left * right + else -> super.leftSideNumberOperation(operation, left, right) + } + + //TODO those operators are blocked by type conflict in RealField + +// operator fun T.plus(b: Number) = this.plus(b * one) // operator fun Number.plus(b: T) = b + this // // operator fun T.minus(b: Number) = this.minus(b * one) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 9639e4c28..e844d404e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -10,9 +10,30 @@ interface ExtendedFieldOperations : FieldOperations, TrigonometricOperations, PowerOperations, - ExponentialOperations + ExponentialOperations { -interface ExtendedField : ExtendedFieldOperations, Field + override fun tan(arg: T): T = sin(arg) / cos(arg) + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + TrigonometricOperations.COS_OPERATION -> cos(arg) + TrigonometricOperations.SIN_OPERATION -> sin(arg) + PowerOperations.SQRT_OPERATION -> sqrt(arg) + ExponentialOperations.EXP_OPERATION -> exp(arg) + ExponentialOperations.LN_OPERATION -> ln(arg) + else -> super.unaryOperation(operation, arg) + } + +} + +interface ExtendedField : ExtendedFieldOperations, Field { + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T { + return when (operation) { + PowerOperations.POW_OPERATION -> power(left, right) + else -> super.rightSideNumberOperation(operation, left, right) + } + + } +} /** * Real field element wrapping double. @@ -44,6 +65,7 @@ object RealField : ExtendedField, Norm { override inline fun sin(arg: Double) = kotlin.math.sin(arg) override inline fun cos(arg: Double) = kotlin.math.cos(arg) + override inline fun tan(arg: Double) = kotlin.math.tan(arg) override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) @@ -76,6 +98,8 @@ object FloatField : ExtendedField, Norm { override inline fun sin(arg: Float) = kotlin.math.sin(arg) override inline fun cos(arg: Float) = kotlin.math.cos(arg) + override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) + override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) override inline fun exp(arg: Float) = kotlin.math.exp(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index bd83932e7..1b58b7254 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -10,30 +10,37 @@ package scientifik.kmath.operations * It also allows to override behavior for optional operations * */ -interface TrigonometricOperations : FieldOperations { +interface TrigonometricOperations { fun sin(arg: T): T fun cos(arg: T): T - fun tg(arg: T): T = sin(arg) / cos(arg) + fun tan(arg: T): T - fun ctg(arg: T): T = cos(arg) / sin(arg) + companion object { + const val SIN_OPERATION = "sin" + const val COS_OPERATION = "cos" + } } fun >> sin(arg: T): T = arg.context.sin(arg) fun >> cos(arg: T): T = arg.context.cos(arg) -fun >> tg(arg: T): T = arg.context.tg(arg) -fun >> ctg(arg: T): T = arg.context.ctg(arg) +fun >> tan(arg: T): T = arg.context.tan(arg) /* Power and roots */ /** * A context extension to include power operations like square roots, etc */ -interface PowerOperations : Algebra { +interface PowerOperations { fun power(arg: T, pow: Number): T fun sqrt(arg: T) = power(arg, 0.5) infix fun T.pow(pow: Number) = power(this, pow) + + companion object { + const val POW_OPERATION = "pow" + const val SQRT_OPERATION = "sqrt" + } } infix fun >> T.pow(power: Double): T = context.power(this, power) @@ -42,9 +49,14 @@ fun >> sqr(arg: T): T = arg pow 2.0 /* Exponential */ -interface ExponentialOperations: Algebra { +interface ExponentialOperations { fun exp(arg: T): T fun ln(arg: T): T + + companion object { + const val EXP_OPERATION = "exp" + const val LN_OPERATION = "ln" + } } fun >> exp(arg: T): T = arg.context.exp(arg) From 5e92d85c4676c27c45fbba43d58a12569a80a0a7 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 12 Jun 2020 10:40:59 +0300 Subject: [PATCH 032/104] Separate ast module --- kmath-ast/build.gradle.kts | 16 +++++++++++++ .../scientifik/kmath/ast}/MathSyntaxTree.kt | 2 +- .../kotlin/scientifik/kmath/ast/asm.kt | 23 +++++++++++++++++++ .../commons/expressions/DiffExpression.kt | 4 ++-- .../kmath/expressions/Expression.kt | 2 +- .../expressions/functionalExpressions.kt | 13 +++++++---- .../scientifik/kmath/operations/bigNumbers.kt | 7 +++++- settings.gradle.kts | 1 + 8 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 kmath-ast/build.gradle.kts rename {kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions => kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast}/MathSyntaxTree.kt (98%) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts new file mode 100644 index 000000000..f3d37bf65 --- /dev/null +++ b/kmath-ast/build.gradle.kts @@ -0,0 +1,16 @@ +plugins { + id("scientifik.mpp") +} + +repositories{ + maven("https://dl.bintray.com/hotkeytlt/maven") +} + +kotlin.sourceSets { + commonMain { + dependencies { + api(project(":kmath-core")) + implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha3") + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/MathSyntaxTree.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MathSyntaxTree.kt similarity index 98% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/MathSyntaxTree.kt rename to kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MathSyntaxTree.kt index fbc055f80..b43ef705d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/MathSyntaxTree.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MathSyntaxTree.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions +package scientifik.kmath.ast import scientifik.kmath.operations.NumericAlgebra import scientifik.kmath.operations.RealField diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt new file mode 100644 index 000000000..6c4325859 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt @@ -0,0 +1,23 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra + +//TODO stubs for asm generation + +interface AsmExpression + +interface AsmExpressionAlgebra> : NumericAlgebra> { + val algebra: A +} + +fun AsmExpression.compile(): Expression = TODO() + +//TODO add converter for functional expressions + +inline fun > A.asm( + block: AsmExpressionAlgebra.() -> AsmExpression +): Expression = TODO() + +inline fun > A.asm(ast: MathSyntaxTree): Expression = asm { compile(ast) } \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 8c19395d3..f6f61e08a 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,7 +2,7 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty @@ -113,7 +113,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionContext, Field { +object DiffExpressionAlgebra : ExpressionAlgebra, Field { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index aba7357e8..df5c981f5 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -14,7 +14,7 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext : Algebra { +interface ExpressionAlgebra : Algebra { /** * Introduce a variable into expression context */ diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt index 4f38e3a71..c3861d256 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -21,8 +21,11 @@ internal class SumExpression( override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) } -internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : - Expression { +internal class ProductExpression( + val context: Ring, + val first: Expression, + val second: Expression +) : Expression { override fun invoke(arguments: Map): T = context.multiply(first.invoke(arguments), second.invoke(arguments)) } @@ -39,7 +42,7 @@ internal class DivExpession(val context: Field, val expr: Expression, v open class FunctionalExpressionSpace( val space: Space -) : Space>, ExpressionContext> { +) : Space>, ExpressionAlgebra> { override val zero: Expression = ConstantExpression(space.zero) @@ -61,12 +64,12 @@ open class FunctionalExpressionSpace( open class FunctionalExpressionField( val field: Field -) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { +) : Field>, ExpressionAlgebra>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) - fun const(value: Double): Expression = const(field.run { one*value}) + fun const(value: Double): Expression = const(field.run { one * value }) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt index 76ca199c5..e6f09c040 100644 --- a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt @@ -1,10 +1,12 @@ package scientifik.kmath.operations -import scientifik.kmath.structures.* import java.math.BigDecimal import java.math.BigInteger import java.math.MathContext +/** + * A field wrapper for Java [BigInteger] + */ object JBigIntegerField : Field { override val zero: BigInteger = BigInteger.ZERO override val one: BigInteger = BigInteger.ONE @@ -18,6 +20,9 @@ object JBigIntegerField : Field { override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) } +/** + * A Field wrapper for Java [BigDecimal] + */ class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field { override val zero: BigDecimal = BigDecimal.ZERO override val one: BigDecimal = BigDecimal.ONE diff --git a/settings.gradle.kts b/settings.gradle.kts index 57173250b..465ecfca8 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -44,5 +44,6 @@ include( ":kmath-dimensions", ":kmath-for-real", ":kmath-geometry", + ":kmath-ast", ":examples" ) From 047af8c172278cdf077e3e9db5f86aff3387f515 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 12 Jun 2020 11:11:13 +0300 Subject: [PATCH 033/104] Fix ND extendend fields --- kmath-ast/build.gradle.kts | 2 +- .../kotlin/scientifik/kmath/operations/Complex.kt | 2 +- .../scientifik/kmath/structures/ExtendedNDField.kt | 9 ++------- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index f3d37bf65..f17809381 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -10,7 +10,7 @@ kotlin.sourceSets { commonMain { dependencies { api(project(":kmath-core")) - implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha3") + implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") } } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 6c529f55e..e33e4078e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -11,7 +11,7 @@ import kotlin.math.* /** * A field for complex numbers */ -object ComplexField : ExtendedFieldOperations, Field { +object ComplexField : ExtendedField { override val zero: Complex = Complex(0.0, 0.0) override val one: Complex = Complex(1.0, 0.0) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 3437644ff..776cff880 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -1,13 +1,8 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.* +import scientifik.kmath.operations.ExtendedField -interface ExtendedNDField> : - NDField, - TrigonometricOperations, - PowerOperations, - ExponentialOperations - where F : ExtendedFieldOperations, F : Field +interface ExtendedNDField, N : NDStructure> : NDField, ExtendedField ///** From 39907a1da239d0eba4f7affbb3438ddcfeb89b1f Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 12 Jun 2020 19:50:28 +0700 Subject: [PATCH 034/104] Add name adapting of *, +, / --- .../kmath/expressions/asm/AsmExpressions.kt | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt index 662e1279e..70f1fc28e 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt @@ -12,15 +12,21 @@ interface AsmExpression { fun invoke(gen: AsmGenerationContext) } +internal val methodNameAdapters = mapOf("+" to "add", "*" to "multiply", "/" to "divide") + internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { - context::class.memberFunctions.find { it.name == name && it.parameters.size == arity } + val aName = methodNameAdapters[name] ?: name + + context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } ?: return false return true } internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - context::class.memberFunctions.find { it.name == name && it.parameters.size == arity } + val aName = methodNameAdapters[name] ?: name + + context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } ?: return false val owner = context::class.jvmName.replace('.', '/') @@ -32,7 +38,7 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, append("L${AsmGenerationContext.OBJECT_CLASS};") } - visitAlgebraOperation(owner = owner, method = name, descriptor = sig) + visitAlgebraOperation(owner = owner, method = aName, descriptor = sig) return true } From 2751cee9267b95a7370cba9e0565f56c3b29d7cf Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 12 Jun 2020 16:56:58 +0300 Subject: [PATCH 035/104] MST expression --- kmath-ast/build.gradle.kts | 5 ++ .../kotlin/scientifik/kmath/ast/MST.kt | 62 +++++++++++++++++++ .../scientifik/kmath/ast/MSTExpression.kt | 19 ++++++ .../scientifik/kmath/ast/MathSyntaxTree.kt | 61 ------------------ .../kotlin/scientifik/kmath/ast/asm.kt | 2 +- .../kotlin/scientifik/kmath/ast/parser.kt | 56 +++++++++++++++++ .../kotlin/scietifik/kmath/ast/ParserTest.kt | 17 +++++ .../kmath/expressions/Expression.kt | 6 ++ .../expressions/functionalExpressions.kt | 17 ++++- .../scientifik/kmath/operations/Complex.kt | 6 ++ 10 files changed, 186 insertions(+), 65 deletions(-) create mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt create mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt delete mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MathSyntaxTree.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index f17809381..88540e7b8 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -13,4 +13,9 @@ kotlin.sourceSets { implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") } } + jvmMain{ + dependencies{ + implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3") + } + } } \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt new file mode 100644 index 000000000..eba2c3343 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -0,0 +1,62 @@ +package scientifik.kmath.ast + +import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.operations.RealField + +/** + * A Mathematical Syntax Tree node for mathematical expressions + */ +sealed class MST { + + /** + * A node containing unparsed string + */ + data class Singular(val value: String) : MST() + + /** + * A node containing a number + */ + data class Numeric(val value: Number) : MST() + + /** + * A node containing an unary operation + */ + data class Unary(val operation: String, val value: MST) : MST() { + companion object { + const val ABS_OPERATION = "abs" + //TODO add operations + } + } + + /** + * A node containing binary operation + */ + data class Binary(val operation: String, val left: MST, val right: MST) : MST() { + companion object + } +} + +//TODO add a function with positional arguments + +//TODO add a function with named arguments + +fun NumericAlgebra.evaluate(node: MST): T { + return when (node) { + is MST.Numeric -> number(node.value) + is MST.Singular -> raw(node.value) + is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) + is MST.Binary -> when { + node.left is MST.Numeric && node.right is MST.Numeric -> { + val number = RealField.binaryOperation( + node.operation, + node.left.value.toDouble(), + node.right.value.toDouble() + ) + number(number) + } + node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) + node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) + else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + } + } +} \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt new file mode 100644 index 000000000..c0f124a35 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt @@ -0,0 +1,19 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.NumericAlgebra + +/** + * The expression evaluates MST on-flight + */ +class MSTExpression(val algebra: NumericAlgebra, val mst: MST) : Expression { + + /** + * Substitute algebra raw value + */ + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra by algebra { + override fun raw(value: String): T = arguments[value] ?: super.raw(value) + } + + override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MathSyntaxTree.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MathSyntaxTree.kt deleted file mode 100644 index b43ef705d..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MathSyntaxTree.kt +++ /dev/null @@ -1,61 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.operations.NumericAlgebra -import scientifik.kmath.operations.RealField - -/** - * A syntax tree node for mathematical expressions - */ -sealed class MathSyntaxTree - -/** - * A node containing unparsed string - */ -data class SingularNode(val value: String) : MathSyntaxTree() - -/** - * A node containing a number - */ -data class NumberNode(val value: Number) : MathSyntaxTree() - -/** - * A node containing an unary operation - */ -data class UnaryNode(val operation: String, val value: MathSyntaxTree) : MathSyntaxTree() { - companion object { - const val ABS_OPERATION = "abs" - //TODO add operations - } -} - -/** - * A node containing binary operation - */ -data class BinaryNode(val operation: String, val left: MathSyntaxTree, val right: MathSyntaxTree) : MathSyntaxTree() { - companion object -} - -//TODO add a function with positional arguments - -//TODO add a function with named arguments - -fun NumericAlgebra.compile(node: MathSyntaxTree): T { - return when (node) { - is NumberNode -> number(node.value) - is SingularNode -> raw(node.value) - is UnaryNode -> unaryOperation(node.operation, compile(node.value)) - is BinaryNode -> when { - node.left is NumberNode && node.right is NumberNode -> { - val number = RealField.binaryOperation( - node.operation, - node.left.value.toDouble(), - node.right.value.toDouble() - ) - number(number) - } - node.left is NumberNode -> leftSideNumberOperation(node.operation, node.left.value, compile(node.right)) - node.right is NumberNode -> rightSideNumberOperation(node.operation, compile(node.left), node.right.value) - else -> binaryOperation(node.operation, compile(node.left), compile(node.right)) - } - } -} \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt index 6c4325859..c01648fb0 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt @@ -20,4 +20,4 @@ inline fun > A.asm( block: AsmExpressionAlgebra.() -> AsmExpression ): Expression = TODO() -inline fun > A.asm(ast: MathSyntaxTree): Expression = asm { compile(ast) } \ No newline at end of file +inline fun > A.asm(ast: MST): Expression = asm { evaluate(ast) } \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt new file mode 100644 index 000000000..dcab1c972 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt @@ -0,0 +1,56 @@ +package scientifik.kmath.ast + +import com.github.h0tk3y.betterParse.combinators.* +import com.github.h0tk3y.betterParse.grammar.Grammar +import com.github.h0tk3y.betterParse.grammar.parseToEnd +import com.github.h0tk3y.betterParse.grammar.parser +import com.github.h0tk3y.betterParse.grammar.tryParseToEnd +import com.github.h0tk3y.betterParse.parser.ParseResult +import com.github.h0tk3y.betterParse.parser.Parser +import scientifik.kmath.operations.FieldOperations +import scientifik.kmath.operations.PowerOperations +import scientifik.kmath.operations.RingOperations +import scientifik.kmath.operations.SpaceOperations + +private object ArithmeticsEvaluator : Grammar() { + val num by token("-?[\\d.]+(?:[eE]-?\\d+)?") + val lpar by token("\\(") + val rpar by token("\\)") + val mul by token("\\*") + val pow by token("\\^") + val div by token("/") + val minus by token("-") + val plus by token("\\+") + val ws by token("\\s+", ignore = true) + + val number: Parser by num use { MST.Numeric(text.toDouble()) } + + val term: Parser by number or + (skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or + (skip(lpar) and parser(this::rootParser) and skip(rpar)) + + val powChain by leftAssociative(term, pow) { a, _, b -> + MST.Binary(PowerOperations.POW_OPERATION, a, b) + } + + val divMulChain: Parser by leftAssociative(powChain, div or mul use { type }) { a, op, b -> + if (op == div) { + MST.Binary(FieldOperations.DIV_OPERATION, a, b) + } else { + MST.Binary(RingOperations.TIMES_OPERATION, a, b) + } + } + + val subSumChain: Parser by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b -> + if (op == plus) { + MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) + } else { + MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) + } + } + + override val rootParser: Parser by subSumChain +} + +fun String.tryParseMath(): ParseResult = ArithmeticsEvaluator.tryParseToEnd(this) +fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) \ No newline at end of file diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt new file mode 100644 index 000000000..6849d24b8 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -0,0 +1,17 @@ +package scietifik.kmath.ast + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import scientifik.kmath.ast.evaluate +import scientifik.kmath.ast.parseMath +import scientifik.kmath.operations.Complex +import scientifik.kmath.operations.ComplexField + +internal class ParserTest{ + @Test + fun parsedExpression(){ + val mst = "2+2*(2+2)".parseMath() + val res = ComplexField.evaluate(mst) + assertEquals(Complex(10.0,0.0), res) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index df5c981f5..eaf3cd1d7 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -7,6 +7,12 @@ import scientifik.kmath.operations.Algebra */ interface Expression { operator fun invoke(arguments: Map): T + + companion object { + operator fun invoke(block: (Map) -> T): Expression = object : Expression { + override fun invoke(arguments: Map): T = block(arguments) + } + } } operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt index c3861d256..6d7676c25 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt @@ -66,8 +66,7 @@ open class FunctionalExpressionField( val field: Field ) : Field>, ExpressionAlgebra>, FunctionalExpressionSpace(field) { - override val one: Expression - get() = const(this.field.one) + override val one: Expression = ConstantExpression(this.field.one) fun const(value: Double): Expression = const(field.run { one * value }) @@ -80,4 +79,16 @@ open class FunctionalExpressionField( operator fun T.times(arg: Expression) = arg * this operator fun T.div(arg: Expression) = arg / this -} \ No newline at end of file +} + +/** + * Create a functional expression on this [Space] + */ +fun Space.buildExpression(block: FunctionalExpressionSpace.() -> Expression): Expression = + FunctionalExpressionSpace(this).run(block) + +/** + * Create a functional expression on this [Field] + */ +fun Field.buildExpression(block: FunctionalExpressionField.() -> Expression): Expression = + FunctionalExpressionField(this).run(block) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index e33e4078e..01aef824d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -50,6 +50,12 @@ object ComplexField : ExtendedField { operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) + + override fun raw(value: String): Complex = if (value == "i") { + i + } else { + super.raw(value) + } } /** From 09641a5c9c8e4a6af739d984e4a0c634c702f934 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 12 Jun 2020 16:59:36 +0300 Subject: [PATCH 036/104] Documentation --- .../commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt | 2 +- kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt index c0f124a35..3ee454b2a 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt @@ -4,7 +4,7 @@ import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.NumericAlgebra /** - * The expression evaluates MST on-flight + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. */ class MSTExpression(val algebra: NumericAlgebra, val mst: MST) : Expression { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt index dcab1c972..cec61a8ff 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt @@ -12,6 +12,9 @@ import scientifik.kmath.operations.PowerOperations import scientifik.kmath.operations.RingOperations import scientifik.kmath.operations.SpaceOperations +/** + * TODO move to common + */ private object ArithmeticsEvaluator : Grammar() { val num by token("-?[\\d.]+(?:[eE]-?\\d+)?") val lpar by token("\\(") From b902f3980a7021a5bdbb0ddc36e45e545d750a95 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 12 Jun 2020 22:10:56 +0700 Subject: [PATCH 037/104] Fix issues found after merge --- .../expressions/asm/AsmExpressionSpaces.kt | 4 +-- .../kmath/expressions/asm/Builders.kt | 12 ++++---- .../scientifik/kmath/expressions/AsmTest.kt | 6 ++-- .../scientifik/kmath/expressions/Builders.kt | 30 +++++++++++++++++++ .../expressions/FunctionalExpressions.kt | 2 +- settings.gradle.kts | 1 + 6 files changed, 42 insertions(+), 13 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt index 232710d6c..0f8f685eb 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt @@ -1,11 +1,11 @@ package scientifik.kmath.expressions.asm -import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.* open class AsmExpressionAlgebra(val algebra: Algebra) : Algebra>, - ExpressionContext> { + ExpressionAlgebra> { override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = AsmUnaryOperation(algebra, operation, arg) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt index 663bd55ba..c56adb767 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt @@ -1,7 +1,7 @@ package scientifik.kmath.expressions.asm import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.* @PublishedApi @@ -18,7 +18,7 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String } -inline fun >> asm( +inline fun >> asm( i: E, algebra: Algebra, block: E.() -> AsmExpression @@ -29,22 +29,22 @@ inline fun >> asm( return ctx.generate() } -inline fun asmAlgebra( +inline fun buildAsmAlgebra( algebra: Algebra, block: AsmExpressionAlgebra.() -> AsmExpression ): Expression = asm(AsmExpressionAlgebra(algebra), algebra, block) -inline fun asmSpace( +inline fun buildAsmSpace( algebra: Space, block: AsmExpressionSpace.() -> AsmExpression ): Expression = asm(AsmExpressionSpace(algebra), algebra, block) -inline fun asmRing( +inline fun buildAsmRing( algebra: Ring, block: AsmExpressionRing.() -> AsmExpression ): Expression = asm(AsmExpressionRing(algebra), algebra, block) -inline fun asmField( +inline fun buildAsmField( algebra: Field, block: AsmExpressionField.() -> AsmExpression ): Expression = asm(AsmExpressionField(algebra), algebra, block) diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 93c52faac..034ba09c8 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -2,9 +2,7 @@ package scientifik.kmath.expressions import scientifik.kmath.expressions.asm.AsmExpression import scientifik.kmath.expressions.asm.AsmExpressionField -import scientifik.kmath.expressions.asm.asmField -import scientifik.kmath.expressions.asm.asmRing -import scientifik.kmath.operations.IntRing +import scientifik.kmath.expressions.asm.buildAsmField import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals @@ -14,7 +12,7 @@ class AsmTest { expected: Double?, arguments: Map = emptyMap(), block: AsmExpressionField.() -> AsmExpression - ): Unit = assertEquals(expected = expected, actual = asmField(RealField, block)(arguments)) + ): Unit = assertEquals(expected = expected, actual = buildAsmField(RealField, block)(arguments)) @Test fun testConstantsSum(): Unit = testDoubleExpression(16.0) { const(8.0) + 8.0 } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt new file mode 100644 index 000000000..b2e4d8a3d --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -0,0 +1,30 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.Space + +/** + * Create a functional expression on this [Algebra] + */ +fun Algebra.buildExpression(block: FunctionalExpressionAlgebra.() -> Expression): Expression = + FunctionalExpressionAlgebra(this).run(block) + +/** + * Create a functional expression on this [Space] + */ +fun Space.buildExpression(block: FunctionalExpressionSpace.() -> Expression): Expression = + FunctionalExpressionSpace(this).run(block) + +/** + * Create a functional expression on this [Ring] + */ +fun Ring.buildExpression(block: FunctionalExpressionRing.() -> Expression): Expression = + FunctionalExpressionRing(this).run(block) + +/** + * Create a functional expression on this [Field] + */ +fun Field.buildExpression(block: FunctionalExpressionField.() -> Expression): Expression = + FunctionalExpressionField(this).run(block) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index ce9b4756c..cbbdb99f5 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -33,7 +33,7 @@ internal class FunctionalConstProductExpression(val context: Space, val ex open class FunctionalExpressionAlgebra(val algebra: Algebra) : Algebra>, - ExpressionContext> { + ExpressionAlgebra> { override fun unaryOperation(operation: String, arg: Expression): Expression = FunctionalUnaryOperation(algebra, operation, arg) diff --git a/settings.gradle.kts b/settings.gradle.kts index 465ecfca8..2b2633060 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -45,5 +45,6 @@ include( ":kmath-for-real", ":kmath-geometry", ":kmath-ast", + ":kmath-asm", ":examples" ) From 3ec1f7b5f15ada3527902c448f0e9517213a4354 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 02:26:12 +0700 Subject: [PATCH 038/104] Merge kmath-asm and kmath-ast modules, make all the ExpressionAlgebras concise and consistent, implement new-styled builders both for ASM and F. expressions --- kmath-asm/build.gradle.kts | 10 -- .../expressions/asm/AsmExpressionSpaces.kt | 74 --------------- .../kmath/expressions/asm/Builders.kt | 50 ---------- .../scientifik/kmath/expressions/AsmTest.kt | 43 --------- kmath-ast/build.gradle.kts | 28 ++++-- .../scientifik/kmath}/asm/AsmExpressions.kt | 92 +++++++++++++++++-- .../kmath}/asm/AsmGenerationContext.kt | 10 +- .../kotlin/scientifik/kmath/asm/Builders.kt | 53 +++++++++++ .../scientifik/kmath}/asm/MethodVisitors.kt | 2 +- .../scientifik/kmath}/asm/Optimization.kt | 2 +- .../kmath/ast/{parser.kt => Parser.kt} | 0 .../kotlin/scientifik/kmath/ast/asm.kt | 23 ----- .../kotlin/scietifik/kmath/ast/AsmTest.kt | 18 ++++ .../kotlin/scietifik/kmath/ast/ParserTest.kt | 10 +- .../scientifik/kmath/expressions/Builders.kt | 13 +-- .../expressions/FunctionalExpressions.kt | 51 +++++----- settings.gradle.kts | 1 - 17 files changed, 220 insertions(+), 260 deletions(-) delete mode 100644 kmath-asm/build.gradle.kts delete mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt delete mode 100644 kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt delete mode 100644 kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt rename {kmath-asm/src/main/kotlin/scientifik/kmath/expressions => kmath-ast/src/jvmMain/kotlin/scientifik/kmath}/asm/AsmExpressions.kt (52%) rename {kmath-asm/src/main/kotlin/scientifik/kmath/expressions => kmath-ast/src/jvmMain/kotlin/scientifik/kmath}/asm/AsmGenerationContext.kt (96%) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Builders.kt rename {kmath-asm/src/main/kotlin/scientifik/kmath/expressions => kmath-ast/src/jvmMain/kotlin/scientifik/kmath}/asm/MethodVisitors.kt (94%) rename {kmath-asm/src/main/kotlin/scientifik/kmath/expressions => kmath-ast/src/jvmMain/kotlin/scientifik/kmath}/asm/Optimization.kt (80%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/{parser.kt => Parser.kt} (100%) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt diff --git a/kmath-asm/build.gradle.kts b/kmath-asm/build.gradle.kts deleted file mode 100644 index 3b004eb8e..000000000 --- a/kmath-asm/build.gradle.kts +++ /dev/null @@ -1,10 +0,0 @@ -plugins { - id("scientifik.jvm") -} - -dependencies { - api(project(path = ":kmath-core")) - implementation("org.ow2.asm:asm:8.0.1") - implementation("org.ow2.asm:asm-commons:8.0.1") - implementation(kotlin("reflect")) -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt deleted file mode 100644 index 0f8f685eb..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt +++ /dev/null @@ -1,74 +0,0 @@ -package scientifik.kmath.expressions.asm - -import scientifik.kmath.expressions.ExpressionAlgebra -import scientifik.kmath.operations.* - -open class AsmExpressionAlgebra(val algebra: Algebra) : - Algebra>, - ExpressionAlgebra> { - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(algebra, operation, left, right) - - override fun const(value: T): AsmExpression = AsmConstantExpression(value) - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) -} - -open class AsmExpressionSpace( - val space: Space -) : AsmExpressionAlgebra(space), Space> { - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(algebra, operation, left, right) - - override val zero: AsmExpression = AsmConstantExpression(space.zero) - - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) - - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) - operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) - operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) - operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this - operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this -} - -open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace(ring), Ring> { - override val one: AsmExpression - get() = const(this.ring.one) - - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(algebra, operation, left, right) - - override fun number(value: Number): AsmExpression = const(ring { one * value }) - - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) - - operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) - operator fun T.times(arg: AsmExpression): AsmExpression = arg * this -} - -open class AsmExpressionField(private val field: Field) : - AsmExpressionRing(field), - Field> { - - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(algebra, operation, left, right) - - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmBinaryOperation(field, FieldOperations.DIV_OPERATION, a, b) - - operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) - operator fun T.div(arg: AsmExpression): AsmExpression = arg / this -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt deleted file mode 100644 index c56adb767..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt +++ /dev/null @@ -1,50 +0,0 @@ -package scientifik.kmath.expressions.asm - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionAlgebra -import scientifik.kmath.operations.* - -@PublishedApi -internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(expression, collision + 1) -} - - -inline fun >> asm( - i: E, - algebra: Algebra, - block: E.() -> AsmExpression -): Expression { - val expression = i.block().optimize() - val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) - expression.invoke(ctx) - return ctx.generate() -} - -inline fun buildAsmAlgebra( - algebra: Algebra, - block: AsmExpressionAlgebra.() -> AsmExpression -): Expression = asm(AsmExpressionAlgebra(algebra), algebra, block) - -inline fun buildAsmSpace( - algebra: Space, - block: AsmExpressionSpace.() -> AsmExpression -): Expression = asm(AsmExpressionSpace(algebra), algebra, block) - -inline fun buildAsmRing( - algebra: Ring, - block: AsmExpressionRing.() -> AsmExpression -): Expression = asm(AsmExpressionRing(algebra), algebra, block) - -inline fun buildAsmField( - algebra: Field, - block: AsmExpressionField.() -> AsmExpression -): Expression = asm(AsmExpressionField(algebra), algebra, block) diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt deleted file mode 100644 index 034ba09c8..000000000 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ /dev/null @@ -1,43 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.expressions.asm.AsmExpression -import scientifik.kmath.expressions.asm.AsmExpressionField -import scientifik.kmath.expressions.asm.buildAsmField -import scientifik.kmath.operations.RealField -import kotlin.test.Test -import kotlin.test.assertEquals - -class AsmTest { - private fun testDoubleExpression( - expected: Double?, - arguments: Map = emptyMap(), - block: AsmExpressionField.() -> AsmExpression - ): Unit = assertEquals(expected = expected, actual = buildAsmField(RealField, block)(arguments)) - - @Test - fun testConstantsSum(): Unit = testDoubleExpression(16.0) { const(8.0) + 8.0 } - - @Test - fun testVarsSum(): Unit = testDoubleExpression(1000.0, mapOf("x" to 500.0)) { variable("x") + 500.0 } - - @Test - fun testProduct(): Unit = testDoubleExpression(24.0) { const(4.0) * const(6.0) } - - @Test - fun testConstantProduct(): Unit = testDoubleExpression(984.0) { const(8.0) * 123 } - - @Test - fun testVarsConstantProductVar(): Unit = testDoubleExpression(984.0, mapOf("x" to 8.0)) { variable("x") * 123 } - - @Test - fun testSubtraction(): Unit = testDoubleExpression(2.0) { const(4.0) - 2.0 } - - @Test - fun testDivision(): Unit = testDoubleExpression(64.0) { const(128.0) / 2 } - - @Test - fun testDirectUnaryCall(): Unit = testDoubleExpression(64.0) { unaryOperation("+", const(64.0)) } - - @Test - fun testDirectBinaryCall(): Unit = testDoubleExpression(4096.0) { binaryOperation("*", const(64.0), const(64.0)) } -} diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 88540e7b8..deb6a28ee 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -1,10 +1,5 @@ -plugins { - id("scientifik.mpp") -} - -repositories{ - maven("https://dl.bintray.com/hotkeytlt/maven") -} +plugins { id("scientifik.mpp") } +repositories { maven("https://dl.bintray.com/hotkeytlt/maven") } kotlin.sourceSets { commonMain { @@ -13,9 +8,22 @@ kotlin.sourceSets { implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") } } - jvmMain{ - dependencies{ + + jvmMain { + dependencies { implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3") + implementation("org.ow2.asm:asm:8.0.1") + implementation("org.ow2.asm:asm-commons:8.0.1") + implementation(kotlin("reflect")) } } -} \ No newline at end of file + + jvmTest { + dependencies { + implementation(kotlin("test")) + implementation(kotlin("test-junit5")) + } + } +} + +tasks.withType { useJUnitPlatform() } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt similarity index 52% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 70f1fc28e..75b5f26d0 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,9 +1,8 @@ -package scientifik.kmath.expressions.asm +package scientifik.kmath.asm import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke +import scientifik.kmath.expressions.ExpressionAlgebra +import scientifik.kmath.operations.* import kotlin.reflect.full.memberFunctions import kotlin.reflect.jvm.jvmName @@ -109,11 +108,13 @@ internal class AsmBinaryOperation( } } -internal class AsmVariableExpression(private val name: String, private val default: T? = null) : AsmExpression { +internal class AsmVariableExpression(private val name: String, private val default: T? = null) : + AsmExpression { override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) } -internal class AsmConstantExpression(private val value: T) : AsmExpression { +internal class AsmConstantExpression(private val value: T) : + AsmExpression { override fun tryEvaluate(): T = value override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } @@ -141,9 +142,88 @@ internal class AsmConstProductExpression( } } +internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : AsmExpression { + override fun tryEvaluate(): T? = context.number(value) + + override fun invoke(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) +} + internal abstract class FunctionalCompiledExpression internal constructor( @JvmField protected val algebra: Algebra, @JvmField protected val constants: Array ) : Expression { abstract override fun invoke(arguments: Map): T } + +interface AsmExpressionAlgebra> : NumericAlgebra>, + ExpressionAlgebra> { + val algebra: A + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun number(value: Number): AsmExpression = AsmNumberExpression(algebra, value) + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) +} + +open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, + Space> where A : Space, A : NumericAlgebra { + override val zero: AsmExpression + get() = const(algebra.zero) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(algebra, a, k) + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this +} + +open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), + Ring> where A : Ring, A : NumericAlgebra { + override val one: AsmExpression + get() = const(algebra.one) + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun number(value: Number): AsmExpression = const(algebra { one * value }) + + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) + + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this +} + +open class AsmExpressionField(override val algebra: A) : + AsmExpressionRing(algebra), + Field> where A : Field, A : NumericAlgebra { + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) + + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this +} + diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt similarity index 96% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt index 92663ba8f..d9dba746d 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions.asm +package scientifik.kmath.asm import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label @@ -15,7 +15,8 @@ class AsmGenerationContext( internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } - private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + private val classLoader: ClassLoader = + ClassLoader(javaClass.classLoader) @Suppress("PrivatePropertyName") private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') @@ -236,7 +237,8 @@ class AsmGenerationContext( } visitLdcInsn(name) - visitMethodInsn(Opcodes.INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) + visitMethodInsn(Opcodes.INVOKEINTERFACE, + MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) visitCastToT() } @@ -275,7 +277,7 @@ class AsmGenerationContext( ) internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/expressions/asm/FunctionalCompiledExpression" + "scientifik/kmath/asm/FunctionalCompiledExpression" internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Builders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Builders.kt new file mode 100644 index 000000000..67870e07e --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Builders.kt @@ -0,0 +1,53 @@ +package scientifik.kmath.asm + +import scientifik.kmath.ast.MST +import scientifik.kmath.ast.evaluate +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.* + +@PublishedApi +internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + +inline fun AsmExpression.compile(algebra: Algebra): Expression { + val ctx = AsmGenerationContext(T::class.java, algebra, buildName(this)) + invoke(ctx) + return ctx.generate() +} + +inline fun , E : AsmExpressionAlgebra> A.asm( + expressionAlgebra: E, + block: E.() -> AsmExpression +): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) + +inline fun , E : AsmExpressionAlgebra> A.asm( + expressionAlgebra: E, + ast: MST +): Expression = asm(expressionAlgebra) { evaluate(ast) } + +inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = + AsmExpressionSpace(this).let { it.block().compile(it.algebra) } + +inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = + asmSpace { evaluate(ast) } + +inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = + AsmExpressionRing(this).let { it.block().compile(it.algebra) } + +inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = + asmRing { evaluate(ast) } + +inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = + AsmExpressionField(this).let { it.block().compile(it.algebra) } + +inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = + asmRing { evaluate(ast) } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt similarity index 94% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt index 9f697ab6d..42565b6ba 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions.asm +package scientifik.kmath.asm import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt similarity index 80% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt index bcd182dc3..e821bd206 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions.asm +package scientifik.kmath.asm @PublishedApi internal fun AsmExpression.optimize(): AsmExpression { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/Parser.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/Parser.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt deleted file mode 100644 index c01648fb0..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt +++ /dev/null @@ -1,23 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.NumericAlgebra - -//TODO stubs for asm generation - -interface AsmExpression - -interface AsmExpressionAlgebra> : NumericAlgebra> { - val algebra: A -} - -fun AsmExpression.compile(): Expression = TODO() - -//TODO add converter for functional expressions - -inline fun > A.asm( - block: AsmExpressionAlgebra.() -> AsmExpression -): Expression = TODO() - -inline fun > A.asm(ast: MST): Expression = asm { evaluate(ast) } \ No newline at end of file diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt new file mode 100644 index 000000000..73058face --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -0,0 +1,18 @@ +package scietifik.kmath.ast + +import org.junit.jupiter.api.Test +import scientifik.kmath.asm.asmField +import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Complex +import scientifik.kmath.operations.ComplexField +import kotlin.test.assertEquals + +class AsmTest { + @Test + fun parsedExpression() { + val mst = "2+2*(2+2)".parseMath() + val res = ComplexField.asmField(mst)() + assertEquals(Complex(10.0, 0.0), res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt index 6849d24b8..ddac07786 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -1,17 +1,17 @@ package scietifik.kmath.ast -import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.parseMath import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField +import kotlin.test.assertEquals -internal class ParserTest{ +internal class ParserTest { @Test - fun parsedExpression(){ + fun parsedExpression() { val mst = "2+2*(2+2)".parseMath() val res = ComplexField.evaluate(mst) - assertEquals(Complex(10.0,0.0), res) + assertEquals(Complex(10.0, 0.0), res) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt index b2e4d8a3d..cdd6f695b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -1,30 +1,23 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space -/** - * Create a functional expression on this [Algebra] - */ -fun Algebra.buildExpression(block: FunctionalExpressionAlgebra.() -> Expression): Expression = - FunctionalExpressionAlgebra(this).run(block) - /** * Create a functional expression on this [Space] */ -fun Space.buildExpression(block: FunctionalExpressionSpace.() -> Expression): Expression = +fun Space.buildExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = FunctionalExpressionSpace(this).run(block) /** * Create a functional expression on this [Ring] */ -fun Ring.buildExpression(block: FunctionalExpressionRing.() -> Expression): Expression = +fun Ring.buildExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = FunctionalExpressionRing(this).run(block) /** * Create a functional expression on this [Field] */ -fun Field.buildExpression(block: FunctionalExpressionField.() -> Expression): Expression = +fun Field.buildExpression(block: FunctionalExpressionField>.() -> Expression): Expression = FunctionalExpressionField(this).run(block) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index cbbdb99f5..88d8c87b7 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -26,48 +26,55 @@ internal class FunctionalConstantExpression(val value: T) : Expression { override fun invoke(arguments: Map): T = value } -internal class FunctionalConstProductExpression(val context: Space, val expr: Expression, val const: Number) : +internal class FunctionalConstProductExpression( + val context: Space, + val expr: Expression, + val const: Number +) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } -open class FunctionalExpressionAlgebra(val algebra: Algebra) : - Algebra>, - ExpressionAlgebra> { - override fun unaryOperation(operation: String, arg: Expression): Expression = - FunctionalUnaryOperation(algebra, operation, arg) +interface FunctionalExpressionAlgebra> : ExpressionAlgebra> { + val algebra: A override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = FunctionalBinaryOperation(algebra, operation, left, right) + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) + override fun const(value: T): Expression = FunctionalConstantExpression(value) override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) } -open class FunctionalExpressionSpace(val space: Space) : - FunctionalExpressionAlgebra(space), - Space> { - override fun unaryOperation(operation: String, arg: Expression): Expression = - FunctionalUnaryOperation(algebra, operation, arg) +open class FunctionalExpressionSpace(override val algebra: A) : FunctionalExpressionAlgebra, + Space> where A : Space { + override val zero: Expression + get() = const(algebra.zero) override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = FunctionalBinaryOperation(algebra, operation, left, right) - override val zero: Expression = FunctionalConstantExpression(space.zero) + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) override fun add(a: Expression, b: Expression): Expression = - FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) + FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: Expression, k: Number): Expression = + FunctionalConstProductExpression(algebra, a, k) - override fun multiply(a: Expression, k: Number): Expression = FunctionalConstProductExpression(space, a, k) operator fun Expression.plus(arg: T): Expression = this + const(arg) operator fun Expression.minus(arg: T): Expression = this - const(arg) operator fun T.plus(arg: Expression): Expression = arg + this operator fun T.minus(arg: Expression): Expression = arg - this } -open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpressionSpace(ring), Ring> { +open class FunctionalExpressionRing(override val algebra: A) : FunctionalExpressionSpace(algebra), + Ring> where A : Ring, A : NumericAlgebra { override val one: Expression - get() = const(this.ring.one) + get() = const(algebra.one) override fun unaryOperation(operation: String, arg: Expression): Expression = FunctionalUnaryOperation(algebra, operation, arg) @@ -75,18 +82,18 @@ open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpression override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = FunctionalBinaryOperation(algebra, operation, left, right) - override fun number(value: Number): Expression = const(ring { one * value }) + override fun number(value: Number): Expression = const(algebra { one * value }) override fun multiply(a: Expression, b: Expression): Expression = - FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) + FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) operator fun Expression.times(arg: T): Expression = this * const(arg) operator fun T.times(arg: Expression): Expression = arg * this } -open class FunctionalExpressionField(val field: Field) : - FunctionalExpressionRing(field), - Field> { +open class FunctionalExpressionField(override val algebra: A) : + FunctionalExpressionRing(algebra), + Field> where A : Field, A : NumericAlgebra { override fun unaryOperation(operation: String, arg: Expression): Expression = FunctionalUnaryOperation(algebra, operation, arg) @@ -95,7 +102,7 @@ open class FunctionalExpressionField(val field: Field) : FunctionalBinaryOperation(algebra, operation, left, right) override fun divide(a: Expression, b: Expression): Expression = - FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) + FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) operator fun Expression.div(arg: T): Expression = this / const(arg) operator fun T.div(arg: Expression): Expression = arg / this diff --git a/settings.gradle.kts b/settings.gradle.kts index 2b2633060..465ecfca8 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -45,6 +45,5 @@ include( ":kmath-for-real", ":kmath-geometry", ":kmath-ast", - ":kmath-asm", ":examples" ) From 36ad1fcf589e0f8e9b4701f463360d7d66c80c41 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 15:44:54 +0700 Subject: [PATCH 039/104] Minor refactor and document --- .../scientifik/kmath/asm/AsmExpressions.kt | 101 +++++++++++++----- .../kmath/asm/AsmGenerationContext.kt | 12 ++- .../expressions/FunctionalExpressions.kt | 84 ++++++++++----- 3 files changed, 141 insertions(+), 56 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 75b5f26d0..ef579278a 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -142,7 +142,8 @@ internal class AsmConstProductExpression( } } -internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : AsmExpression { +internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : + AsmExpression { override fun tryEvaluate(): T? = context.number(value) override fun invoke(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) @@ -155,75 +156,121 @@ internal abstract class FunctionalCompiledExpression internal constructor( abstract override fun invoke(arguments: Map): T } +/** + * A context class for [AsmExpression] construction. + */ interface AsmExpressionAlgebra> : NumericAlgebra>, ExpressionAlgebra> { + /** + * The algebra to provide for AsmExpressions built. + */ val algebra: A + + /** + * Builds an AsmExpression to wrap a number. + */ + override fun number(value: Number): AsmExpression = AsmNumberExpression(algebra, value) + + /** + * Builds an AsmExpression of constant expression which does not depend on arguments. + */ + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + + /** + * Builds an AsmExpression to access a variable. + */ + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) + + /** + * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. + */ override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, operation, left, right) + /** + * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. + */ override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = AsmUnaryOperation(algebra, operation, arg) - - override fun number(value: Number): AsmExpression = AsmNumberExpression(algebra, value) - override fun const(value: T): AsmExpression = AsmConstantExpression(value) - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) } +/** + * A context class for [AsmExpression] construction for [Space] algebras. + */ open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, Space> where A : Space, A : NumericAlgebra { override val zero: AsmExpression get() = const(algebra.zero) - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(algebra, operation, left, right) - - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(algebra, operation, arg) - + /** + * Builds an AsmExpression of addition of two another expressions. + */ override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + /** + * Builds an AsmExpression of multiplication of expression by number. + */ override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(algebra, a, k) + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + super.binaryOperation(operation, left, right) } +/** + * A context class for [AsmExpression] construction for [Ring] algebras. + */ open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), Ring> where A : Ring, A : NumericAlgebra { override val one: AsmExpression get() = const(algebra.one) - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(algebra, operation, left, right) - - override fun number(value: Number): AsmExpression = const(algebra { one * value }) - + /** + * Builds an AsmExpression of multiplication of two expressions. + */ override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) operator fun T.times(arg: AsmExpression): AsmExpression = arg * this + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + super.binaryOperation(operation, left, right) + + override fun number(value: Number): AsmExpression = super.number(value) } +/** + * A context class for [AsmExpression] construction for [Field] algebras. + */ open class AsmExpressionField(override val algebra: A) : AsmExpressionRing(algebra), Field> where A : Field, A : NumericAlgebra { - - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(algebra, operation, left, right) - + /** + * Builds an AsmExpression of division an expression by another one. + */ override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) operator fun T.div(arg: AsmExpression): AsmExpression = arg / this -} + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + super.binaryOperation(operation, left, right) + + override fun number(value: Number): AsmExpression = super.number(value) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt index d9dba746d..69a2b6c85 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt @@ -6,6 +6,15 @@ import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.operations.Algebra +/** + * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java + * expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new + * class. + * + * @param T the type of AsmExpression to unwrap. + * @param algebra the algebra the applied AsmExpressions use. + * @param className the unique class name of new loaded class. + */ class AsmGenerationContext( classOfT: Class<*>, private val algebra: Algebra, @@ -29,8 +38,7 @@ class AsmGenerationContext( private val invokeArgumentsVar: Int = 1 private var maxStack: Int = 0 private val constants: MutableList = mutableListOf() - private val asmCompiledClassWriter: ClassWriter = - ClassWriter(0) + private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) private val invokeMethodVisitor: MethodVisitor private val invokeL0: Label private lateinit var invokeL1: Label diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index 88d8c87b7..c3f36a814 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -28,40 +28,61 @@ internal class FunctionalConstantExpression(val value: T) : Expression { internal class FunctionalConstProductExpression( val context: Space, - val expr: Expression, + private val expr: Expression, val const: Number -) : - Expression { +) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } +/** + * A context class for [Expression] construction. + */ interface FunctionalExpressionAlgebra> : ExpressionAlgebra> { + /** + * The algebra to provide for Expressions built. + */ val algebra: A + /** + * Builds an Expression of constant expression which does not depend on arguments. + */ + override fun const(value: T): Expression = FunctionalConstantExpression(value) + + /** + * Builds an Expression to access a variable. + */ + override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) + + /** + * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. + */ override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = FunctionalBinaryOperation(algebra, operation, left, right) + /** + * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. + */ override fun unaryOperation(operation: String, arg: Expression): Expression = FunctionalUnaryOperation(algebra, operation, arg) - - override fun const(value: T): Expression = FunctionalConstantExpression(value) - override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) } +/** + * A context class for [Expression] construction for [Space] algebras. + */ open class FunctionalExpressionSpace(override val algebra: A) : FunctionalExpressionAlgebra, Space> where A : Space { override val zero: Expression get() = const(algebra.zero) - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - FunctionalBinaryOperation(algebra, operation, left, right) - - override fun unaryOperation(operation: String, arg: Expression): Expression = - FunctionalUnaryOperation(algebra, operation, arg) - + /** + * Builds an Expression of addition of two another expressions. + */ override fun add(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + /** + * Builds an Expression of multiplication of expression by number. + */ override fun multiply(a: Expression, k: Number): Expression = FunctionalConstProductExpression(algebra, a, k) @@ -69,6 +90,12 @@ open class FunctionalExpressionSpace(override val algebra: A) : Functional operator fun Expression.minus(arg: T): Expression = this - const(arg) operator fun T.plus(arg: Expression): Expression = arg + this operator fun T.minus(arg: Expression): Expression = arg - this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) } open class FunctionalExpressionRing(override val algebra: A) : FunctionalExpressionSpace(algebra), @@ -76,34 +103,37 @@ open class FunctionalExpressionRing(override val algebra: A) : FunctionalE override val one: Expression get() = const(algebra.one) - override fun unaryOperation(operation: String, arg: Expression): Expression = - FunctionalUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - FunctionalBinaryOperation(algebra, operation, left, right) - - override fun number(value: Number): Expression = const(algebra { one * value }) - + /** + * Builds an Expression of multiplication of two expressions. + */ override fun multiply(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) operator fun Expression.times(arg: T): Expression = this * const(arg) operator fun T.times(arg: Expression): Expression = arg * this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) } open class FunctionalExpressionField(override val algebra: A) : FunctionalExpressionRing(algebra), Field> where A : Field, A : NumericAlgebra { - - override fun unaryOperation(operation: String, arg: Expression): Expression = - FunctionalUnaryOperation(algebra, operation, arg) - - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - FunctionalBinaryOperation(algebra, operation, left, right) - + /** + * Builds an Expression of division an expression by another one. + */ override fun divide(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) operator fun Expression.div(arg: T): Expression = this / const(arg) operator fun T.div(arg: Expression): Expression = arg / this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) } From fec8c7f9d19118f38188d804fd8248aeb8fbafe8 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 15:50:21 +0700 Subject: [PATCH 040/104] Minor refactor and encapsulation --- .../jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt | 6 ++++-- .../scientifik/kmath/expressions/FunctionalExpressions.kt | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index ef579278a..9e6efb66b 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -11,7 +11,7 @@ interface AsmExpression { fun invoke(gen: AsmGenerationContext) } -internal val methodNameAdapters = mapOf("+" to "add", "*" to "multiply", "/" to "divide") +private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name @@ -137,7 +137,9 @@ internal class AsmConstProductExpression( gen.visitAlgebraOperation( owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};" + + "L${AsmGenerationContext.NUMBER_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" ) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index c3f36a814..a441a80c0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -2,7 +2,7 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.* -internal class FunctionalUnaryOperation(val context: Algebra, val name: String, val expr: Expression) : +internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : Expression { override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) } From 866ae47239f93eea7625e65a94799ce7245f03c4 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 13 Jun 2020 11:51:33 +0300 Subject: [PATCH 041/104] replace `raw` by `symbol` in algebra --- kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt | 2 +- .../commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt | 2 +- .../commonMain/kotlin/scientifik/kmath/operations/Algebra.kt | 5 ++++- .../commonMain/kotlin/scientifik/kmath/operations/Complex.kt | 4 ++-- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt index eba2c3343..3375ecdf3 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -43,7 +43,7 @@ sealed class MST { fun NumericAlgebra.evaluate(node: MST): T { return when (node) { is MST.Numeric -> number(node.value) - is MST.Singular -> raw(node.value) + is MST.Singular -> symbol(node.value) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Binary -> when { node.left is MST.Numeric && node.right is MST.Numeric -> { diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt index 3ee454b2a..dbd5238e3 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt @@ -12,7 +12,7 @@ class MSTExpression(val algebra: NumericAlgebra, val mst: MST) : Expressio * Substitute algebra raw value */ private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra by algebra { - override fun raw(value: String): T = arguments[value] ?: super.raw(value) + override fun symbol(value: String): T = arguments[value] ?: super.symbol(value) } override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 166287ec7..8ed3f329e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -10,7 +10,10 @@ interface Algebra { /** * Wrap raw string or variable */ - fun raw(value: String): T = error("Wrapping of '$value' is not supported in $this") + fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") + + @Deprecated("Symbol is more concise",replaceWith = ReplaceWith("symbol")) + fun raw(value: String): T = symbol(value) /** * Dynamic call of unary operation with name [operation] on [arg] diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 01aef824d..6ce1d929b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -51,10 +51,10 @@ object ComplexField : ExtendedField { operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) - override fun raw(value: String): Complex = if (value == "i") { + override fun symbol(value: String): Complex = if (value == "i") { i } else { - super.raw(value) + super.symbol(value) } } From 1582fde0915332681884d6c50f5bf1f516ff5c6a Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 15:51:42 +0700 Subject: [PATCH 042/104] Replace JUnit @Test with kotlin-test @Test --- kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt | 2 +- kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 73058face..c92524b5d 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -1,12 +1,12 @@ package scietifik.kmath.ast -import org.junit.jupiter.api.Test import scientifik.kmath.asm.asmField import scientifik.kmath.ast.parseMath import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import kotlin.test.assertEquals +import kotlin.test.Test class AsmTest { @Test diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt index ddac07786..06546302e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -1,11 +1,11 @@ package scietifik.kmath.ast -import org.junit.jupiter.api.Test import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.parseMath import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import kotlin.test.assertEquals +import kotlin.test.Test internal class ParserTest { @Test From 834d1e1397e76943adab584c46a7526c47073da3 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 15:53:25 +0700 Subject: [PATCH 043/104] Move specific optimization functions to Optimization --- .../scientifik/kmath/asm/AsmExpressions.kt | 31 ---------------- .../scientifik/kmath/asm/Optimization.kt | 35 +++++++++++++++++++ 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 9e6efb66b..c4039da81 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -11,37 +11,6 @@ interface AsmExpression { fun invoke(gen: AsmGenerationContext) } -private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") - -internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name - - context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } - ?: return false - - return true -} - -internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name - - context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } - ?: return false - - val owner = context::class.jvmName.replace('.', '/') - - val sig = buildString { - append('(') - repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } - append(')') - append("L${AsmGenerationContext.OBJECT_CLASS};") - } - - visitAlgebraOperation(owner = owner, method = aName, descriptor = sig) - - return true -} - internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : AsmExpression { private val expr: AsmExpression = expr.optimize() diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt index e821bd206..4e094fdd8 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt @@ -1,5 +1,40 @@ package scientifik.kmath.asm +import scientifik.kmath.operations.Algebra +import kotlin.reflect.full.memberFunctions +import kotlin.reflect.jvm.jvmName + +private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") + +internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { + val aName = methodNameAdapters[name] ?: name + + context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } + ?: return false + + return true +} + +internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { + val aName = methodNameAdapters[name] ?: name + + context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } + ?: return false + + val owner = context::class.jvmName.replace('.', '/') + + val sig = buildString { + append('(') + repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } + append(')') + append("L${AsmGenerationContext.OBJECT_CLASS};") + } + + visitAlgebraOperation(owner = owner, method = aName, descriptor = sig) + + return true +} + @PublishedApi internal fun AsmExpression.optimize(): AsmExpression { val a = tryEvaluate() From 223d238c435f6e7b34d5ae778ec35b64603f0e81 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 15:53:54 +0700 Subject: [PATCH 044/104] Encapsulate MethodVisitor extensions --- .../jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt index 42565b6ba..aec5a123d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt @@ -3,7 +3,7 @@ package scientifik.kmath.asm import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* -fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { +internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { -1 -> visitInsn(ICONST_M1) 0 -> visitInsn(ICONST_0) 1 -> visitInsn(ICONST_1) @@ -14,13 +14,13 @@ fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { else -> visitLdcInsn(value) } -fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) { +internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) { 0.0 -> visitInsn(DCONST_0) 1.0 -> visitInsn(DCONST_1) else -> visitLdcInsn(value) } -fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) { +internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) { 0f -> visitInsn(FCONST_0) 1f -> visitInsn(FCONST_1) 2f -> visitInsn(FCONST_2) From e65d1e43cf64a27a04bc20b3245814be372008b9 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 16:16:57 +0700 Subject: [PATCH 045/104] Write tests --- .../kmath/ast/asm/TestAsmAlgebras.kt | 75 +++++++++++++++++++ .../kmath/ast/asm/TestAsmExpressions.kt | 21 ++++++ 2 files changed, 96 insertions(+) create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt new file mode 100644 index 000000000..07f895c40 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt @@ -0,0 +1,75 @@ +package scietifik.kmath.ast.asm + +import scientifik.kmath.asm.asmField +import scientifik.kmath.asm.asmRing +import scientifik.kmath.asm.asmSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmAlgebras { + @Test + fun space() { + val res = ByteRing.asmSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + 3.toByte() - (2.toByte() + (multiply( + add(const(1), const(1)), + 2 + ) + 1.toByte()) * 3.toByte() - 1.toByte()) + ), + + number(1) + ) + variable("x") + zero + }("x" to 2.toByte()) + + assertEquals(16, res) + } + + @Test + fun ring() { + val res = ByteRing.asmRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (3.toByte() - (2.toByte() + (multiply( + add(const(1), const(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * const(2) + }() + + assertEquals(24, res) + } + + @Test + fun field() { + val res = RealField.asmField { + divide(binaryOperation( + "+", + + unaryOperation( + "+", + (3.0 - (2.0 + (multiply( + add(const(1.0), const(1.0)), + 2 + ) + 1.0))) * 3 - 1.0 + ), + + number(1) + ) / 2, const(2.0)) * one + }() + + assertEquals(3.0, res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt new file mode 100644 index 000000000..2839b8a25 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt @@ -0,0 +1,21 @@ +package scietifik.kmath.ast.asm + +import scientifik.kmath.asm.asmField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmExpressions { + @Test + fun testUnaryOperationInvocation() { + val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0) + assertEquals(2.0, res) + } + + @Test + fun testConstProductInvocation() { + val res = RealField.asmField { variable("x") * 2 }("x" to 2.0) + assertEquals(4.0, res) + } +} From f9835979ea22069b0f3f94ec436b665ecb508554 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 16:48:45 +0700 Subject: [PATCH 046/104] Fix specification bug --- .../scientifik/kmath/asm/AsmExpressions.kt | 4 ++-- .../kmath/asm/AsmGenerationContext.kt | 17 +++++++++++++---- .../kotlin/scientifik/kmath/asm/Optimization.kt | 17 +++++++++++------ 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index c4039da81..83f3b7754 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -57,13 +57,13 @@ internal class AsmBinaryOperation( override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() - if (!hasSpecific(context, name, 1)) + if (!hasSpecific(context, name, 2)) gen.visitStringConstant(name) first.invoke(gen) second.invoke(gen) - if (gen.tryInvokeSpecific(context, name, 1)) + if (gen.tryInvokeSpecific(context, name, 2)) return gen.visitAlgebraOperation( diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt index 69a2b6c85..c10e56ba8 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt @@ -4,6 +4,7 @@ import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes +import scientifik.kmath.asm.AsmGenerationContext.ClassLoader import scientifik.kmath.operations.Algebra /** @@ -245,8 +246,10 @@ class AsmGenerationContext( } visitLdcInsn(name) - visitMethodInsn(Opcodes.INVOKEINTERFACE, - MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) + visitMethodInsn( + Opcodes.INVOKEINTERFACE, + MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true + ) visitCastToT() } @@ -262,9 +265,15 @@ class AsmGenerationContext( invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) } - internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) { + internal fun visitAlgebraOperation( + owner: String, + method: String, + descriptor: String, + opcode: Int = Opcodes.INVOKEINTERFACE, + isInterface: Boolean = true + ) { maxStack++ - invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, owner, method, descriptor, true) + invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) visitCastToT() } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt index 4e094fdd8..3198fa8ec 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt @@ -1,15 +1,14 @@ package scientifik.kmath.asm +import org.objectweb.asm.Opcodes import scientifik.kmath.operations.Algebra -import kotlin.reflect.full.memberFunctions -import kotlin.reflect.jvm.jvmName private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } + context::class.java.methods.find { it.name == aName && it.parameters.size == arity } ?: return false return true @@ -18,10 +17,10 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context::class.memberFunctions.find { it.name == aName && it.parameters.size == arity } + context::class.java.methods.find { it.name == aName && it.parameters.size == arity } ?: return false - val owner = context::class.jvmName.replace('.', '/') + val owner = context::class.java.name.replace('.', '/') val sig = buildString { append('(') @@ -30,7 +29,13 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, append("L${AsmGenerationContext.OBJECT_CLASS};") } - visitAlgebraOperation(owner = owner, method = aName, descriptor = sig) + visitAlgebraOperation( + owner = owner, + method = aName, + descriptor = sig, + opcode = Opcodes.INVOKEVIRTUAL, + isInterface = false + ) return true } From e91f6470d3d6fd4d21574b4006977ae7b4be9b2e Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 13 Jun 2020 17:07:22 +0700 Subject: [PATCH 047/104] Implement constants inlining --- .../kmath/asm/AsmGenerationContext.kt | 33 ++++++++++++------- .../scientifik/kmath/asm/Optimization.kt | 1 + 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt index c10e56ba8..cc7874e37 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt @@ -17,7 +17,7 @@ import scientifik.kmath.operations.Algebra * @param className the unique class name of new loaded class. */ class AsmGenerationContext( - classOfT: Class<*>, + private val classOfT: Class<*>, private val algebra: Algebra, private val className: String ) { @@ -186,7 +186,15 @@ class AsmGenerationContext( return new } - internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) + internal fun visitLoadFromConstants(value: T) { + if (classOfT in INLINABLE_NUMBERS) { + visitNumberConstant(value as Number) + visitCastToT() + return + } + + visitLoadAnyFromConstants(value as Any, T_CLASS) + } private fun visitLoadAnyFromConstants(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex @@ -207,7 +215,6 @@ class AsmGenerationContext( maxStack++ val clazz = value.javaClass val c = clazz.name.replace('.', '/') - val sigLetter = SIGNATURE_LETTERS[clazz] if (sigLetter != null) { @@ -284,14 +291,18 @@ class AsmGenerationContext( } internal companion object { - private val SIGNATURE_LETTERS = mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" - ) + private val SIGNATURE_LETTERS by lazy { + mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) + } + + private val INLINABLE_NUMBERS by lazy { SIGNATURE_LETTERS.keys } internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/FunctionalCompiledExpression" diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt index 3198fa8ec..51048c77c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt @@ -2,6 +2,7 @@ package scientifik.kmath.asm import org.objectweb.asm.Opcodes import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.ByteRing private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") From 0950580b85ca02a72c13863efc72aae2367e3ab3 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 13 Jun 2020 18:26:18 +0300 Subject: [PATCH 048/104] Moe better-parse to common. Watch for https://github.com/h0tk3y/better-parse/issues/27 --- kmath-ast/build.gradle.kts | 6 ++++++ .../kotlin/scientifik/kmath/ast/parser.kt | 0 2 files changed, 6 insertions(+) rename kmath-ast/src/{jvmMain => commonMain}/kotlin/scientifik/kmath/ast/parser.kt (100%) diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 88540e7b8..65a1ac12c 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -11,6 +11,7 @@ kotlin.sourceSets { dependencies { api(project(":kmath-core")) implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") + implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3") } } jvmMain{ @@ -18,4 +19,9 @@ kotlin.sourceSets { implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3") } } + jsMain{ + dependencies{ + implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3") + } + } } \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/parser.kt rename to kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt From 48b688b6b1a6888d1812a36ab5c2bca12042f8d2 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 00:06:12 +0700 Subject: [PATCH 049/104] Fix minor problems occured after merge --- .../scientifik/kmath/operations/NumberAlgebra.kt | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 830b1496d..a1b845ccc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -20,20 +20,18 @@ interface ExtendedFieldOperations : PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + else -> super.unaryOperation(operation, arg) } } interface ExtendedField : ExtendedFieldOperations, Field { - override fun rightSideNumberOperation(operation: String, left: T, right: Number): T { - return when (operation) { - PowerOperations.POW_OPERATION -> power(left, right) - else -> super.rightSideNumberOperation(operation, left, right) - } - + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { + PowerOperations.POW_OPERATION -> power(left, right) + else -> super.rightSideNumberOperation(operation, left, right) } } + /** * Real field element wrapping double. * From af410dde70af10e7531809dd18862d5eaca8a204 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 00:18:40 +0700 Subject: [PATCH 050/104] Apply the suggested changes --- .../kmath/asm/{Builders.kt => AsmBuilders.kt} | 5 ++- .../scientifik/kmath/asm/AsmExpressions.kt | 43 +++++++++++++------ .../kmath/asm/AsmGenerationContext.kt | 5 ++- .../asm/{ => internal}/MethodVisitors.kt | 2 +- .../kmath/asm/{ => internal}/Optimization.kt | 6 ++- 5 files changed, 42 insertions(+), 19 deletions(-) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/{Builders.kt => AsmBuilders.kt} (94%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/{ => internal}/MethodVisitors.kt (95%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/{ => internal}/Optimization.kt (88%) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Builders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt similarity index 94% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Builders.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt index 67870e07e..14456a426 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Builders.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt @@ -18,9 +18,10 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String return buildName(expression, collision + 1) } -inline fun AsmExpression.compile(algebra: Algebra): Expression { +@PublishedApi +internal inline fun AsmExpression.compile(algebra: Algebra): Expression { val ctx = AsmGenerationContext(T::class.java, algebra, buildName(this)) - invoke(ctx) + compile(ctx) return ctx.generate() } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 83f3b7754..cd1815f7c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,14 +1,31 @@ package scientifik.kmath.asm +import scientifik.kmath.asm.internal.hasSpecific +import scientifik.kmath.asm.internal.optimize +import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.* -import kotlin.reflect.full.memberFunctions -import kotlin.reflect.jvm.jvmName +/** + * A function declaration that could be compiled to [AsmGenerationContext]. + * + * @param T the type the stored function returns. + */ interface AsmExpression { + /** + * Tries to evaluate this function without its variables. This method is intended for optimization. + * + * @return `null` if the function depends on its variables, the value if the function is a constant. + */ fun tryEvaluate(): T? = null - fun invoke(gen: AsmGenerationContext) + + /** + * Compiles this declaration. + * + * @param gen the target [AsmGenerationContext]. + */ + fun compile(gen: AsmGenerationContext) } internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : @@ -16,13 +33,13 @@ internal class AsmUnaryOperation(private val context: Algebra, private val private val expr: AsmExpression = expr.optimize() override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } - override fun invoke(gen: AsmGenerationContext) { + override fun compile(gen: AsmGenerationContext) { gen.visitLoadAlgebra() if (!hasSpecific(context, name, 1)) gen.visitStringConstant(name) - expr.invoke(gen) + expr.compile(gen) if (gen.tryInvokeSpecific(context, name, 1)) return @@ -54,14 +71,14 @@ internal class AsmBinaryOperation( ) } - override fun invoke(gen: AsmGenerationContext) { + override fun compile(gen: AsmGenerationContext) { gen.visitLoadAlgebra() if (!hasSpecific(context, name, 2)) gen.visitStringConstant(name) - first.invoke(gen) - second.invoke(gen) + first.compile(gen) + second.compile(gen) if (gen.tryInvokeSpecific(context, name, 2)) return @@ -79,13 +96,13 @@ internal class AsmBinaryOperation( internal class AsmVariableExpression(private val name: String, private val default: T? = null) : AsmExpression { - override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) + override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) } internal class AsmConstantExpression(private val value: T) : AsmExpression { override fun tryEvaluate(): T = value - override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) + override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } internal class AsmConstProductExpression( @@ -98,10 +115,10 @@ internal class AsmConstProductExpression( override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } - override fun invoke(gen: AsmGenerationContext) { + override fun compile(gen: AsmGenerationContext) { gen.visitLoadAlgebra() gen.visitNumberConstant(const) - expr.invoke(gen) + expr.compile(gen) gen.visitAlgebraOperation( owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, @@ -117,7 +134,7 @@ internal class AsmNumberExpression(private val context: NumericAlgebra, pr AsmExpression { override fun tryEvaluate(): T? = context.number(value) - override fun invoke(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) + override fun compile(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) } internal abstract class FunctionalCompiledExpression internal constructor( diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt index cc7874e37..be8da3eff 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt @@ -5,6 +5,9 @@ import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.AsmGenerationContext.ClassLoader +import scientifik.kmath.asm.internal.visitLdcOrDConstInsn +import scientifik.kmath.asm.internal.visitLdcOrFConstInsn +import scientifik.kmath.asm.internal.visitLdcOrIConstInsn import scientifik.kmath.operations.Algebra /** @@ -16,7 +19,7 @@ import scientifik.kmath.operations.Algebra * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. */ -class AsmGenerationContext( +class AsmGenerationContext @PublishedApi internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt similarity index 95% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt index aec5a123d..356de4765 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/MethodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.asm +package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt similarity index 88% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt index 51048c77c..cd41ff07c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt @@ -1,8 +1,10 @@ -package scientifik.kmath.asm +package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes +import scientifik.kmath.asm.AsmConstantExpression +import scientifik.kmath.asm.AsmExpression +import scientifik.kmath.asm.AsmGenerationContext import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.ByteRing private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") From 28cecde05bb16f35db04f31bd2585538e563eed1 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 00:24:10 +0700 Subject: [PATCH 051/104] Fix compilation problems found after merge --- kmath-ast/build.gradle.kts | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 65a1ac12c..9764fccc3 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("scientifik.mpp") } -repositories{ +repositories { maven("https://dl.bintray.com/hotkeytlt/maven") } @@ -14,13 +14,18 @@ kotlin.sourceSets { implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3") } } - jvmMain{ - dependencies{ + + jvmMain { + dependencies { implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3") + implementation("org.ow2.asm:asm:8.0.1") + implementation("org.ow2.asm:asm-commons:8.0.1") + implementation(kotlin("reflect")) } } - jsMain{ - dependencies{ + + jsMain { + dependencies { implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3") } } From d3d348620acb55feea6d915eca4523789594bba3 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 00:30:39 +0700 Subject: [PATCH 052/104] Rename AsmExpression to AsmNode, encapsulate AsmGenerationContext, make AsmNode (ex-AsmExpression) an abstract class instead of interface --- .../scientifik/kmath/asm/AsmBuilders.kt | 16 +-- .../scientifik/kmath/asm/AsmExpressions.kt | 109 +++++++++--------- .../{ => internal}/AsmGenerationContext.kt | 13 +-- .../kmath/asm/internal/Optimization.kt | 5 +- 4 files changed, 72 insertions(+), 71 deletions(-) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/{ => internal}/AsmGenerationContext.kt (96%) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt index 14456a426..69e49f8b6 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt @@ -1,12 +1,13 @@ package scientifik.kmath.asm +import scientifik.kmath.asm.internal.AsmGenerationContext import scientifik.kmath.ast.MST import scientifik.kmath.ast.evaluate import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.* @PublishedApi -internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { +internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String { val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" try { @@ -19,15 +20,16 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String } @PublishedApi -internal inline fun AsmExpression.compile(algebra: Algebra): Expression { - val ctx = AsmGenerationContext(T::class.java, algebra, buildName(this)) +internal inline fun AsmNode.compile(algebra: Algebra): Expression { + val ctx = + AsmGenerationContext(T::class.java, algebra, buildName(this)) compile(ctx) return ctx.generate() } inline fun , E : AsmExpressionAlgebra> A.asm( expressionAlgebra: E, - block: E.() -> AsmExpression + block: E.() -> AsmNode ): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) inline fun , E : AsmExpressionAlgebra> A.asm( @@ -35,19 +37,19 @@ inline fun , E : AsmExpressionAlgebra> A. ast: MST ): Expression = asm(expressionAlgebra) { evaluate(ast) } -inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = +inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmNode): Expression where A : NumericAlgebra, A : Space = AsmExpressionSpace(this).let { it.block().compile(it.algebra) } inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = asmSpace { evaluate(ast) } -inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = +inline fun A.asmRing(block: AsmExpressionRing.() -> AsmNode): Expression where A : NumericAlgebra, A : Ring = AsmExpressionRing(this).let { it.block().compile(it.algebra) } inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = asmRing { evaluate(ast) } -inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = +inline fun A.asmField(block: AsmExpressionField.() -> AsmNode): Expression where A : NumericAlgebra, A : Field = AsmExpressionField(this).let { it.block().compile(it.algebra) } inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index cd1815f7c..8f16a2710 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,5 +1,6 @@ package scientifik.kmath.asm +import scientifik.kmath.asm.internal.AsmGenerationContext import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.optimize import scientifik.kmath.asm.internal.tryInvokeSpecific @@ -12,25 +13,26 @@ import scientifik.kmath.operations.* * * @param T the type the stored function returns. */ -interface AsmExpression { +abstract class AsmNode internal constructor() { /** * Tries to evaluate this function without its variables. This method is intended for optimization. * * @return `null` if the function depends on its variables, the value if the function is a constant. */ - fun tryEvaluate(): T? = null + internal open fun tryEvaluate(): T? = null /** * Compiles this declaration. * * @param gen the target [AsmGenerationContext]. */ - fun compile(gen: AsmGenerationContext) + @PublishedApi + internal abstract fun compile(gen: AsmGenerationContext) } -internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : - AsmExpression { - private val expr: AsmExpression = expr.optimize() +internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmNode) : + AsmNode() { + private val expr: AsmNode = expr.optimize() override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } override fun compile(gen: AsmGenerationContext) { @@ -57,11 +59,11 @@ internal class AsmUnaryOperation(private val context: Algebra, private val internal class AsmBinaryOperation( private val context: Algebra, private val name: String, - first: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val first: AsmExpression = first.optimize() - private val second: AsmExpression = second.optimize() + first: AsmNode, + second: AsmNode +) : AsmNode() { + private val first: AsmNode = first.optimize() + private val second: AsmNode = second.optimize() override fun tryEvaluate(): T? = context { binaryOperation( @@ -95,23 +97,22 @@ internal class AsmBinaryOperation( } internal class AsmVariableExpression(private val name: String, private val default: T? = null) : - AsmExpression { + AsmNode() { override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) } internal class AsmConstantExpression(private val value: T) : - AsmExpression { + AsmNode() { override fun tryEvaluate(): T = value override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } internal class AsmConstProductExpression( private val context: Space, - expr: AsmExpression, + expr: AsmNode, private val const: Number -) : - AsmExpression { - private val expr: AsmExpression = expr.optimize() +) : AsmNode() { + private val expr: AsmNode = expr.optimize() override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } @@ -131,7 +132,7 @@ internal class AsmConstProductExpression( } internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : - AsmExpression { + AsmNode() { override fun tryEvaluate(): T? = context.number(value) override fun compile(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) @@ -145,10 +146,10 @@ internal abstract class FunctionalCompiledExpression internal constructor( } /** - * A context class for [AsmExpression] construction. + * A context class for [AsmNode] construction. */ -interface AsmExpressionAlgebra> : NumericAlgebra>, - ExpressionAlgebra> { +interface AsmExpressionAlgebra> : NumericAlgebra>, + ExpressionAlgebra> { /** * The algebra to provide for AsmExpressions built. */ @@ -157,108 +158,108 @@ interface AsmExpressionAlgebra> : NumericAlgebra = AsmNumberExpression(algebra, value) + override fun number(value: Number): AsmNode = AsmNumberExpression(algebra, value) /** * Builds an AsmExpression of constant expression which does not depend on arguments. */ - override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun const(value: T): AsmNode = AsmConstantExpression(value) /** * Builds an AsmExpression to access a variable. */ - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) + override fun variable(name: String, default: T?): AsmNode = AsmVariableExpression(name, default) /** * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. */ - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = AsmBinaryOperation(algebra, operation, left, right) /** * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. */ - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = AsmUnaryOperation(algebra, operation, arg) } /** - * A context class for [AsmExpression] construction for [Space] algebras. + * A context class for [AsmNode] construction for [Space] algebras. */ open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, - Space> where A : Space, A : NumericAlgebra { - override val zero: AsmExpression + Space> where A : Space, A : NumericAlgebra { + override val zero: AsmNode get() = const(algebra.zero) /** * Builds an AsmExpression of addition of two another expressions. */ - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + override fun add(a: AsmNode, b: AsmNode): AsmNode = AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) /** * Builds an AsmExpression of multiplication of expression by number. */ - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(algebra, a, k) + override fun multiply(a: AsmNode, k: Number): AsmNode = AsmConstProductExpression(algebra, a, k) - operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) - operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) - operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this - operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this + operator fun AsmNode.plus(arg: T): AsmNode = this + const(arg) + operator fun AsmNode.minus(arg: T): AsmNode = this - const(arg) + operator fun T.plus(arg: AsmNode): AsmNode = arg + this + operator fun T.minus(arg: AsmNode): AsmNode = arg - this - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = super.binaryOperation(operation, left, right) } /** - * A context class for [AsmExpression] construction for [Ring] algebras. + * A context class for [AsmNode] construction for [Ring] algebras. */ open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { - override val one: AsmExpression + Ring> where A : Ring, A : NumericAlgebra { + override val one: AsmNode get() = const(algebra.one) /** * Builds an AsmExpression of multiplication of two expressions. */ - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + override fun multiply(a: AsmNode, b: AsmNode): AsmNode = AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) - operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) - operator fun T.times(arg: AsmExpression): AsmExpression = arg * this + operator fun AsmNode.times(arg: T): AsmNode = this * const(arg) + operator fun T.times(arg: AsmNode): AsmNode = arg * this - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmExpression = super.number(value) + override fun number(value: Number): AsmNode = super.number(value) } /** - * A context class for [AsmExpression] construction for [Field] algebras. + * A context class for [AsmNode] construction for [Field] algebras. */ open class AsmExpressionField(override val algebra: A) : AsmExpressionRing(algebra), - Field> where A : Field, A : NumericAlgebra { + Field> where A : Field, A : NumericAlgebra { /** * Builds an AsmExpression of division an expression by another one. */ - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + override fun divide(a: AsmNode, b: AsmNode): AsmNode = AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) - operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) - operator fun T.div(arg: AsmExpression): AsmExpression = arg / this + operator fun AsmNode.div(arg: T): AsmNode = this / const(arg) + operator fun T.div(arg: AsmNode): AsmNode = arg / this - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmExpression = super.number(value) + override fun number(value: Number): AsmNode = super.number(value) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt similarity index 96% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt index be8da3eff..fa0f06579 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt @@ -1,17 +1,15 @@ -package scientifik.kmath.asm +package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.AsmGenerationContext.ClassLoader -import scientifik.kmath.asm.internal.visitLdcOrDConstInsn -import scientifik.kmath.asm.internal.visitLdcOrFConstInsn -import scientifik.kmath.asm.internal.visitLdcOrIConstInsn +import scientifik.kmath.asm.FunctionalCompiledExpression +import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader import scientifik.kmath.operations.Algebra /** - * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java + * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java * expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new * class. * @@ -19,7 +17,8 @@ import scientifik.kmath.operations.Algebra * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. */ -class AsmGenerationContext @PublishedApi internal constructor( +@PublishedApi +internal class AsmGenerationContext @PublishedApi internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt index cd41ff07c..815f292ce 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt @@ -2,8 +2,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes import scientifik.kmath.asm.AsmConstantExpression -import scientifik.kmath.asm.AsmExpression -import scientifik.kmath.asm.AsmGenerationContext +import scientifik.kmath.asm.AsmNode import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") @@ -44,7 +43,7 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, } @PublishedApi -internal fun AsmExpression.optimize(): AsmExpression { +internal fun AsmNode.optimize(): AsmNode { val a = tryEvaluate() return if (a == null) this else AsmConstantExpression(a) } From 635aac5f30ce84516db3d3bcc1d4a065572a728c Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 22:58:09 +0700 Subject: [PATCH 053/104] Refactor ex-AsmGenerationContext, introduce many bytecode utility functions to make its code readable, update compile method --- .../scientifik/kmath/asm/AsmBuilders.kt | 10 +- .../scientifik/kmath/asm/AsmExpressions.kt | 64 ++-- .../asm/internal/AsmGenerationContext.kt | 319 ------------------ .../kmath/asm/internal/AsmGenerator.kt | 311 +++++++++++++++++ .../kmath/asm/internal/ClassWriters.kt | 15 + .../kmath/asm/internal/MethodVisitors.kt | 32 +- .../kmath/asm/internal/Optimization.kt | 8 +- 7 files changed, 394 insertions(+), 365 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/ClassWriters.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt index 69e49f8b6..223278756 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt @@ -1,6 +1,6 @@ package scientifik.kmath.asm -import scientifik.kmath.asm.internal.AsmGenerationContext +import scientifik.kmath.asm.internal.AsmGenerator import scientifik.kmath.ast.MST import scientifik.kmath.ast.evaluate import scientifik.kmath.expressions.Expression @@ -20,12 +20,8 @@ internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String { } @PublishedApi -internal inline fun AsmNode.compile(algebra: Algebra): Expression { - val ctx = - AsmGenerationContext(T::class.java, algebra, buildName(this)) - compile(ctx) - return ctx.generate() -} +internal inline fun AsmNode.compile(algebra: Algebra): Expression = + AsmGenerator(T::class.java, algebra, buildName(this), this).getInstance() inline fun , E : AsmExpressionAlgebra> A.asm( expressionAlgebra: E, diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 8f16a2710..c1f71612f 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,6 +1,6 @@ package scientifik.kmath.asm -import scientifik.kmath.asm.internal.AsmGenerationContext +import scientifik.kmath.asm.internal.AsmGenerator import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.optimize import scientifik.kmath.asm.internal.tryInvokeSpecific @@ -9,7 +9,7 @@ import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.* /** - * A function declaration that could be compiled to [AsmGenerationContext]. + * A function declaration that could be compiled to [AsmGenerator]. * * @param T the type the stored function returns. */ @@ -24,10 +24,10 @@ abstract class AsmNode internal constructor() { /** * Compiles this declaration. * - * @param gen the target [AsmGenerationContext]. + * @param gen the target [AsmGenerator]. */ @PublishedApi - internal abstract fun compile(gen: AsmGenerationContext) + internal abstract fun compile(gen: AsmGenerator) } internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmNode) : @@ -35,23 +35,23 @@ internal class AsmUnaryOperation(private val context: Algebra, private val private val expr: AsmNode = expr.optimize() override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } - override fun compile(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() + override fun compile(gen: AsmGenerator) { + gen.loadAlgebra() if (!hasSpecific(context, name, 1)) - gen.visitStringConstant(name) + gen.loadStringConstant(name) expr.compile(gen) if (gen.tryInvokeSpecific(context, name, 1)) return - gen.visitAlgebraOperation( - owner = AsmGenerationContext.ALGEBRA_CLASS, + gen.invokeAlgebraOperation( + owner = AsmGenerator.ALGEBRA_CLASS, method = "unaryOperation", - descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmGenerator.STRING_CLASS};" + + "L${AsmGenerator.OBJECT_CLASS};)" + + "L${AsmGenerator.OBJECT_CLASS};" ) } } @@ -73,11 +73,11 @@ internal class AsmBinaryOperation( ) } - override fun compile(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() + override fun compile(gen: AsmGenerator) { + gen.loadAlgebra() if (!hasSpecific(context, name, 2)) - gen.visitStringConstant(name) + gen.loadStringConstant(name) first.compile(gen) second.compile(gen) @@ -85,26 +85,26 @@ internal class AsmBinaryOperation( if (gen.tryInvokeSpecific(context, name, 2)) return - gen.visitAlgebraOperation( - owner = AsmGenerationContext.ALGEBRA_CLASS, + gen.invokeAlgebraOperation( + owner = AsmGenerator.ALGEBRA_CLASS, method = "binaryOperation", - descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmGenerator.STRING_CLASS};" + + "L${AsmGenerator.OBJECT_CLASS};" + + "L${AsmGenerator.OBJECT_CLASS};)" + + "L${AsmGenerator.OBJECT_CLASS};" ) } } internal class AsmVariableExpression(private val name: String, private val default: T? = null) : AsmNode() { - override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) + override fun compile(gen: AsmGenerator): Unit = gen.loadFromVariables(name, default) } internal class AsmConstantExpression(private val value: T) : AsmNode() { override fun tryEvaluate(): T = value - override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) + override fun compile(gen: AsmGenerator): Unit = gen.loadTConstant(value) } internal class AsmConstProductExpression( @@ -116,17 +116,17 @@ internal class AsmConstProductExpression( override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } - override fun compile(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - gen.visitNumberConstant(const) + override fun compile(gen: AsmGenerator) { + gen.loadAlgebra() + gen.loadNumberConstant(const) expr.compile(gen) - gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + gen.invokeAlgebraOperation( + owner = AsmGenerator.SPACE_OPERATIONS_CLASS, method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};" + - "L${AsmGenerationContext.NUMBER_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmGenerator.OBJECT_CLASS};" + + "L${AsmGenerator.NUMBER_CLASS};)" + + "L${AsmGenerator.OBJECT_CLASS};" ) } } @@ -135,7 +135,7 @@ internal class AsmNumberExpression(private val context: NumericAlgebra, pr AsmNode() { override fun tryEvaluate(): T? = context.number(value) - override fun compile(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) + override fun compile(gen: AsmGenerator): Unit = gen.loadNumberConstant(value) } internal abstract class FunctionalCompiledExpression internal constructor( diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt deleted file mode 100644 index fa0f06579..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt +++ /dev/null @@ -1,319 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.Label -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.FunctionalCompiledExpression -import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader -import scientifik.kmath.operations.Algebra - -/** - * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java - * expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new - * class. - * - * @param T the type of AsmExpression to unwrap. - * @param algebra the algebra the applied AsmExpressions use. - * @param className the unique class name of new loaded class. - */ -@PublishedApi -internal class AsmGenerationContext @PublishedApi internal constructor( - private val classOfT: Class<*>, - private val algebra: Algebra, - private val className: String -) { - private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { - internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) - } - - private val classLoader: ClassLoader = - ClassLoader(javaClass.classLoader) - - @Suppress("PrivatePropertyName") - private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') - - @Suppress("PrivatePropertyName") - private val T_CLASS: String = classOfT.name.replace('.', '/') - - private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') - private val invokeThisVar: Int = 0 - private val invokeArgumentsVar: Int = 1 - private var maxStack: Int = 0 - private val constants: MutableList = mutableListOf() - private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) - private val invokeMethodVisitor: MethodVisitor - private val invokeL0: Label - private lateinit var invokeL1: Label - private var generatedInstance: FunctionalCompiledExpression? = null - - init { - asmCompiledClassWriter.visit( - Opcodes.V1_8, - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - slashesClassName, - "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, - arrayOf() - ) - - asmCompiledClassWriter.run { - visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run { - val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 - val l0 = Label() - visitLabel(l0) - visitVarInsn(Opcodes.ALOAD, thisVar) - visitVarInsn(Opcodes.ALOAD, algebraVar) - visitVarInsn(Opcodes.ALOAD, constantsVar) - - visitMethodInsn( - Opcodes.INVOKESPECIAL, - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, - "", - "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", - false - ) - - val l1 = Label() - visitLabel(l1) - visitInsn(Opcodes.RETURN) - val l2 = Label() - visitLabel(l2) - visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) - - visitLocalVariable( - "algebra", - "L$ALGEBRA_CLASS;", - "L$ALGEBRA_CLASS;", - l0, - l2, - algebraVar - ) - - visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) - visitMaxs(3, 3) - visitEnd() - } - - invokeMethodVisitor = visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, - "invoke", - "(L$MAP_CLASS;)L$T_CLASS;", - "(L$MAP_CLASS;)L$T_CLASS;", - null - ) - - invokeMethodVisitor.run { - visitCode() - invokeL0 = Label() - visitLabel(invokeL0) - } - } - } - - @PublishedApi - @Suppress("UNCHECKED_CAST") - internal fun generate(): FunctionalCompiledExpression { - generatedInstance?.let { return it } - - invokeMethodVisitor.run { - visitInsn(Opcodes.ARETURN) - invokeL1 = Label() - visitLabel(invokeL1) - - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - invokeL0, - invokeL1, - invokeThisVar - ) - - visitLocalVariable( - "arguments", - "L$MAP_CLASS;", - "L$MAP_CLASS;", - invokeL0, - invokeL1, - invokeArgumentsVar - ) - - visitMaxs(maxStack + 1, 2) - visitEnd() - } - - asmCompiledClassWriter.visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, - "invoke", - "(L$MAP_CLASS;)L$OBJECT_CLASS;", - null, - null - ).run { - val thisVar = 0 - visitCode() - val l0 = Label() - visitLabel(l0) - visitVarInsn(Opcodes.ALOAD, 0) - visitVarInsn(Opcodes.ALOAD, 1) - visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) - visitInsn(Opcodes.ARETURN) - val l1 = Label() - visitLabel(l1) - - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - l0, - l1, - thisVar - ) - - visitMaxs(2, 2) - visitEnd() - } - - asmCompiledClassWriter.visitEnd() - - val new = classLoader - .defineClass(className, asmCompiledClassWriter.toByteArray()) - .constructors - .first() - .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression - - generatedInstance = new - return new - } - - internal fun visitLoadFromConstants(value: T) { - if (classOfT in INLINABLE_NUMBERS) { - visitNumberConstant(value as Number) - visitCastToT() - return - } - - visitLoadAnyFromConstants(value as Any, T_CLASS) - } - - private fun visitLoadAnyFromConstants(value: Any, type: String) { - val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - maxStack++ - - invokeMethodVisitor.run { - visitLoadThis() - visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;") - visitLdcOrIConstInsn(idx) - visitInsn(Opcodes.AALOAD) - invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) - } - } - - private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) - - internal fun visitNumberConstant(value: Number) { - maxStack++ - val clazz = value.javaClass - val c = clazz.name.replace('.', '/') - val sigLetter = SIGNATURE_LETTERS[clazz] - - if (sigLetter != null) { - when (value) { - is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) - is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) - is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) - else -> invokeMethodVisitor.visitLdcInsn(value) - } - - invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) - return - } - - visitLoadAnyFromConstants(value, c) - } - - internal fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - maxStack += 2 - visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) - - if (defaultValue != null) { - visitLdcInsn(name) - visitLoadFromConstants(defaultValue) - - visitMethodInsn( - Opcodes.INVOKEINTERFACE, - MAP_CLASS, - "getOrDefault", - "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", - true - ) - - visitCastToT() - return - } - - visitLdcInsn(name) - visitMethodInsn( - Opcodes.INVOKEINTERFACE, - MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true - ) - visitCastToT() - } - - internal fun visitLoadAlgebra() { - maxStack++ - invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) - - invokeMethodVisitor.visitFieldInsn( - Opcodes.GETFIELD, - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" - ) - - invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) - } - - internal fun visitAlgebraOperation( - owner: String, - method: String, - descriptor: String, - opcode: Int = Opcodes.INVOKEINTERFACE, - isInterface: Boolean = true - ) { - maxStack++ - invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) - visitCastToT() - } - - private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) - - internal fun visitStringConstant(string: String) { - invokeMethodVisitor.visitLdcInsn(string) - } - - internal companion object { - private val SIGNATURE_LETTERS by lazy { - mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" - ) - } - - private val INLINABLE_NUMBERS by lazy { SIGNATURE_LETTERS.keys } - - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/asm/FunctionalCompiledExpression" - - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" - internal const val STRING_CLASS = "java/lang/String" - internal const val NUMBER_CLASS = "java/lang/Number" - } -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt new file mode 100644 index 000000000..f636fc289 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt @@ -0,0 +1,311 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes +import scientifik.kmath.asm.AsmNode +import scientifik.kmath.asm.FunctionalCompiledExpression +import scientifik.kmath.asm.internal.AsmGenerator.ClassLoader +import scientifik.kmath.operations.Algebra + +/** + * ASM Generator is a structure that abstracts building a class that unwraps [AsmNode] to plain Java expression. + * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. + * + * @param T the type of AsmExpression to unwrap. + * @param algebra the algebra the applied AsmExpressions use. + * @param className the unique class name of new loaded class. + */ +@Suppress("UNCHECKED_CAST") +@PublishedApi +internal class AsmGenerator @PublishedApi internal constructor( + private val classOfT: Class<*>, + private val algebra: Algebra, + private val className: String, + private val root: AsmNode +) { + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + private val classLoader: ClassLoader = + ClassLoader(javaClass.classLoader) + + @Suppress("PrivatePropertyName") + private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + + @Suppress("PrivatePropertyName") + private val T_CLASS: String = classOfT.name.replace('.', '/') + + private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + private val invokeThisVar: Int = 0 + private val invokeArgumentsVar: Int = 1 + private val constants: MutableList = mutableListOf() + private lateinit var invokeMethodVisitor: MethodVisitor + private var generatedInstance: FunctionalCompiledExpression? = null + + fun getInstance(): FunctionalCompiledExpression { + generatedInstance?.let { return it } + + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { + visit( + Opcodes.V1_8, + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + slashesClassName, + "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + arrayOf() + ) + + visitMethod( + access = Opcodes.ACC_PUBLIC, + name = "", + descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", + signature = null, + exceptions = null + ) { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = Label() + visitLabel(l0) + visitLoadObjectVar(thisVar) + visitLoadObjectVar(algebraVar) + visitLoadObjectVar(constantsVar) + + visitInvokeSpecial( + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + "", + "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V" + ) + + val l1 = Label() + visitLabel(l1) + visitReturn() + val l2 = Label() + visitLabel(l2) + visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + + visitLocalVariable( + "algebra", + "L$ALGEBRA_CLASS;", + "L$ALGEBRA_CLASS;", + l0, + l2, + algebraVar + ) + + visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) + visitMaxs(0, 3) + visitEnd() + } + + invokeMethodVisitor = visitMethod( + access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + name = "invoke", + descriptor = "(L$MAP_CLASS;)L$T_CLASS;", + signature = "(L$MAP_CLASS;)L$T_CLASS;", + exceptions = null + ) {} + + invokeMethodVisitor.run { + visitCode() + val l0 = Label() + visitLabel(l0) + root.compile(this@AsmGenerator) + visitReturnObject() + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + invokeThisVar + ) + + visitLocalVariable( + "arguments", + "L$MAP_CLASS;", + "L$MAP_CLASS;", + l0, + l1, + invokeArgumentsVar + ) + + visitMaxs(0, 2) + visitEnd() + } + + visitMethod( + access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + name = "invoke", + descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;", + signature = null, + exceptions = null + ) { + val thisVar = 0 + val argumentsVar = 1 + visitCode() + val l0 = Label() + visitLabel(l0) + visitLoadObjectVar(thisVar) + visitLoadObjectVar(argumentsVar) + visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;") + visitReturnObject() + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + thisVar + ) + + visitMaxs(0, 2) + visitEnd() + } + + visitEnd() + } + + val new = classLoader + .defineClass(className, classWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression + + generatedInstance = new + return new + } + + internal fun loadTConstant(value: T) { + if (classOfT in INLINABLE_NUMBERS) { + loadNumberConstant(value as Number) + invokeMethodVisitor.visitCheckCast(T_CLASS) + return + } + + loadConstant(value as Any, T_CLASS) + } + + private fun loadConstant(value: Any, type: String) { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + + invokeMethodVisitor.run { + loadThis() + visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") + visitLdcOrIntConstant(idx) + visitGetObjectArrayElement() + invokeMethodVisitor.visitCheckCast(type) + } + } + + private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) + + internal fun loadNumberConstant(value: Number) { + val clazz = value.javaClass + val c = clazz.name.replace('.', '/') + val sigLetter = SIGNATURE_LETTERS[clazz] + + if (sigLetter != null) { + when (value) { + is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) + is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) + is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) + else -> invokeMethodVisitor.visitLdcInsn(value) + } + + invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + return + } + + loadConstant(value, c) + } + + internal fun loadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + visitLoadObjectVar(invokeArgumentsVar) + + if (defaultValue != null) { + visitLdcInsn(name) + loadTConstant(defaultValue) + + visitInvokeInterface( + owner = MAP_CLASS, + name = "getOrDefault", + descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;" + ) + + invokeMethodVisitor.visitCheckCast(T_CLASS) + return + } + + visitLdcInsn(name) + + visitInvokeInterface( + owner = MAP_CLASS, + name = "get", + descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;" + ) + + invokeMethodVisitor.visitCheckCast(T_CLASS) + } + + internal fun loadAlgebra() { + loadThis() + + invokeMethodVisitor.visitGetField( + owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + name = "algebra", + descriptor = "L$ALGEBRA_CLASS;" + ) + + invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) + } + + internal fun invokeAlgebraOperation( + owner: String, + method: String, + descriptor: String, + opcode: Int = Opcodes.INVOKEINTERFACE, + isInterface: Boolean = true + ) { + invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) + invokeMethodVisitor.visitCheckCast(T_CLASS) + } + + internal fun loadStringConstant(string: String) { + invokeMethodVisitor.visitLdcInsn(string) + } + + internal companion object { + private val SIGNATURE_LETTERS: Map, String> by lazy { + mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) + } + + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + + internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = + "scientifik/kmath/asm/FunctionalCompiledExpression" + + internal const val MAP_CLASS = "java/util/Map" + internal const val OBJECT_CLASS = "java/lang/Object" + internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + internal const val STRING_CLASS = "java/lang/String" + internal const val NUMBER_CLASS = "java/lang/Number" + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/ClassWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/ClassWriters.kt new file mode 100644 index 000000000..84f736ee7 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/ClassWriters.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.MethodVisitor + +inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) + +inline fun ClassWriter.visitMethod( + access: Int, + name: String, + descriptor: String, + signature: String?, + exceptions: Array?, + block: MethodVisitor.() -> Unit +): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt index 356de4765..d29d786fc 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt @@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* -internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { +internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) { -1 -> visitInsn(ICONST_M1) 0 -> visitInsn(ICONST_0) 1 -> visitInsn(ICONST_1) @@ -14,15 +14,41 @@ internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { else -> visitLdcInsn(value) } -internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) { +internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) { 0.0 -> visitInsn(DCONST_0) 1.0 -> visitInsn(DCONST_1) else -> visitLdcInsn(value) } -internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) { +internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) { 0f -> visitInsn(FCONST_0) 1f -> visitInsn(FCONST_1) 2f -> visitInsn(FCONST_2) else -> visitLdcInsn(value) } + +internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true) + +internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false) + +internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false) + +internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false) + +internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type) + +internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit = + visitFieldInsn(GETFIELD, owner, name, descriptor) + +internal fun MethodVisitor.visitLoadObjectVar(`var`: Int) { + visitVarInsn(ALOAD, `var`) +} + +internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD) + +internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN) +internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt index 815f292ce..3b37f81e5 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt @@ -16,7 +16,7 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo return true } -internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +internal fun AsmGenerator.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name context::class.java.methods.find { it.name == aName && it.parameters.size == arity } @@ -26,12 +26,12 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, val sig = buildString { append('(') - repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } + repeat(arity) { append("L${AsmGenerator.OBJECT_CLASS};") } append(')') - append("L${AsmGenerationContext.OBJECT_CLASS};") + append("L${AsmGenerator.OBJECT_CLASS};") } - visitAlgebraOperation( + invokeAlgebraOperation( owner = owner, method = aName, descriptor = sig, From 4e28ad7d4e91eb26196907d5f28c618211f71a31 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 23:00:26 +0700 Subject: [PATCH 054/104] Minor refactor --- .../scientifik/kmath/asm/internal/AsmGenerator.kt | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt index f636fc289..d2568eb53 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt @@ -17,7 +17,6 @@ import scientifik.kmath.operations.Algebra * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. */ -@Suppress("UNCHECKED_CAST") @PublishedApi internal class AsmGenerator @PublishedApi internal constructor( private val classOfT: Class<*>, @@ -45,6 +44,7 @@ internal class AsmGenerator @PublishedApi internal constructor( private lateinit var invokeMethodVisitor: MethodVisitor private var generatedInstance: FunctionalCompiledExpression? = null + @Suppress("UNCHECKED_CAST") fun getInstance(): FunctionalCompiledExpression { generatedInstance?.let { return it } @@ -101,15 +101,14 @@ internal class AsmGenerator @PublishedApi internal constructor( visitEnd() } - invokeMethodVisitor = visitMethod( + visitMethod( access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;", signature = "(L$MAP_CLASS;)L$T_CLASS;", exceptions = null - ) {} - - invokeMethodVisitor.run { + ) { + invokeMethodVisitor = this visitCode() val l0 = Label() visitLabel(l0) @@ -121,7 +120,7 @@ internal class AsmGenerator @PublishedApi internal constructor( visitLocalVariable( "this", "L$slashesClassName;", - T_CLASS, + null, l0, l1, invokeThisVar @@ -280,9 +279,7 @@ internal class AsmGenerator @PublishedApi internal constructor( invokeMethodVisitor.visitCheckCast(T_CLASS) } - internal fun loadStringConstant(string: String) { - invokeMethodVisitor.visitLdcInsn(string) - } + internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) internal companion object { private val SIGNATURE_LETTERS: Map, String> by lazy { From a8fa38549758b5602118c85d23ffdf061fbce4ba Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 23:01:34 +0700 Subject: [PATCH 055/104] Rename loadFromVariables to loadVariable --- .../src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt | 2 +- .../kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index c1f71612f..72a53d899 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -98,7 +98,7 @@ internal class AsmBinaryOperation( internal class AsmVariableExpression(private val name: String, private val default: T? = null) : AsmNode() { - override fun compile(gen: AsmGenerator): Unit = gen.loadFromVariables(name, default) + override fun compile(gen: AsmGenerator): Unit = gen.loadVariable(name, default) } internal class AsmConstantExpression(private val value: T) : diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt index d2568eb53..0c6975cd9 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt @@ -228,7 +228,7 @@ internal class AsmGenerator @PublishedApi internal constructor( loadConstant(value, c) } - internal fun loadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { visitLoadObjectVar(invokeArgumentsVar) if (defaultValue != null) { From a7302f49fffdadf91e4255fa3f65e476d51af96d Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 23:13:26 +0700 Subject: [PATCH 056/104] Convert to expression body --- .../kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt index d29d786fc..70fb3bd44 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt @@ -44,9 +44,7 @@ internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CH internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit = visitFieldInsn(GETFIELD, owner, name, descriptor) -internal fun MethodVisitor.visitLoadObjectVar(`var`: Int) { - visitVarInsn(ALOAD, `var`) -} +internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`) internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD) From 3434dde1d1eeec094f03a150f55a73ac81fa5f59 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 15 Jun 2020 11:02:13 +0300 Subject: [PATCH 057/104] ASM API simplification --- kmath-ast/README.md | 4 + .../kotlin/scientifik/kmath/ast/MST.kt | 4 +- .../scientifik/kmath/asm/AsmBuilders.kt | 56 ---- .../scientifik/kmath/asm/AsmExpressions.kt | 249 ++++++++++-------- .../kotlin/scientifik/kmath/asm/asm.kt | 47 ++++ ...{AsmGenerationContext.kt => AsmBuilder.kt} | 66 +++-- .../{MethodVisitors.kt => methodVisitors.kt} | 0 .../{Optimization.kt => optimization.kt} | 13 +- .../kmath/{ast => }/asm/TestAsmAlgebras.kt | 2 +- .../kmath/{ast => }/asm/TestAsmExpressions.kt | 2 +- ...ions.kt => FunctionalExpressionAlgebra.kt} | 0 .../kmath/expressions/ExpressionFieldTest.kt | 4 +- 12 files changed, 231 insertions(+), 216 deletions(-) create mode 100644 kmath-ast/README.md delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{AsmGenerationContext.kt => AsmBuilder.kt} (82%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{MethodVisitors.kt => methodVisitors.kt} (100%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{Optimization.kt => optimization.kt} (72%) rename kmath-ast/src/jvmTest/kotlin/scietifik/kmath/{ast => }/asm/TestAsmAlgebras.kt (98%) rename kmath-ast/src/jvmTest/kotlin/scietifik/kmath/{ast => }/asm/TestAsmExpressions.kt (94%) rename kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/{FunctionalExpressions.kt => FunctionalExpressionAlgebra.kt} (100%) diff --git a/kmath-ast/README.md b/kmath-ast/README.md new file mode 100644 index 000000000..f9dbd4663 --- /dev/null +++ b/kmath-ast/README.md @@ -0,0 +1,4 @@ +# AST based expression representation and operations + +## Dynamic expression code generation +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt index 3375ecdf3..900b9297a 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -11,7 +11,7 @@ sealed class MST { /** * A node containing unparsed string */ - data class Singular(val value: String) : MST() + data class Symbolic(val value: String) : MST() /** * A node containing a number @@ -43,7 +43,7 @@ sealed class MST { fun NumericAlgebra.evaluate(node: MST): T { return when (node) { is MST.Numeric -> number(node.value) - is MST.Singular -> symbol(node.value) + is MST.Symbolic -> symbol(node.value) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Binary -> when { node.left is MST.Numeric && node.right is MST.Numeric -> { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt deleted file mode 100644 index 69e49f8b6..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt +++ /dev/null @@ -1,56 +0,0 @@ -package scientifik.kmath.asm - -import scientifik.kmath.asm.internal.AsmGenerationContext -import scientifik.kmath.ast.MST -import scientifik.kmath.ast.evaluate -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.* - -@PublishedApi -internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(expression, collision + 1) -} - -@PublishedApi -internal inline fun AsmNode.compile(algebra: Algebra): Expression { - val ctx = - AsmGenerationContext(T::class.java, algebra, buildName(this)) - compile(ctx) - return ctx.generate() -} - -inline fun , E : AsmExpressionAlgebra> A.asm( - expressionAlgebra: E, - block: E.() -> AsmNode -): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) - -inline fun , E : AsmExpressionAlgebra> A.asm( - expressionAlgebra: E, - ast: MST -): Expression = asm(expressionAlgebra) { evaluate(ast) } - -inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmNode): Expression where A : NumericAlgebra, A : Space = - AsmExpressionSpace(this).let { it.block().compile(it.algebra) } - -inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = - asmSpace { evaluate(ast) } - -inline fun A.asmRing(block: AsmExpressionRing.() -> AsmNode): Expression where A : NumericAlgebra, A : Ring = - AsmExpressionRing(this).let { it.block().compile(it.algebra) } - -inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = - asmRing { evaluate(ast) } - -inline fun A.asmField(block: AsmExpressionField.() -> AsmNode): Expression where A : NumericAlgebra, A : Field = - AsmExpressionField(this).let { it.block().compile(it.algebra) } - -inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = - asmRing { evaluate(ast) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 8f16a2710..a8b3ff976 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,19 +1,24 @@ package scientifik.kmath.asm -import scientifik.kmath.asm.internal.AsmGenerationContext +import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.optimize import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.* +import kotlin.reflect.KClass /** - * A function declaration that could be compiled to [AsmGenerationContext]. + * A function declaration that could be compiled to [AsmBuilder]. * * @param T the type the stored function returns. */ -abstract class AsmNode internal constructor() { +sealed class AsmExpression: Expression { + abstract val type: KClass + + abstract val algebra: Algebra + /** * Tries to evaluate this function without its variables. This method is intended for optimization. * @@ -24,118 +29,144 @@ abstract class AsmNode internal constructor() { /** * Compiles this declaration. * - * @param gen the target [AsmGenerationContext]. + * @param gen the target [AsmBuilder]. */ - @PublishedApi - internal abstract fun compile(gen: AsmGenerationContext) + internal abstract fun appendTo(gen: AsmBuilder) + + /** + * Compile and cache the expression + */ + private val compiledExpression by lazy{ + val builder = AsmBuilder(type.java, algebra, buildName(this)) + this.appendTo(builder) + builder.generate() + } + + override fun invoke(arguments: Map): T = compiledExpression.invoke(arguments) } -internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmNode) : - AsmNode() { - private val expr: AsmNode = expr.optimize() - override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } +internal class AsmUnaryOperation( + override val type: KClass, + override val algebra: Algebra, + private val name: String, + expr: AsmExpression +) : AsmExpression() { + private val expr: AsmExpression = expr.optimize() + override fun tryEvaluate(): T? = algebra { unaryOperation(name, expr.tryEvaluate() ?: return@algebra null) } - override fun compile(gen: AsmGenerationContext) { + override fun appendTo(gen: AsmBuilder) { gen.visitLoadAlgebra() - if (!hasSpecific(context, name, 1)) + if (!hasSpecific(algebra, name, 1)) gen.visitStringConstant(name) - expr.compile(gen) + expr.appendTo(gen) - if (gen.tryInvokeSpecific(context, name, 1)) + if (gen.tryInvokeSpecific(algebra, name, 1)) return gen.visitAlgebraOperation( - owner = AsmGenerationContext.ALGEBRA_CLASS, + owner = AsmBuilder.ALGEBRA_CLASS, method = "unaryOperation", - descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } -internal class AsmBinaryOperation( - private val context: Algebra, +internal class AsmBinaryOperation( + override val type: KClass, + override val algebra: Algebra, private val name: String, - first: AsmNode, - second: AsmNode -) : AsmNode() { - private val first: AsmNode = first.optimize() - private val second: AsmNode = second.optimize() + first: AsmExpression, + second: AsmExpression +) : AsmExpression() { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() - override fun tryEvaluate(): T? = context { + override fun tryEvaluate(): T? = algebra { binaryOperation( name, - first.tryEvaluate() ?: return@context null, - second.tryEvaluate() ?: return@context null + first.tryEvaluate() ?: return@algebra null, + second.tryEvaluate() ?: return@algebra null ) } - override fun compile(gen: AsmGenerationContext) { + override fun appendTo(gen: AsmBuilder) { gen.visitLoadAlgebra() - if (!hasSpecific(context, name, 2)) + if (!hasSpecific(algebra, name, 2)) gen.visitStringConstant(name) - first.compile(gen) - second.compile(gen) + first.appendTo(gen) + second.appendTo(gen) - if (gen.tryInvokeSpecific(context, name, 2)) + if (gen.tryInvokeSpecific(algebra, name, 2)) return gen.visitAlgebraOperation( - owner = AsmGenerationContext.ALGEBRA_CLASS, + owner = AsmBuilder.ALGEBRA_CLASS, method = "binaryOperation", - descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } -internal class AsmVariableExpression(private val name: String, private val default: T? = null) : - AsmNode() { - override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) +internal class AsmVariableExpression( + override val type: KClass, + override val algebra: Algebra, + private val name: String, + private val default: T? = null +) : AsmExpression() { + override fun appendTo(gen: AsmBuilder): Unit = gen.visitLoadFromVariables(name, default) } -internal class AsmConstantExpression(private val value: T) : - AsmNode() { +internal class AsmConstantExpression( + override val type: KClass, + override val algebra: Algebra, + private val value: T +) : AsmExpression() { override fun tryEvaluate(): T = value - override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) + override fun appendTo(gen: AsmBuilder): Unit = gen.visitLoadFromConstants(value) } -internal class AsmConstProductExpression( - private val context: Space, - expr: AsmNode, +internal class AsmConstProductExpression( + override val type: KClass, + override val algebra: Space, + expr: AsmExpression, private val const: Number -) : AsmNode() { - private val expr: AsmNode = expr.optimize() +) : AsmExpression() { + private val expr: AsmExpression = expr.optimize() - override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } + override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const } - override fun compile(gen: AsmGenerationContext) { + override fun appendTo(gen: AsmBuilder) { gen.visitLoadAlgebra() gen.visitNumberConstant(const) - expr.compile(gen) + expr.appendTo(gen) gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + owner = AsmBuilder.SPACE_OPERATIONS_CLASS, method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};" + - "L${AsmGenerationContext.NUMBER_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.OBJECT_CLASS};" + + "L${AsmBuilder.NUMBER_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } -internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : - AsmNode() { - override fun tryEvaluate(): T? = context.number(value) +internal class AsmNumberExpression( + override val type: KClass, + override val algebra: NumericAlgebra, + private val value: Number +) : AsmExpression() { + override fun tryEvaluate(): T? = algebra.number(value) - override fun compile(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) + override fun appendTo(gen: AsmBuilder): Unit = gen.visitNumberConstant(value) } internal abstract class FunctionalCompiledExpression internal constructor( @@ -146,120 +177,116 @@ internal abstract class FunctionalCompiledExpression internal constructor( } /** - * A context class for [AsmNode] construction. + * A context class for [AsmExpression] construction. + * + * @param algebra The algebra to provide for AsmExpressions built. */ -interface AsmExpressionAlgebra> : NumericAlgebra>, - ExpressionAlgebra> { - /** - * The algebra to provide for AsmExpressions built. - */ - val algebra: A +open class AsmExpressionAlgebra>(val type: KClass, val algebra: A) : + NumericAlgebra>, ExpressionAlgebra> { /** * Builds an AsmExpression to wrap a number. */ - override fun number(value: Number): AsmNode = AsmNumberExpression(algebra, value) + override fun number(value: Number): AsmExpression = AsmNumberExpression(type, algebra, value) /** * Builds an AsmExpression of constant expression which does not depend on arguments. */ - override fun const(value: T): AsmNode = AsmConstantExpression(value) + override fun const(value: T): AsmExpression = AsmConstantExpression(type, algebra, value) /** * Builds an AsmExpression to access a variable. */ - override fun variable(name: String, default: T?): AsmNode = AsmVariableExpression(name, default) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(type, algebra, name, default) /** * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. */ - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = - AsmBinaryOperation(algebra, operation, left, right) + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, operation, left, right) /** * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. */ - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = - AsmUnaryOperation(algebra, operation, arg) + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(type, algebra, operation, arg) } /** - * A context class for [AsmNode] construction for [Space] algebras. + * A context class for [AsmExpression] construction for [Space] algebras. */ -open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, - Space> where A : Space, A : NumericAlgebra { - override val zero: AsmNode - get() = const(algebra.zero) +open class AsmExpressionSpace(type: KClass, algebra: A) : AsmExpressionAlgebra(type, algebra), + Space> where A : Space, A : NumericAlgebra { + override val zero: AsmExpression get() = const(algebra.zero) /** * Builds an AsmExpression of addition of two another expressions. */ - override fun add(a: AsmNode, b: AsmNode): AsmNode = - AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, SpaceOperations.PLUS_OPERATION, a, b) /** * Builds an AsmExpression of multiplication of expression by number. */ - override fun multiply(a: AsmNode, k: Number): AsmNode = AsmConstProductExpression(algebra, a, k) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(type, algebra, a, k) - operator fun AsmNode.plus(arg: T): AsmNode = this + const(arg) - operator fun AsmNode.minus(arg: T): AsmNode = this - const(arg) - operator fun T.plus(arg: AsmNode): AsmNode = arg + this - operator fun T.minus(arg: AsmNode): AsmNode = arg - this + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) } /** - * A context class for [AsmNode] construction for [Ring] algebras. + * A context class for [AsmExpression] construction for [Ring] algebras. */ -open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { - override val one: AsmNode - get() = const(algebra.one) +open class AsmExpressionRing(type: KClass, algebra: A) : AsmExpressionSpace(type, algebra), + Ring> where A : Ring, A : NumericAlgebra { + override val one: AsmExpression get() = const(algebra.one) /** * Builds an AsmExpression of multiplication of two expressions. */ - override fun multiply(a: AsmNode, b: AsmNode): AsmNode = - AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, RingOperations.TIMES_OPERATION, a, b) - operator fun AsmNode.times(arg: T): AsmNode = this * const(arg) - operator fun T.times(arg: AsmNode): AsmNode = arg * this + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmNode = super.number(value) + override fun number(value: Number): AsmExpression = super.number(value) } /** - * A context class for [AsmNode] construction for [Field] algebras. + * A context class for [AsmExpression] construction for [Field] algebras. */ -open class AsmExpressionField(override val algebra: A) : - AsmExpressionRing(algebra), - Field> where A : Field, A : NumericAlgebra { +open class AsmExpressionField(type: KClass, algebra: A) : + AsmExpressionRing(type, algebra), + Field> where A : Field, A : NumericAlgebra { /** * Builds an AsmExpression of division an expression by another one. */ - override fun divide(a: AsmNode, b: AsmNode): AsmNode = - AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, FieldOperations.DIV_OPERATION, a, b) - operator fun AsmNode.div(arg: T): AsmNode = this / const(arg) - operator fun T.div(arg: AsmNode): AsmNode = arg / this + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmNode = super.number(value) + override fun number(value: Number): AsmExpression = super.number(value) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt new file mode 100644 index 000000000..cc3e36e94 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -0,0 +1,47 @@ +package scientifik.kmath.asm + +import scientifik.kmath.ast.MST +import scientifik.kmath.ast.evaluate +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.Space + +internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + +inline fun , E : AsmExpressionAlgebra> A.asm( + expressionAlgebra: E, + block: E.() -> AsmExpression +): Expression = expressionAlgebra.block() + +inline fun NumericAlgebra.asm(ast: MST): Expression = + AsmExpressionAlgebra(T::class, this).evaluate(ast) + +inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = + AsmExpressionSpace(T::class, this).block() + +inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = + asmSpace { evaluate(ast) } + +inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = + AsmExpressionRing(T::class, this).block() + +inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = + asmRing { evaluate(ast) } + +inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = + AsmExpressionField(T::class, this).block() + +inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = + asmRing { evaluate(ast) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt similarity index 82% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index fa0f06579..eefde66e4 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -5,30 +5,29 @@ import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.FunctionalCompiledExpression -import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader +import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader import scientifik.kmath.operations.Algebra /** * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java - * expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new + * expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new * class. * * @param T the type of AsmExpression to unwrap. * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. */ -@PublishedApi -internal class AsmGenerationContext @PublishedApi internal constructor( +internal class AsmBuilder( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String ) { - private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + private class AsmClassLoader(parent: ClassLoader) : ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } - private val classLoader: ClassLoader = - ClassLoader(javaClass.classLoader) + private val classLoader: AsmClassLoader = + AsmClassLoader(javaClass.classLoader) @Suppress("PrivatePropertyName") private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') @@ -113,9 +112,8 @@ internal class AsmGenerationContext @PublishedApi internal constructor( } } - @PublishedApi @Suppress("UNCHECKED_CAST") - internal fun generate(): FunctionalCompiledExpression { + fun generate(): FunctionalCompiledExpression { generatedInstance?.let { return it } invokeMethodVisitor.run { @@ -188,7 +186,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( return new } - internal fun visitLoadFromConstants(value: T) { + fun visitLoadFromConstants(value: T) { if (classOfT in INLINABLE_NUMBERS) { visitNumberConstant(value as Number) visitCastToT() @@ -213,7 +211,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) - internal fun visitNumberConstant(value: Number) { + fun visitNumberConstant(value: Number) { maxStack++ val clazz = value.javaClass val c = clazz.name.replace('.', '/') @@ -234,7 +232,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( visitLoadAnyFromConstants(value, c) } - internal fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { maxStack += 2 visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) @@ -262,7 +260,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( visitCastToT() } - internal fun visitLoadAlgebra() { + fun visitLoadAlgebra() { maxStack++ invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) @@ -274,7 +272,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) } - internal fun visitAlgebraOperation( + fun visitAlgebraOperation( owner: String, method: String, descriptor: String, @@ -288,32 +286,28 @@ internal class AsmGenerationContext @PublishedApi internal constructor( private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) - internal fun visitStringConstant(string: String) { + fun visitStringConstant(string: String) { invokeMethodVisitor.visitLdcInsn(string) } - internal companion object { - private val SIGNATURE_LETTERS by lazy { - mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" - ) - } + companion object { + private val SIGNATURE_LETTERS = mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) - private val INLINABLE_NUMBERS by lazy { SIGNATURE_LETTERS.keys } + private val INLINABLE_NUMBERS = SIGNATURE_LETTERS.keys - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/asm/FunctionalCompiledExpression" - - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" - internal const val STRING_CLASS = "java/lang/String" - internal const val NUMBER_CLASS = "java/lang/Number" + const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/FunctionalCompiledExpression" + const val MAP_CLASS = "java/util/Map" + const val OBJECT_CLASS = "java/lang/Object" + const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + const val STRING_CLASS = "java/lang/String" + const val NUMBER_CLASS = "java/lang/Number" } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt similarity index 72% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index 815f292ce..d9f950ceb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -2,7 +2,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes import scientifik.kmath.asm.AsmConstantExpression -import scientifik.kmath.asm.AsmNode +import scientifik.kmath.asm.AsmExpression import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") @@ -16,7 +16,7 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo return true } -internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name context::class.java.methods.find { it.name == aName && it.parameters.size == arity } @@ -26,9 +26,9 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, val sig = buildString { append('(') - repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } + repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } append(')') - append("L${AsmGenerationContext.OBJECT_CLASS};") + append("L${AsmBuilder.OBJECT_CLASS};") } visitAlgebraOperation( @@ -42,8 +42,7 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, return true } -@PublishedApi -internal fun AsmNode.optimize(): AsmNode { +internal fun AsmExpression.optimize(): AsmExpression { val a = tryEvaluate() - return if (a == null) this else AsmConstantExpression(a) + return if (a == null) this else AsmConstantExpression(type, algebra, a) } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt similarity index 98% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt rename to kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 07f895c40..06f776597 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,4 +1,4 @@ -package scietifik.kmath.ast.asm +package scietifik.kmath.asm import scientifik.kmath.asm.asmField import scientifik.kmath.asm.asmRing diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt similarity index 94% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt rename to kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 2839b8a25..40d990537 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -1,4 +1,4 @@ -package scietifik.kmath.ast.asm +package scietifik.kmath.asm import scientifik.kmath.asm.asmField import scientifik.kmath.expressions.invoke diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt similarity index 100% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index b933f6a17..9eae60efc 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -31,7 +31,7 @@ class ExpressionFieldTest { @Test fun separateContext() { - fun FunctionalExpressionField.expression(): Expression { + fun FunctionalExpressionField.expression(): Expression { val x = variable("x") return x * x + 2 * x + one } @@ -42,7 +42,7 @@ class ExpressionFieldTest { @Test fun valueExpression() { - val expressionBuilder: FunctionalExpressionField.() -> Expression = { + val expressionBuilder: FunctionalExpressionField.() -> Expression = { val x = variable("x") x * x + 2 * x + one } From 521ea8bddcb906bd43b9d1ac82712f2fcb305a19 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 15 Jun 2020 17:36:30 +0700 Subject: [PATCH 058/104] Rename ClassWriters.kt to be consistent with local code style, rename AsmBuilders.kt to asm.kt, rename AsmNode back to AsmExpression, rename AsmGenerator to AsmBuilder --- .../scientifik/kmath/asm/AsmBuilders.kt | 52 ------ .../scientifik/kmath/asm/AsmExpressions.kt | 148 +++++++++--------- .../kotlin/scientifik/kmath/asm/asm.kt | 40 +++-- .../{AsmGenerator.kt => AsmBuilder.kt} | 12 +- .../{ClassWriters.kt => classWriters.kt} | 0 .../kmath/asm/internal/optimization.kt | 10 +- 6 files changed, 108 insertions(+), 154 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{AsmGenerator.kt => AsmBuilder.kt} (96%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{ClassWriters.kt => classWriters.kt} (100%) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt deleted file mode 100644 index 223278756..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt +++ /dev/null @@ -1,52 +0,0 @@ -package scientifik.kmath.asm - -import scientifik.kmath.asm.internal.AsmGenerator -import scientifik.kmath.ast.MST -import scientifik.kmath.ast.evaluate -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.* - -@PublishedApi -internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(expression, collision + 1) -} - -@PublishedApi -internal inline fun AsmNode.compile(algebra: Algebra): Expression = - AsmGenerator(T::class.java, algebra, buildName(this), this).getInstance() - -inline fun , E : AsmExpressionAlgebra> A.asm( - expressionAlgebra: E, - block: E.() -> AsmNode -): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) - -inline fun , E : AsmExpressionAlgebra> A.asm( - expressionAlgebra: E, - ast: MST -): Expression = asm(expressionAlgebra) { evaluate(ast) } - -inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmNode): Expression where A : NumericAlgebra, A : Space = - AsmExpressionSpace(this).let { it.block().compile(it.algebra) } - -inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = - asmSpace { evaluate(ast) } - -inline fun A.asmRing(block: AsmExpressionRing.() -> AsmNode): Expression where A : NumericAlgebra, A : Ring = - AsmExpressionRing(this).let { it.block().compile(it.algebra) } - -inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = - asmRing { evaluate(ast) } - -inline fun A.asmField(block: AsmExpressionField.() -> AsmNode): Expression where A : NumericAlgebra, A : Field = - AsmExpressionField(this).let { it.block().compile(it.algebra) } - -inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = - asmRing { evaluate(ast) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 72a53d899..55d263117 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,6 +1,6 @@ package scientifik.kmath.asm -import scientifik.kmath.asm.internal.AsmGenerator +import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.optimize import scientifik.kmath.asm.internal.tryInvokeSpecific @@ -9,11 +9,11 @@ import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.* /** - * A function declaration that could be compiled to [AsmGenerator]. + * A function declaration that could be compiled to [AsmBuilder]. * * @param T the type the stored function returns. */ -abstract class AsmNode internal constructor() { +abstract class AsmExpression internal constructor() { /** * Tries to evaluate this function without its variables. This method is intended for optimization. * @@ -24,18 +24,18 @@ abstract class AsmNode internal constructor() { /** * Compiles this declaration. * - * @param gen the target [AsmGenerator]. + * @param gen the target [AsmBuilder]. */ @PublishedApi - internal abstract fun compile(gen: AsmGenerator) + internal abstract fun compile(gen: AsmBuilder) } -internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmNode) : - AsmNode() { - private val expr: AsmNode = expr.optimize() +internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : + AsmExpression() { + private val expr: AsmExpression = expr.optimize() override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } - override fun compile(gen: AsmGenerator) { + override fun compile(gen: AsmBuilder) { gen.loadAlgebra() if (!hasSpecific(context, name, 1)) @@ -47,11 +47,11 @@ internal class AsmUnaryOperation(private val context: Algebra, private val return gen.invokeAlgebraOperation( - owner = AsmGenerator.ALGEBRA_CLASS, + owner = AsmBuilder.ALGEBRA_CLASS, method = "unaryOperation", - descriptor = "(L${AsmGenerator.STRING_CLASS};" + - "L${AsmGenerator.OBJECT_CLASS};)" + - "L${AsmGenerator.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } @@ -59,11 +59,11 @@ internal class AsmUnaryOperation(private val context: Algebra, private val internal class AsmBinaryOperation( private val context: Algebra, private val name: String, - first: AsmNode, - second: AsmNode -) : AsmNode() { - private val first: AsmNode = first.optimize() - private val second: AsmNode = second.optimize() + first: AsmExpression, + second: AsmExpression +) : AsmExpression() { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() override fun tryEvaluate(): T? = context { binaryOperation( @@ -73,7 +73,7 @@ internal class AsmBinaryOperation( ) } - override fun compile(gen: AsmGenerator) { + override fun compile(gen: AsmBuilder) { gen.loadAlgebra() if (!hasSpecific(context, name, 2)) @@ -86,56 +86,56 @@ internal class AsmBinaryOperation( return gen.invokeAlgebraOperation( - owner = AsmGenerator.ALGEBRA_CLASS, + owner = AsmBuilder.ALGEBRA_CLASS, method = "binaryOperation", - descriptor = "(L${AsmGenerator.STRING_CLASS};" + - "L${AsmGenerator.OBJECT_CLASS};" + - "L${AsmGenerator.OBJECT_CLASS};)" + - "L${AsmGenerator.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } internal class AsmVariableExpression(private val name: String, private val default: T? = null) : - AsmNode() { - override fun compile(gen: AsmGenerator): Unit = gen.loadVariable(name, default) + AsmExpression() { + override fun compile(gen: AsmBuilder): Unit = gen.loadVariable(name, default) } internal class AsmConstantExpression(private val value: T) : - AsmNode() { + AsmExpression() { override fun tryEvaluate(): T = value - override fun compile(gen: AsmGenerator): Unit = gen.loadTConstant(value) + override fun compile(gen: AsmBuilder): Unit = gen.loadTConstant(value) } internal class AsmConstProductExpression( private val context: Space, - expr: AsmNode, + expr: AsmExpression, private val const: Number -) : AsmNode() { - private val expr: AsmNode = expr.optimize() +) : AsmExpression() { + private val expr: AsmExpression = expr.optimize() override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } - override fun compile(gen: AsmGenerator) { + override fun compile(gen: AsmBuilder) { gen.loadAlgebra() gen.loadNumberConstant(const) expr.compile(gen) gen.invokeAlgebraOperation( - owner = AsmGenerator.SPACE_OPERATIONS_CLASS, + owner = AsmBuilder.SPACE_OPERATIONS_CLASS, method = "multiply", - descriptor = "(L${AsmGenerator.OBJECT_CLASS};" + - "L${AsmGenerator.NUMBER_CLASS};)" + - "L${AsmGenerator.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.OBJECT_CLASS};" + + "L${AsmBuilder.NUMBER_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : - AsmNode() { + AsmExpression() { override fun tryEvaluate(): T? = context.number(value) - override fun compile(gen: AsmGenerator): Unit = gen.loadNumberConstant(value) + override fun compile(gen: AsmBuilder): Unit = gen.loadNumberConstant(value) } internal abstract class FunctionalCompiledExpression internal constructor( @@ -146,10 +146,10 @@ internal abstract class FunctionalCompiledExpression internal constructor( } /** - * A context class for [AsmNode] construction. + * A context class for [AsmExpression] construction. */ -interface AsmExpressionAlgebra> : NumericAlgebra>, - ExpressionAlgebra> { +interface AsmExpressionAlgebra> : NumericAlgebra>, + ExpressionAlgebra> { /** * The algebra to provide for AsmExpressions built. */ @@ -158,108 +158,108 @@ interface AsmExpressionAlgebra> : NumericAlgebra = AsmNumberExpression(algebra, value) + override fun number(value: Number): AsmExpression = AsmNumberExpression(algebra, value) /** * Builds an AsmExpression of constant expression which does not depend on arguments. */ - override fun const(value: T): AsmNode = AsmConstantExpression(value) + override fun const(value: T): AsmExpression = AsmConstantExpression(value) /** * Builds an AsmExpression to access a variable. */ - override fun variable(name: String, default: T?): AsmNode = AsmVariableExpression(name, default) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) /** * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. */ - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, operation, left, right) /** * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. */ - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = AsmUnaryOperation(algebra, operation, arg) } /** - * A context class for [AsmNode] construction for [Space] algebras. + * A context class for [AsmExpression] construction for [Space] algebras. */ open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, - Space> where A : Space, A : NumericAlgebra { - override val zero: AsmNode + Space> where A : Space, A : NumericAlgebra { + override val zero: AsmExpression get() = const(algebra.zero) /** * Builds an AsmExpression of addition of two another expressions. */ - override fun add(a: AsmNode, b: AsmNode): AsmNode = + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) /** * Builds an AsmExpression of multiplication of expression by number. */ - override fun multiply(a: AsmNode, k: Number): AsmNode = AsmConstProductExpression(algebra, a, k) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(algebra, a, k) - operator fun AsmNode.plus(arg: T): AsmNode = this + const(arg) - operator fun AsmNode.minus(arg: T): AsmNode = this - const(arg) - operator fun T.plus(arg: AsmNode): AsmNode = arg + this - operator fun T.minus(arg: AsmNode): AsmNode = arg - this + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) } /** - * A context class for [AsmNode] construction for [Ring] algebras. + * A context class for [AsmExpression] construction for [Ring] algebras. */ open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { - override val one: AsmNode + Ring> where A : Ring, A : NumericAlgebra { + override val one: AsmExpression get() = const(algebra.one) /** * Builds an AsmExpression of multiplication of two expressions. */ - override fun multiply(a: AsmNode, b: AsmNode): AsmNode = + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) - operator fun AsmNode.times(arg: T): AsmNode = this * const(arg) - operator fun T.times(arg: AsmNode): AsmNode = arg * this + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmNode = super.number(value) + override fun number(value: Number): AsmExpression = super.number(value) } /** - * A context class for [AsmNode] construction for [Field] algebras. + * A context class for [AsmExpression] construction for [Field] algebras. */ open class AsmExpressionField(override val algebra: A) : AsmExpressionRing(algebra), - Field> where A : Field, A : NumericAlgebra { + Field> where A : Field, A : NumericAlgebra { /** * Builds an AsmExpression of division an expression by another one. */ - override fun divide(a: AsmNode, b: AsmNode): AsmNode = + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) - operator fun AsmNode.div(arg: T): AsmNode = this / const(arg) - operator fun T.div(arg: AsmNode): AsmNode = arg / this + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmNode = super.number(value) + override fun number(value: Number): AsmExpression = super.number(value) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index cc3e36e94..a2bbb254c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,13 +1,12 @@ package scientifik.kmath.asm +import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.ast.MST import scientifik.kmath.ast.evaluate import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.NumericAlgebra -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.* +@PublishedApi internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" @@ -20,28 +19,35 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String return buildName(expression, collision + 1) } -inline fun , E : AsmExpressionAlgebra> A.asm( +@PublishedApi +internal inline fun AsmExpression.compile(algebra: Algebra): Expression = + AsmBuilder(T::class.java, algebra, buildName(this), this).getInstance() + +inline fun , E : AsmExpressionAlgebra> A.asm( expressionAlgebra: E, block: E.() -> AsmExpression -): Expression = expressionAlgebra.block() +): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) -inline fun NumericAlgebra.asm(ast: MST): Expression = - AsmExpressionAlgebra(T::class, this).evaluate(ast) +inline fun , E : AsmExpressionAlgebra> A.asm( + expressionAlgebra: E, + ast: MST +): Expression = asm(expressionAlgebra) { evaluate(ast) } -inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = - AsmExpressionSpace(T::class, this).block() +inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = + AsmExpressionSpace(this).let { it.block().compile(it.algebra) } -inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = +inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = asmSpace { evaluate(ast) } -inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = - AsmExpressionRing(T::class, this).block() +inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = + AsmExpressionRing(this).let { it.block().compile(it.algebra) } -inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = +inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = asmRing { evaluate(ast) } -inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = - AsmExpressionField(T::class, this).block() +inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = + AsmExpressionField(this).let { it.block().compile(it.algebra) } -inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = +inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = asmRing { evaluate(ast) } + diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt similarity index 96% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 0c6975cd9..720d44b45 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerator.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -4,13 +4,13 @@ import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.AsmNode +import scientifik.kmath.asm.AsmExpression import scientifik.kmath.asm.FunctionalCompiledExpression -import scientifik.kmath.asm.internal.AsmGenerator.ClassLoader +import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.operations.Algebra /** - * ASM Generator is a structure that abstracts building a class that unwraps [AsmNode] to plain Java expression. + * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @param T the type of AsmExpression to unwrap. @@ -18,11 +18,11 @@ import scientifik.kmath.operations.Algebra * @param className the unique class name of new loaded class. */ @PublishedApi -internal class AsmGenerator @PublishedApi internal constructor( +internal class AsmBuilder @PublishedApi internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String, - private val root: AsmNode + private val root: AsmExpression ) { private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) @@ -112,7 +112,7 @@ internal class AsmGenerator @PublishedApi internal constructor( visitCode() val l0 = Label() visitLabel(l0) - root.compile(this@AsmGenerator) + root.compile(this@AsmBuilder) visitReturnObject() val l1 = Label() visitLabel(l1) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/ClassWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/ClassWriters.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index 3b37f81e5..2fb23d137 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -2,7 +2,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes import scientifik.kmath.asm.AsmConstantExpression -import scientifik.kmath.asm.AsmNode +import scientifik.kmath.asm.AsmExpression import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") @@ -16,7 +16,7 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo return true } -internal fun AsmGenerator.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name context::class.java.methods.find { it.name == aName && it.parameters.size == arity } @@ -26,9 +26,9 @@ internal fun AsmGenerator.tryInvokeSpecific(context: Algebra, name: St val sig = buildString { append('(') - repeat(arity) { append("L${AsmGenerator.OBJECT_CLASS};") } + repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } append(')') - append("L${AsmGenerator.OBJECT_CLASS};") + append("L${AsmBuilder.OBJECT_CLASS};") } invokeAlgebraOperation( @@ -43,7 +43,7 @@ internal fun AsmGenerator.tryInvokeSpecific(context: Algebra, name: St } @PublishedApi -internal fun AsmNode.optimize(): AsmNode { +internal fun AsmExpression.optimize(): AsmExpression { val a = tryEvaluate() return if (a == null) this else AsmConstantExpression(a) } From 2580ab347e0fc53c8c030e63c3eb8de96e52b839 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 15 Jun 2020 17:37:11 +0700 Subject: [PATCH 059/104] Make ClassWriter extensions internal --- .../kotlin/scientifik/kmath/asm/internal/classWriters.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index 84f736ee7..95d713b18 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -3,9 +3,9 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter import org.objectweb.asm.MethodVisitor -inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) -inline fun ClassWriter.visitMethod( +internal inline fun ClassWriter.visitMethod( access: Int, name: String, descriptor: String, From 1e2460c5b3e4f03b8351fd5ad7b5feb63311923f Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Mon, 15 Jun 2020 21:02:38 +0700 Subject: [PATCH 060/104] Rename --- .../kmath/asm/internal/{MethodVisitors.kt => methodVisitors.kt} | 0 .../kmath/asm/internal/{Optimization.kt => optimization.kt} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{MethodVisitors.kt => methodVisitors.kt} (100%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{Optimization.kt => optimization.kt} (100%) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt From 96550922cdbb582a556d6740ad51cd4c884b0684 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 15 Jun 2020 22:07:31 +0300 Subject: [PATCH 061/104] Removal of AsmExpression --- .../scientifik/kmath/structures/ComplexND.kt | 2 +- .../scientifik/kmath/structures/NDField.kt | 8 +- .../kotlin/scientifik/kmath/ast/MST.kt | 11 +- .../kotlin/scientifik/kmath/ast/MSTAlgebra.kt | 33 ++ .../scientifik/kmath/asm/AsmExpressions.kt | 292 ------------------ .../kotlin/scientifik/kmath/asm/asm.kt | 122 ++++++-- .../kmath/asm/internal/AsmBuilder.kt | 18 +- .../kmath/asm/internal/optimization.kt | 12 +- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 137 ++++---- .../scietifik/kmath/asm/TestAsmExpressions.kt | 8 +- .../kotlin/scietifik/kmath/ast/AsmTest.kt | 6 +- .../commons/expressions/DiffExpression.kt | 8 +- .../kotlin/scientifik/kmath/misc/AutoDiff.kt | 18 +- .../scientifik/kmath/operations/Algebra.kt | 25 +- 14 files changed, 253 insertions(+), 447 deletions(-) create mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt index cc8b68d85..991cd34a1 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt @@ -27,7 +27,7 @@ fun main() { val complexTime = measureTimeMillis { complexField.run { - var res = one + var res: NDBuffer = one repeat(n) { res += 1.0 } diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt index cfd1206ff..2aafb504d 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt @@ -23,14 +23,14 @@ fun main() { measureAndPrint("Automatic field addition") { autoField.run { - var res = one + var res: NDBuffer = one repeat(n) { - res += 1.0 + res += number(1.0) } } } - measureAndPrint("Element addition"){ + measureAndPrint("Element addition") { var res = genericField.one repeat(n) { res += 1.0 @@ -63,7 +63,7 @@ fun main() { genericField.run { var res: NDBuffer = one repeat(n) { - res += 1.0 + res += one // con't avoid using `one` due to resolution ambiguity } } } diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt index 900b9297a..142d27f93 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -1,5 +1,6 @@ package scientifik.kmath.ast +import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra import scientifik.kmath.operations.RealField @@ -40,12 +41,14 @@ sealed class MST { //TODO add a function with named arguments -fun NumericAlgebra.evaluate(node: MST): T { +fun Algebra.evaluate(node: MST): T { return when (node) { - is MST.Numeric -> number(node.value) + is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) + ?: error("Numeric nodes are not supported by $this") is MST.Symbolic -> symbol(node.value) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Binary -> when { + this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) node.left is MST.Numeric && node.right is MST.Numeric -> { val number = RealField.binaryOperation( node.operation, @@ -59,4 +62,6 @@ fun NumericAlgebra.evaluate(node: MST): T { else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) } } -} \ No newline at end of file +} + +fun MST.compile(algebra: Algebra): T = algebra.evaluate(this) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt new file mode 100644 index 000000000..45693645c --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -0,0 +1,33 @@ +@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE") + +package scientifik.kmath.ast + +import scientifik.kmath.operations.* + +object MSTAlgebra : NumericAlgebra { + override fun symbol(value: String): MST = MST.Symbolic(value) + + override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right) + + override fun number(value: Number): MST = MST.Numeric(value) +} + +object MSTSpace : Space, NumericAlgebra by MSTAlgebra { + override val zero: MST = number(0.0) + + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) +} + +object MSTRing : Ring, Space by MSTSpace { + override val one: MST = number(1.0) + + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) +} + +object MSTField : Field, Ring by MSTRing { + override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) +} \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt deleted file mode 100644 index a8b3ff976..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ /dev/null @@ -1,292 +0,0 @@ -package scientifik.kmath.asm - -import scientifik.kmath.asm.internal.AsmBuilder -import scientifik.kmath.asm.internal.hasSpecific -import scientifik.kmath.asm.internal.optimize -import scientifik.kmath.asm.internal.tryInvokeSpecific -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionAlgebra -import scientifik.kmath.operations.* -import kotlin.reflect.KClass - -/** - * A function declaration that could be compiled to [AsmBuilder]. - * - * @param T the type the stored function returns. - */ -sealed class AsmExpression: Expression { - abstract val type: KClass - - abstract val algebra: Algebra - - /** - * Tries to evaluate this function without its variables. This method is intended for optimization. - * - * @return `null` if the function depends on its variables, the value if the function is a constant. - */ - internal open fun tryEvaluate(): T? = null - - /** - * Compiles this declaration. - * - * @param gen the target [AsmBuilder]. - */ - internal abstract fun appendTo(gen: AsmBuilder) - - /** - * Compile and cache the expression - */ - private val compiledExpression by lazy{ - val builder = AsmBuilder(type.java, algebra, buildName(this)) - this.appendTo(builder) - builder.generate() - } - - override fun invoke(arguments: Map): T = compiledExpression.invoke(arguments) -} - -internal class AsmUnaryOperation( - override val type: KClass, - override val algebra: Algebra, - private val name: String, - expr: AsmExpression -) : AsmExpression() { - private val expr: AsmExpression = expr.optimize() - override fun tryEvaluate(): T? = algebra { unaryOperation(name, expr.tryEvaluate() ?: return@algebra null) } - - override fun appendTo(gen: AsmBuilder) { - gen.visitLoadAlgebra() - - if (!hasSpecific(algebra, name, 1)) - gen.visitStringConstant(name) - - expr.appendTo(gen) - - if (gen.tryInvokeSpecific(algebra, name, 1)) - return - - gen.visitAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "unaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } -} - -internal class AsmBinaryOperation( - override val type: KClass, - override val algebra: Algebra, - private val name: String, - first: AsmExpression, - second: AsmExpression -) : AsmExpression() { - private val first: AsmExpression = first.optimize() - private val second: AsmExpression = second.optimize() - - override fun tryEvaluate(): T? = algebra { - binaryOperation( - name, - first.tryEvaluate() ?: return@algebra null, - second.tryEvaluate() ?: return@algebra null - ) - } - - override fun appendTo(gen: AsmBuilder) { - gen.visitLoadAlgebra() - - if (!hasSpecific(algebra, name, 2)) - gen.visitStringConstant(name) - - first.appendTo(gen) - second.appendTo(gen) - - if (gen.tryInvokeSpecific(algebra, name, 2)) - return - - gen.visitAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "binaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } -} - -internal class AsmVariableExpression( - override val type: KClass, - override val algebra: Algebra, - private val name: String, - private val default: T? = null -) : AsmExpression() { - override fun appendTo(gen: AsmBuilder): Unit = gen.visitLoadFromVariables(name, default) -} - -internal class AsmConstantExpression( - override val type: KClass, - override val algebra: Algebra, - private val value: T -) : AsmExpression() { - override fun tryEvaluate(): T = value - override fun appendTo(gen: AsmBuilder): Unit = gen.visitLoadFromConstants(value) -} - -internal class AsmConstProductExpression( - override val type: KClass, - override val algebra: Space, - expr: AsmExpression, - private val const: Number -) : AsmExpression() { - private val expr: AsmExpression = expr.optimize() - - override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const } - - override fun appendTo(gen: AsmBuilder) { - gen.visitLoadAlgebra() - gen.visitNumberConstant(const) - expr.appendTo(gen) - - gen.visitAlgebraOperation( - owner = AsmBuilder.SPACE_OPERATIONS_CLASS, - method = "multiply", - descriptor = "(L${AsmBuilder.OBJECT_CLASS};" + - "L${AsmBuilder.NUMBER_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } -} - -internal class AsmNumberExpression( - override val type: KClass, - override val algebra: NumericAlgebra, - private val value: Number -) : AsmExpression() { - override fun tryEvaluate(): T? = algebra.number(value) - - override fun appendTo(gen: AsmBuilder): Unit = gen.visitNumberConstant(value) -} - -internal abstract class FunctionalCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: Array -) : Expression { - abstract override fun invoke(arguments: Map): T -} - -/** - * A context class for [AsmExpression] construction. - * - * @param algebra The algebra to provide for AsmExpressions built. - */ -open class AsmExpressionAlgebra>(val type: KClass, val algebra: A) : - NumericAlgebra>, ExpressionAlgebra> { - - /** - * Builds an AsmExpression to wrap a number. - */ - override fun number(value: Number): AsmExpression = AsmNumberExpression(type, algebra, value) - - /** - * Builds an AsmExpression of constant expression which does not depend on arguments. - */ - override fun const(value: T): AsmExpression = AsmConstantExpression(type, algebra, value) - - /** - * Builds an AsmExpression to access a variable. - */ - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(type, algebra, name, default) - - /** - * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. - */ - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - AsmBinaryOperation(type, algebra, operation, left, right) - - /** - * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. - */ - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - AsmUnaryOperation(type, algebra, operation, arg) -} - -/** - * A context class for [AsmExpression] construction for [Space] algebras. - */ -open class AsmExpressionSpace(type: KClass, algebra: A) : AsmExpressionAlgebra(type, algebra), - Space> where A : Space, A : NumericAlgebra { - override val zero: AsmExpression get() = const(algebra.zero) - - /** - * Builds an AsmExpression of addition of two another expressions. - */ - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmBinaryOperation(type, algebra, SpaceOperations.PLUS_OPERATION, a, b) - - /** - * Builds an AsmExpression of multiplication of expression by number. - */ - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(type, algebra, a, k) - - operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) - operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) - operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this - operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this - - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - super.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - super.binaryOperation(operation, left, right) -} - -/** - * A context class for [AsmExpression] construction for [Ring] algebras. - */ -open class AsmExpressionRing(type: KClass, algebra: A) : AsmExpressionSpace(type, algebra), - Ring> where A : Ring, A : NumericAlgebra { - override val one: AsmExpression get() = const(algebra.one) - - /** - * Builds an AsmExpression of multiplication of two expressions. - */ - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmBinaryOperation(type, algebra, RingOperations.TIMES_OPERATION, a, b) - - operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) - operator fun T.times(arg: AsmExpression): AsmExpression = arg * this - - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - super.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - super.binaryOperation(operation, left, right) - - override fun number(value: Number): AsmExpression = super.number(value) -} - -/** - * A context class for [AsmExpression] construction for [Field] algebras. - */ -open class AsmExpressionField(type: KClass, algebra: A) : - AsmExpressionRing(type, algebra), - Field> where A : Field, A : NumericAlgebra { - /** - * Builds an AsmExpression of division an expression by another one. - */ - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmBinaryOperation(type, algebra, FieldOperations.DIV_OPERATION, a, b) - - operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) - operator fun T.div(arg: AsmExpression): AsmExpression = arg / this - - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = - super.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = - super.binaryOperation(operation, left, right) - - override fun number(value: Number): AsmExpression = super.number(value) -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index cc3e36e94..48e368fc3 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,47 +1,103 @@ package scientifik.kmath.asm +import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.hasSpecific +import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.evaluate +import scientifik.kmath.ast.MSTField +import scientifik.kmath.ast.MSTRing +import scientifik.kmath.ast.MSTSpace import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.NumericAlgebra -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.* +import kotlin.reflect.KClass -internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name +/** + * Compile given MST to an Expression using AST compiler + */ +fun MST.compileWith(type: KClass, algebra: Algebra): Expression { + + fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) } - return buildName(expression, collision + 1) + fun AsmBuilder.visit(node: MST): Unit { + when (node) { + is MST.Symbolic -> visitLoadFromVariables(node.value) + is MST.Numeric -> { + val constant = if (algebra is NumericAlgebra) { + algebra.number(node.value) + } else { + error("Number literals are not supported in $algebra") + } + visitLoadFromConstants(constant) + } + is MST.Unary -> { + visitLoadAlgebra() + + if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation) + + visit(node.value) + + if (!tryInvokeSpecific(algebra, node.operation, 1)) { + visitAlgebraOperation( + owner = AsmBuilder.ALGEBRA_CLASS, + method = "unaryOperation", + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" + ) + } + } + is MST.Binary -> { + visitLoadAlgebra() + + if (!hasSpecific(algebra, node.operation, 2)) + visitStringConstant(node.operation) + + visit(node.left) + visit(node.right) + + if (!tryInvokeSpecific(algebra, node.operation, 2)) { + + visitAlgebraOperation( + owner = AsmBuilder.ALGEBRA_CLASS, + method = "binaryOperation", + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" + ) + } + } + } + } + + val builder = AsmBuilder(type.java, algebra, buildName(this)) + builder.visit(this) + return builder.generate() } -inline fun , E : AsmExpressionAlgebra> A.asm( - expressionAlgebra: E, - block: E.() -> AsmExpression -): Expression = expressionAlgebra.block() +inline fun Algebra.compile(mst: MST): Expression = mst.compileWith(T::class, this) -inline fun NumericAlgebra.asm(ast: MST): Expression = - AsmExpressionAlgebra(T::class, this).evaluate(ast) +inline fun , E : Algebra> A.asm( + mstAlgebra: E, + block: E.() -> MST +): Expression = mstAlgebra.block().compileWith(T::class, this) -inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = - AsmExpressionSpace(T::class, this).block() +inline fun > A.asmInSpace(block: MSTSpace.() -> MST): Expression = + MSTSpace.block().compileWith(T::class, this) -inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = - asmSpace { evaluate(ast) } +inline fun > A.asmInRing(block: MSTRing.() -> MST): Expression = + MSTRing.block().compileWith(T::class, this) -inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = - AsmExpressionRing(T::class, this).block() - -inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = - asmRing { evaluate(ast) } - -inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = - AsmExpressionField(T::class, this).block() - -inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = - asmRing { evaluate(ast) } +inline fun > A.asmInField(block: MSTField.() -> MST): Expression = + MSTField.block().compileWith(T::class, this) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index eefde66e4..076093ee1 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -4,10 +4,19 @@ import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.FunctionalCompiledExpression import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra + +private abstract class FunctionalCompiledExpression internal constructor( + @JvmField protected val algebra: Algebra, + @JvmField protected val constants: Array +) : Expression { + abstract override fun invoke(arguments: Map): T +} + + /** * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java * expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new @@ -16,6 +25,8 @@ import scientifik.kmath.operations.Algebra * @param T the type of AsmExpression to unwrap. * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. + * + * @author [Iaroslav Postovalov](https://github.com/CommanderTvis) */ internal class AsmBuilder( private val classOfT: Class<*>, @@ -44,7 +55,6 @@ internal class AsmBuilder( private val invokeMethodVisitor: MethodVisitor private val invokeL0: Label private lateinit var invokeL1: Label - private var generatedInstance: FunctionalCompiledExpression? = null init { asmCompiledClassWriter.visit( @@ -113,8 +123,7 @@ internal class AsmBuilder( } @Suppress("UNCHECKED_CAST") - fun generate(): FunctionalCompiledExpression { - generatedInstance?.let { return it } + fun generate(): Expression { invokeMethodVisitor.run { visitInsn(Opcodes.ARETURN) @@ -182,7 +191,6 @@ internal class AsmBuilder( .first() .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression - generatedInstance = new return new } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index d9f950ceb..029939f16 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -1,8 +1,6 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.AsmConstantExpression -import scientifik.kmath.asm.AsmExpression import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") @@ -41,8 +39,8 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri return true } - -internal fun AsmExpression.optimize(): AsmExpression { - val a = tryEvaluate() - return if (a == null) this else AsmConstantExpression(type, algebra, a) -} +// +//internal fun AsmExpression.optimize(): AsmExpression { +// val a = tryEvaluate() +// return if (a == null) this else AsmConstantExpression(type, algebra, a) +//} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 06f776597..a6cb1e247 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,75 +1,66 @@ package scietifik.kmath.asm -import scientifik.kmath.asm.asmField -import scientifik.kmath.asm.asmRing -import scientifik.kmath.asm.asmSpace -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.ByteRing -import scientifik.kmath.operations.RealField -import kotlin.test.Test -import kotlin.test.assertEquals - -class TestAsmAlgebras { - @Test - fun space() { - val res = ByteRing.asmSpace { - binaryOperation( - "+", - - unaryOperation( - "+", - 3.toByte() - (2.toByte() + (multiply( - add(const(1), const(1)), - 2 - ) + 1.toByte()) * 3.toByte() - 1.toByte()) - ), - - number(1) - ) + variable("x") + zero - }("x" to 2.toByte()) - - assertEquals(16, res) - } - - @Test - fun ring() { - val res = ByteRing.asmRing { - binaryOperation( - "+", - - unaryOperation( - "+", - (3.toByte() - (2.toByte() + (multiply( - add(const(1), const(1)), - 2 - ) + 1.toByte()))) * 3.0 - 1.toByte() - ), - - number(1) - ) * const(2) - }() - - assertEquals(24, res) - } - - @Test - fun field() { - val res = RealField.asmField { - divide(binaryOperation( - "+", - - unaryOperation( - "+", - (3.0 - (2.0 + (multiply( - add(const(1.0), const(1.0)), - 2 - ) + 1.0))) * 3 - 1.0 - ), - - number(1) - ) / 2, const(2.0)) * one - }() - - assertEquals(3.0, res) - } -} +// +//class TestAsmAlgebras { +// @Test +// fun space() { +// val res = ByteRing.asmInRing { +// binaryOperation( +// "+", +// +// unaryOperation( +// "+", +// 3.toByte() - (2.toByte() + (multiply( +// add(number(1), number(1)), +// 2 +// ) + 1.toByte()) * 3.toByte() - 1.toByte()) +// ), +// +// number(1) +// ) + symbol("x") + zero +// }("x" to 2.toByte()) +// +// assertEquals(16, res) +// } +// +// @Test +// fun ring() { +// val res = ByteRing.asmInRing { +// binaryOperation( +// "+", +// +// unaryOperation( +// "+", +// (3.toByte() - (2.toByte() + (multiply( +// add(const(1), const(1)), +// 2 +// ) + 1.toByte()))) * 3.0 - 1.toByte() +// ), +// +// number(1) +// ) * const(2) +// }() +// +// assertEquals(24, res) +// } +// +// @Test +// fun field() { +// val res = RealField.asmInField { +// +(3 - 2 + 2*(number(1)+1.0) +// +// unaryOperation( +// "+", +// (3.0 - (2.0 + (multiply( +// add((1.0), const(1.0)), +// 2 +// ) + 1.0))) * 3 - 1.0 +// )+ +// +// number(1) +// ) / 2, const(2.0)) * one +// }() +// +// assertEquals(3.0, res) +// } +//} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 40d990537..bfbf5e926 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -1,6 +1,6 @@ package scietifik.kmath.asm -import scientifik.kmath.asm.asmField +import scientifik.kmath.asm.asmInField import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.RealField import kotlin.test.Test @@ -9,13 +9,13 @@ import kotlin.test.assertEquals class TestAsmExpressions { @Test fun testUnaryOperationInvocation() { - val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0) - assertEquals(2.0, res) + val res = RealField.asmInField { -symbol("x") }("x" to 2.0) + assertEquals(-2.0, res) } @Test fun testConstProductInvocation() { - val res = RealField.asmField { variable("x") * 2 }("x" to 2.0) + val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0) assertEquals(4.0, res) } } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index c92524b5d..752b3b601 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -1,18 +1,18 @@ package scietifik.kmath.ast -import scientifik.kmath.asm.asmField +import scientifik.kmath.asm.compile import scientifik.kmath.ast.parseMath import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField -import kotlin.test.assertEquals import kotlin.test.Test +import kotlin.test.assertEquals class AsmTest { @Test fun parsedExpression() { val mst = "2+2*(2+2)".parseMath() - val res = ComplexField.asmField(mst)() + val res = ComplexField.compile(mst)() assertEquals(Complex(10.0, 0.0), res) } } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index f6f61e08a..a38fd52a8 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -74,10 +74,10 @@ class DerivativeStructureField( override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() - operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble()) - operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble()) - operator fun Number.plus(s: DerivativeStructure) = s + this - operator fun Number.minus(s: DerivativeStructure) = s - this + override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) + override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) + override operator fun Number.plus(b: DerivativeStructure) = b + this + override operator fun Number.minus(b: DerivativeStructure) = b - this } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt index ed77054cf..076701a4f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt @@ -90,20 +90,20 @@ abstract class AutoDiffField> : Field> { // Overloads for Double constants - operator fun Number.plus(that: Variable): Variable = - derive(variable { this@plus.toDouble() * one + that.value }) { z -> - that.d += z.d + override operator fun Number.plus(b: Variable): Variable = + derive(variable { this@plus.toDouble() * one + b.value }) { z -> + b.d += z.d } - operator fun Variable.plus(b: Number): Variable = b.plus(this) + override operator fun Variable.plus(b: Number): Variable = b.plus(this) - operator fun Number.minus(that: Variable): Variable = - derive(variable { this@minus.toDouble() * one - that.value }) { z -> - that.d -= z.d + override operator fun Number.minus(b: Variable): Variable = + derive(variable { this@minus.toDouble() * one - b.value }) { z -> + b.d -= z.d } - operator fun Variable.minus(that: Number): Variable = - derive(variable { this@minus.value - one * that.toDouble() }) { z -> + override operator fun Variable.minus(b: Number): Variable = + derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 8ed3f329e..52b6bba02 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -12,9 +12,6 @@ interface Algebra { */ fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") - @Deprecated("Symbol is more concise",replaceWith = ReplaceWith("symbol")) - fun raw(value: String): T = symbol(value) - /** * Dynamic call of unary operation with name [operation] on [arg] */ @@ -64,6 +61,8 @@ interface SpaceOperations : Algebra { //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 operator fun T.unaryMinus(): T = multiply(this, -1.0) + operator fun T.unaryPlus(): T = this + operator fun T.plus(b: T): T = add(this, b) operator fun T.minus(b: T): T = add(this, -b) operator fun T.times(k: Number) = multiply(this, k.toDouble()) @@ -138,17 +137,25 @@ interface Ring : Space, RingOperations, NumericAlgebra { override fun number(value: Number): T = one * value.toDouble() override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { + SpaceOperations.PLUS_OPERATION -> left + right + SpaceOperations.MINUS_OPERATION -> left - right RingOperations.TIMES_OPERATION -> left * right else -> super.leftSideNumberOperation(operation, left, right) } - //TODO those operators are blocked by type conflict in RealField + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { + SpaceOperations.PLUS_OPERATION -> left + right + SpaceOperations.MINUS_OPERATION -> left - right + RingOperations.TIMES_OPERATION -> left * right + else -> super.rightSideNumberOperation(operation, left, right) + } -// operator fun T.plus(b: Number) = this.plus(b * one) -// operator fun Number.plus(b: T) = b + this -// -// operator fun T.minus(b: Number) = this.minus(b * one) -// operator fun Number.minus(b: T) = -b + this + + operator fun T.plus(b: Number) = this.plus(number(b)) + operator fun Number.plus(b: T) = b + this + + operator fun T.minus(b: Number) = this.minus(number(b)) + operator fun Number.minus(b: T) = -b + this } /** From 15d7a20b43166b576df63a434bdcd0bc40632840 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Tue, 16 Jun 2020 14:16:36 +0700 Subject: [PATCH 062/104] Add removed AsmCompiledExpression, move buildName to buildName.kt, refactor compileWith --- .../kotlin/scientifik/kmath/asm/asm.kt | 37 ++++++------------- .../kmath/asm/internal/AsmBuilder.kt | 17 ++++----- .../asm/internal/AsmCompiledExpression.kt | 12 ++++++ .../kmath/asm/internal/buildName.kt | 15 ++++++++ .../kmath/asm/internal/optimization.kt | 7 +--- 5 files changed, 46 insertions(+), 42 deletions(-) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 48e368fc3..e7644998e 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,6 +1,7 @@ package scientifik.kmath.asm import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST @@ -11,44 +12,30 @@ import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.* import kotlin.reflect.KClass - /** * Compile given MST to an Expression using AST compiler */ fun MST.compileWith(type: KClass, algebra: Algebra): Expression { - - fun buildName(mst: MST, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(mst, collision + 1) - } - - fun AsmBuilder.visit(node: MST): Unit { + fun AsmBuilder.visit(node: MST) { when (node) { - is MST.Symbolic -> visitLoadFromVariables(node.value) + is MST.Symbolic -> loadVariable(node.value) is MST.Numeric -> { val constant = if (algebra is NumericAlgebra) { algebra.number(node.value) } else { error("Number literals are not supported in $algebra") } - visitLoadFromConstants(constant) + loadTConstant(constant) } is MST.Unary -> { - visitLoadAlgebra() + loadAlgebra() - if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation) + if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation) visit(node.value) if (!tryInvokeSpecific(algebra, node.operation, 1)) { - visitAlgebraOperation( + invokeAlgebraOperation( owner = AsmBuilder.ALGEBRA_CLASS, method = "unaryOperation", descriptor = "(L${AsmBuilder.STRING_CLASS};" + @@ -58,17 +45,17 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< } } is MST.Binary -> { - visitLoadAlgebra() + loadAlgebra() if (!hasSpecific(algebra, node.operation, 2)) - visitStringConstant(node.operation) + loadStringConstant(node.operation) visit(node.left) visit(node.right) if (!tryInvokeSpecific(algebra, node.operation, 2)) { - visitAlgebraOperation( + invokeAlgebraOperation( owner = AsmBuilder.ALGEBRA_CLASS, method = "binaryOperation", descriptor = "(L${AsmBuilder.STRING_CLASS};" + @@ -81,9 +68,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< } } - val builder = AsmBuilder(type.java, algebra, buildName(this)) - builder.visit(this) - return builder.generate() + return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() } inline fun Algebra.compile(mst: MST): Expression = mst.compileWith(T::class, this) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 720d44b45..73ef77cc5 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -4,9 +4,8 @@ import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.AsmExpression -import scientifik.kmath.asm.FunctionalCompiledExpression import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader +import scientifik.kmath.ast.MST import scientifik.kmath.operations.Algebra /** @@ -22,7 +21,7 @@ internal class AsmBuilder @PublishedApi internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String, - private val root: AsmExpression + private val evaluateMethodVisitor: AsmBuilder.() -> Unit ) { private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) @@ -42,10 +41,10 @@ internal class AsmBuilder @PublishedApi internal constructor( private val invokeArgumentsVar: Int = 1 private val constants: MutableList = mutableListOf() private lateinit var invokeMethodVisitor: MethodVisitor - private var generatedInstance: FunctionalCompiledExpression? = null + private var generatedInstance: AsmCompiledExpression? = null @Suppress("UNCHECKED_CAST") - fun getInstance(): FunctionalCompiledExpression { + fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { @@ -112,7 +111,7 @@ internal class AsmBuilder @PublishedApi internal constructor( visitCode() val l0 = Label() visitLabel(l0) - root.compile(this@AsmBuilder) + evaluateMethodVisitor() visitReturnObject() val l1 = Label() visitLabel(l1) @@ -178,7 +177,7 @@ internal class AsmBuilder @PublishedApi internal constructor( .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression generatedInstance = new return new @@ -296,13 +295,11 @@ internal class AsmBuilder @PublishedApi internal constructor( private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/asm/FunctionalCompiledExpression" + "scientifik/kmath/asm/internal/AsmCompiledExpression" internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" internal const val STRING_CLASS = "java/lang/String" - internal const val NUMBER_CLASS = "java/lang/Number" } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt new file mode 100644 index 000000000..0e113099a --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt @@ -0,0 +1,12 @@ +package scientifik.kmath.asm.internal + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra + +internal abstract class AsmCompiledExpression internal constructor( + @JvmField protected val algebra: Algebra, + @JvmField protected val constants: Array +) : Expression { + abstract override fun invoke(arguments: Map): T +} + diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt new file mode 100644 index 000000000..65652d72b --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.asm.internal + +import scientifik.kmath.ast.MST + +internal fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index 029939f16..a39dba6b1 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -29,7 +29,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri append("L${AsmBuilder.OBJECT_CLASS};") } - visitAlgebraOperation( + invokeAlgebraOperation( owner = owner, method = aName, descriptor = sig, @@ -39,8 +39,3 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri return true } -// -//internal fun AsmExpression.optimize(): AsmExpression { -// val a = tryEvaluate() -// return if (a == null) this else AsmConstantExpression(type, algebra, a) -//} From 91a9e2a5e9c66821c0e5b008710b6cb798dbf9bd Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Tue, 16 Jun 2020 14:20:17 +0700 Subject: [PATCH 063/104] Remove @PublishedApi --- .../jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 73ef77cc5..5525e623f 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -16,8 +16,7 @@ import scientifik.kmath.operations.Algebra * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. */ -@PublishedApi -internal class AsmBuilder @PublishedApi internal constructor( +internal class AsmBuilder internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String, From c3cecc5a1611d4793dfaf4dd059d96a9f8455971 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Tue, 16 Jun 2020 14:21:13 +0700 Subject: [PATCH 064/104] Rename variable --- .../kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 5525e623f..337219029 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -5,7 +5,6 @@ import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader -import scientifik.kmath.ast.MST import scientifik.kmath.operations.Algebra /** @@ -20,7 +19,7 @@ internal class AsmBuilder internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String, - private val evaluateMethodVisitor: AsmBuilder.() -> Unit + private val invokeLabel0Visitor: AsmBuilder.() -> Unit ) { private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) @@ -110,7 +109,7 @@ internal class AsmBuilder internal constructor( visitCode() val l0 = Label() visitLabel(l0) - evaluateMethodVisitor() + invokeLabel0Visitor() visitReturnObject() val l1 = Label() visitLabel(l1) From 3d5036c982389ca45c66933454a212edd423ce60 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 16 Jun 2020 10:27:54 +0300 Subject: [PATCH 065/104] Fix MSTAlgebra delegation --- .../kotlin/scientifik/kmath/ast/MSTAlgebra.kt | 22 +++++++++--- .../kotlin/scientifik/kmath/asm/asm.kt | 36 ++++++++++--------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt index 45693645c..4a5900a48 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -1,4 +1,3 @@ -@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE") package scientifik.kmath.ast @@ -7,9 +6,11 @@ import scientifik.kmath.operations.* object MSTAlgebra : NumericAlgebra { override fun symbol(value: String): MST = MST.Symbolic(value) - override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = + MST.Unary(operation, arg) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MST.Binary(operation, left, right) override fun number(value: Number): MST = MST.Numeric(value) } @@ -17,17 +18,28 @@ object MSTAlgebra : NumericAlgebra { object MSTSpace : Space, NumericAlgebra by MSTAlgebra { override val zero: MST = number(0.0) - override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun add(a: MST, b: MST): MST = + binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } object MSTRing : Ring, Space by MSTSpace { override val one: MST = number(1.0) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTSpace.binaryOperation(operation, left, right) } object MSTField : Field, Ring by MSTRing { override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTRing.binaryOperation(operation, left, right) } \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 48e368fc3..46739c49c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -4,11 +4,10 @@ import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTField -import scientifik.kmath.ast.MSTRing -import scientifik.kmath.ast.MSTSpace +import scientifik.kmath.ast.MSTExpression import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.* +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra import kotlin.reflect.KClass @@ -88,16 +87,21 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< inline fun Algebra.compile(mst: MST): Expression = mst.compileWith(T::class, this) -inline fun , E : Algebra> A.asm( - mstAlgebra: E, - block: E.() -> MST -): Expression = mstAlgebra.block().compileWith(T::class, this) +/** + * Optimize performance of an [MSTExpression] using ASM codegen + */ +inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) -inline fun > A.asmInSpace(block: MSTSpace.() -> MST): Expression = - MSTSpace.block().compileWith(T::class, this) - -inline fun > A.asmInRing(block: MSTRing.() -> MST): Expression = - MSTRing.block().compileWith(T::class, this) - -inline fun > A.asmInField(block: MSTField.() -> MST): Expression = - MSTField.block().compileWith(T::class, this) +//inline fun , E : Algebra> A.asm( +// mstAlgebra: E, +// block: E.() -> MST +//): Expression = mstAlgebra.block().compileWith(T::class, this) +// +//inline fun > A.asmInSpace(block: MSTSpace.() -> MST): Expression = +// MSTSpace.block().compileWith(T::class, this) +// +//inline fun > A.asmInRing(block: MSTRing.() -> MST): Expression = +// MSTRing.block().compileWith(T::class, this) +// +//inline fun > A.asmInField(block: MSTField.() -> MST): Expression = +// MSTField.block().compileWith(T::class, this) From 8264806958764be299d69a98d0804ea50d5e485e Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 16 Jun 2020 14:52:02 +0300 Subject: [PATCH 066/104] Algebra delegates update. --- kmath-ast/build.gradle.kts | 5 ++ .../kotlin/scientifik/kmath/ast/MSTAlgebra.kt | 55 +++++++++++++++---- .../scientifik/kmath/ast/MSTExpression.kt | 46 ++++++++++++++-- .../kotlin/scientifik/kmath/asm/asm.kt | 21 ++----- .../scietifik/kmath/asm/TestAsmExpressions.kt | 9 ++- .../kotlin/scietifik/kmath/ast/AsmTest.kt | 5 +- .../scientifik/kmath/expressions/Builders.kt | 6 +- .../kmath/expressions/Expression.kt | 13 +++-- .../FunctionalExpressionAlgebra.kt | 8 +-- 9 files changed, 117 insertions(+), 51 deletions(-) diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 9764fccc3..511571fc9 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -7,6 +7,11 @@ repositories { } kotlin.sourceSets { +// all { +// languageSettings.apply{ +// enableLanguageFeature("NewInference") +// } +// } commonMain { dependencies { api(project(":kmath-core")) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt index 4a5900a48..07194a7bb 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -1,9 +1,10 @@ - package scientifik.kmath.ast import scientifik.kmath.operations.* object MSTAlgebra : NumericAlgebra { + override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MST.Symbolic(value) override fun unaryOperation(operation: String, arg: MST): MST = @@ -11,13 +12,15 @@ object MSTAlgebra : NumericAlgebra { override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right) - - override fun number(value: Number): MST = MST.Numeric(value) } -object MSTSpace : Space, NumericAlgebra by MSTAlgebra { +object MSTSpace : Space, NumericAlgebra { override val zero: MST = number(0.0) + override fun number(value: Number): MST = MST.Numeric(value) + + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) @@ -30,16 +33,44 @@ object MSTSpace : Space, NumericAlgebra by MSTAlgebra { override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } -object MSTRing : Ring, Space by MSTSpace { - override val one: MST = number(1.0) +object MSTRing : Ring, NumericAlgebra { + override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MST.Symbolic(value) + + override val zero: MST = MSTSpace.number(0.0) + override val one: MST = number(1.0) + override fun add(a: MST, b: MST): MST = + MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, b) - override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTSpace.binaryOperation(operation, left, right) + MSTAlgebra.binaryOperation(operation, left, right) } -object MSTField : Field, Ring by MSTRing { - override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) +object MSTField : Field{ + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MST.Numeric(value) + + override val zero: MST = MSTSpace.number(0.0) + override val one: MST = number(1.0) + override fun add(a: MST, b: MST): MST = + MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + + override fun multiply(a: MST, k: Number): MST = + MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, b) + + override fun divide(a: MST, b: MST): MST = + binaryOperation(FieldOperations.DIV_OPERATION, a, b) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTRing.binaryOperation(operation, left, right) -} \ No newline at end of file + MSTAlgebra.binaryOperation(operation, left, right) +} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt index dbd5238e3..61703cac7 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt @@ -1,19 +1,55 @@ package scientifik.kmath.ast import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.expressions.FunctionalExpressionField +import scientifik.kmath.expressions.FunctionalExpressionRing +import scientifik.kmath.expressions.FunctionalExpressionSpace +import scientifik.kmath.operations.* /** * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. */ -class MSTExpression(val algebra: NumericAlgebra, val mst: MST) : Expression { +class MSTExpression(val algebra: Algebra, val mst: MST) : Expression { /** * Substitute algebra raw value */ - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra by algebra { - override fun symbol(value: String): T = arguments[value] ?: super.symbol(value) + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra{ + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if(algebra is NumericAlgebra){ + algebra.number(value) + } else{ + error("Numeric nodes are not supported by $this") + } } override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} \ No newline at end of file +} + + +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MSTExpression = MSTExpression(this, mstAlgebra.block()) + +inline fun Space.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = + MSTExpression(this, MSTSpace.block()) + +inline fun Ring.mstInRing(block: MSTRing.() -> MST): MSTExpression = + MSTExpression(this, MSTRing.block()) + +inline fun Field.mstInField(block: MSTField.() -> MST): MSTExpression = + MSTExpression(this, MSTField.block()) + +inline fun > FunctionalExpressionSpace.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = + algebra.mstInSpace(block) + +inline fun > FunctionalExpressionRing.mstInRing(block: MSTRing.() -> MST): MSTExpression = + algebra.mstInRing(block) + +inline fun > FunctionalExpressionField.mstInField(block: MSTField.() -> MST): MSTExpression = + algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index e4bd705b4..bb091732e 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -70,23 +70,12 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() } -inline fun Algebra.compile(mst: MST): Expression = mst.compileWith(T::class, this) +/** + * Compile an [MST] to ASM using given algebra + */ +inline fun Algebra.expresion(mst: MST): Expression = mst.compileWith(T::class, this) /** * Optimize performance of an [MSTExpression] using ASM codegen */ -inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) - -//inline fun , E : Algebra> A.asm( -// mstAlgebra: E, -// block: E.() -> MST -//): Expression = mstAlgebra.block().compileWith(T::class, this) -// -//inline fun > A.asmInSpace(block: MSTSpace.() -> MST): Expression = -// MSTSpace.block().compileWith(T::class, this) -// -//inline fun > A.asmInRing(block: MSTRing.() -> MST): Expression = -// MSTRing.block().compileWith(T::class, this) -// -//inline fun > A.asmInField(block: MSTField.() -> MST): Expression = -// MSTField.block().compileWith(T::class, this) +inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) \ No newline at end of file diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index bfbf5e926..82af1a927 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -1,6 +1,8 @@ package scietifik.kmath.asm -import scientifik.kmath.asm.asmInField +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInSpace import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.RealField import kotlin.test.Test @@ -9,13 +11,14 @@ import kotlin.test.assertEquals class TestAsmExpressions { @Test fun testUnaryOperationInvocation() { - val res = RealField.asmInField { -symbol("x") }("x" to 2.0) + val expression = RealField.mstInSpace { -symbol("x") }.compile() + val res = expression("x" to 2.0) assertEquals(-2.0, res) } @Test fun testConstProductInvocation() { - val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0) + val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) assertEquals(4.0, res) } } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 752b3b601..f0f9a8bc1 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -1,8 +1,7 @@ package scietifik.kmath.ast -import scientifik.kmath.asm.compile +import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.parseMath -import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import kotlin.test.Test @@ -12,7 +11,7 @@ class AsmTest { @Test fun parsedExpression() { val mst = "2+2*(2+2)".parseMath() - val res = ComplexField.compile(mst)() + val res = ComplexField.evaluate(mst) assertEquals(Complex(10.0, 0.0), res) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt index cdd6f695b..9f1503285 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -7,17 +7,17 @@ import scientifik.kmath.operations.Space /** * Create a functional expression on this [Space] */ -fun Space.buildExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = +fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = FunctionalExpressionSpace(this).run(block) /** * Create a functional expression on this [Ring] */ -fun Ring.buildExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = +fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = FunctionalExpressionRing(this).run(block) /** * Create a functional expression on this [Field] */ -fun Field.buildExpression(block: FunctionalExpressionField>.() -> Expression): Expression = +fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression = FunctionalExpressionField(this).run(block) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index eaf3cd1d7..e512b1cd8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -8,11 +8,14 @@ import scientifik.kmath.operations.Algebra interface Expression { operator fun invoke(arguments: Map): T - companion object { - operator fun invoke(block: (Map) -> T): Expression = object : Expression { - override fun invoke(arguments: Map): T = block(arguments) - } - } + companion object +} + +/** + * Create simple lazily evaluated expression inside given algebra + */ +fun Algebra.expression(block: Algebra.(arguments: Map) -> T): Expression = object: Expression { + override fun invoke(arguments: Map): T = block(arguments) } operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt index a441a80c0..c8d6e8eb0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -69,10 +69,10 @@ interface FunctionalExpressionAlgebra> : ExpressionAlgebra(override val algebra: A) : FunctionalExpressionAlgebra, - Space> where A : Space { - override val zero: Expression - get() = const(algebra.zero) +open class FunctionalExpressionSpace>(override val algebra: A) : + FunctionalExpressionAlgebra, Space> { + + override val zero: Expression get() = const(algebra.zero) /** * Builds an Expression of addition of two another expressions. From a1f0188b8b0b3d5577c4ba8dec31ccf72ff2834e Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 16 Jun 2020 16:29:52 +0300 Subject: [PATCH 067/104] Functional expression builders --- .../FunctionalExpressionAlgebra.kt | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt index c8d6e8eb0..a8a26aa33 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -36,12 +36,10 @@ internal class FunctionalConstProductExpression( /** * A context class for [Expression] construction. + * + * @param algebra The algebra to provide for Expressions built. */ -interface FunctionalExpressionAlgebra> : ExpressionAlgebra> { - /** - * The algebra to provide for Expressions built. - */ - val algebra: A +abstract class FunctionalExpressionAlgebra>(val algebra: A) : ExpressionAlgebra> { /** * Builds an Expression of constant expression which does not depend on arguments. @@ -69,8 +67,8 @@ interface FunctionalExpressionAlgebra> : ExpressionAlgebra>(override val algebra: A) : - FunctionalExpressionAlgebra, Space> { +open class FunctionalExpressionSpace>(algebra: A) : + FunctionalExpressionAlgebra(algebra), Space> { override val zero: Expression get() = const(algebra.zero) @@ -98,7 +96,7 @@ open class FunctionalExpressionSpace>(override val algebra: A) : super.binaryOperation(operation, left, right) } -open class FunctionalExpressionRing(override val algebra: A) : FunctionalExpressionSpace(algebra), +open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), Ring> where A : Ring, A : NumericAlgebra { override val one: Expression get() = const(algebra.one) @@ -119,7 +117,7 @@ open class FunctionalExpressionRing(override val algebra: A) : FunctionalE super.binaryOperation(operation, left, right) } -open class FunctionalExpressionField(override val algebra: A) : +open class FunctionalExpressionField(algebra: A) : FunctionalExpressionRing(algebra), Field> where A : Field, A : NumericAlgebra { /** @@ -137,3 +135,12 @@ open class FunctionalExpressionField(override val algebra: A) : override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = super.binaryOperation(operation, left, right) } + +inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = + FunctionalExpressionSpace(this).block() + +inline fun > A.expressionInRing(block: FunctionalExpressionRing.() -> Expression): Expression = + FunctionalExpressionRing(this).block() + +inline fun > A.expressionInField(block: FunctionalExpressionField.() -> Expression): Expression = + FunctionalExpressionField(this).block() \ No newline at end of file From d6e7eb8143029541d2ff6c5ef320c627ccb66973 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 18 Jun 2020 11:35:20 +0700 Subject: [PATCH 068/104] Add advanced specialization for primitive non-bridge methods --- .../kotlin/scientifik/kmath/asm/asm.kt | 7 +- .../kmath/asm/internal/AsmBuilder.kt | 65 ++++++++++++++++++- .../{optimization.kt => specialization.kt} | 4 +- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 35 ++++++++-- 4 files changed, 98 insertions(+), 13 deletions(-) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{optimization.kt => specialization.kt} (90%) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index bb091732e..2345757f1 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -19,11 +19,10 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< when (node) { is MST.Symbolic -> loadVariable(node.value) is MST.Numeric -> { - val constant = if (algebra is NumericAlgebra) { - algebra.number(node.value) - } else { + val constant = if (algebra is NumericAlgebra) + algebra.number(node.value) else error("Number literals are not supported in $algebra") - } + loadTConstant(constant) } is MST.Unary -> { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 337219029..351d5b754 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.operations.Algebra +import java.io.File /** * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. @@ -39,12 +40,23 @@ internal class AsmBuilder internal constructor( private val invokeArgumentsVar: Int = 1 private val constants: MutableList = mutableListOf() private lateinit var invokeMethodVisitor: MethodVisitor + private var primitiveMode = false + internal var primitiveType = OBJECT_CLASS + internal var primitiveTypeSig = "L$OBJECT_CLASS;" + internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;" private var generatedInstance: AsmCompiledExpression? = null @Suppress("UNCHECKED_CAST") fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } + if (classOfT in SIGNATURE_LETTERS) { + primitiveMode = true + primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT) + primitiveType = SIGNATURE_LETTERS.getValue(classOfT) + primitiveTypeReturnSig = "L$T_CLASS;" + } + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, @@ -110,6 +122,10 @@ internal class AsmBuilder internal constructor( val l0 = Label() visitLabel(l0) invokeLabel0Visitor() + + if (primitiveMode) + boxPrimitiveToNumber() + visitReturnObject() val l1 = Label() visitLabel(l1) @@ -171,6 +187,8 @@ internal class AsmBuilder internal constructor( visitEnd() } + File("dump.class").writeBytes(classWriter.toByteArray()) + val new = classLoader .defineClass(className, classWriter.toByteArray()) .constructors @@ -182,6 +200,11 @@ internal class AsmBuilder internal constructor( } internal fun loadTConstant(value: T) { + if (primitiveMode) { + loadNumberConstant(value as Number) + return + } + if (classOfT in INLINABLE_NUMBERS) { loadNumberConstant(value as Number) invokeMethodVisitor.visitCheckCast(T_CLASS) @@ -191,6 +214,25 @@ internal class AsmBuilder internal constructor( loadConstant(value as Any, T_CLASS) } + private fun unboxInPrimitiveMode() { + if (!primitiveMode) + return + + unboxNumberToPrimitive() + } + + private fun boxPrimitiveToNumber() { + invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;") + } + + private fun unboxNumberToPrimitive() { + invokeMethodVisitor.visitInvokeVirtual( + owner = NUMBER_CLASS, + name = NUMBER_CONVERTER_METHODS.getValue(primitiveType), + descriptor = "()${primitiveTypeSig}" + ) + } + private fun loadConstant(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex @@ -199,13 +241,14 @@ internal class AsmBuilder internal constructor( visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") visitLdcOrIntConstant(idx) visitGetObjectArrayElement() - invokeMethodVisitor.visitCheckCast(type) + visitCheckCast(type) + unboxInPrimitiveMode() } } private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) - internal fun loadNumberConstant(value: Number) { + private fun loadNumberConstant(value: Number) { val clazz = value.javaClass val c = clazz.name.replace('.', '/') val sigLetter = SIGNATURE_LETTERS[clazz] @@ -218,7 +261,9 @@ internal class AsmBuilder internal constructor( else -> invokeMethodVisitor.visitLdcInsn(value) } - invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + if (!primitiveMode) + invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + return } @@ -251,6 +296,7 @@ internal class AsmBuilder internal constructor( ) invokeMethodVisitor.visitCheckCast(T_CLASS) + unboxInPrimitiveMode() } internal fun loadAlgebra() { @@ -274,6 +320,7 @@ internal class AsmBuilder internal constructor( ) { invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) invokeMethodVisitor.visitCheckCast(T_CLASS) + unboxInPrimitiveMode() } internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) @@ -290,11 +337,23 @@ internal class AsmBuilder internal constructor( ) } + private val NUMBER_CONVERTER_METHODS by lazy { + mapOf( + "B" to "byteValue", + "S" to "shortValue", + "I" to "intValue", + "J" to "longValue", + "F" to "floatValue", + "D" to "doubleValue" + ) + } + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/internal/AsmCompiledExpression" + internal const val NUMBER_CLASS = "java/lang/Number" internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt similarity index 90% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a39dba6b1..f279c94f4 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -24,9 +24,9 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri val sig = buildString { append('(') - repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } + repeat(arity) { append(primitiveTypeSig) } append(')') - append("L${AsmBuilder.OBJECT_CLASS};") + append(primitiveTypeReturnSig) } invokeAlgebraOperation( diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index a6cb1e247..294da32e9 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,10 +1,37 @@ package scietifik.kmath.asm -// -//class TestAsmAlgebras { +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmAlgebras { + @Test + fun space() { + val res = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to 2.toByte()) + + assertEquals(16, res) + } + // @Test // fun space() { -// val res = ByteRing.asmInRing { +// val res = ByteRing.asm { // binaryOperation( // "+", // @@ -63,4 +90,4 @@ package scietifik.kmath.asm // // assertEquals(3.0, res) // } -//} +} From e9ff33c4f95973aa3ebd5d376e847d453f2429f9 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 19 Jun 2020 23:56:35 +0700 Subject: [PATCH 069/104] Write KDoc comments for AsmBuilder, minimal refactor of it --- .../kmath/asm/internal/AsmBuilder.kt | 82 +++++++++++++++++-- .../kmath/asm/internal/optimization.kt | 3 +- 2 files changed, 75 insertions(+), 10 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 337219029..649d06277 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -11,9 +11,12 @@ import scientifik.kmath.operations.Algebra * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * - * @param T the type of AsmExpression to unwrap. - * @param algebra the algebra the applied AsmExpressions use. - * @param className the unique class name of new loaded class. + * @param T the type generated [AsmCompiledExpression] operates. + * @property classOfT the [Class] of T. + * @property algebra the algebra the applied AsmExpressions use. + * @property className the unique class name of new loaded class. + * @property invokeLabel0Visitor the function applied to this object when the L0 label of invoke implementation is + * being written. */ internal class AsmBuilder internal constructor( private val classOfT: Class<*>, @@ -21,10 +24,16 @@ internal class AsmBuilder internal constructor( private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit ) { + /** + * Internal classloader of [AsmBuilder] with alias to define class from byte array. + */ private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } + /** + * The instance of [ClassLoader] used by this builder. + */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) @@ -35,14 +44,39 @@ internal class AsmBuilder internal constructor( private val T_CLASS: String = classOfT.name.replace('.', '/') private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + + /** + * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. + */ private val invokeThisVar: Int = 0 + + /** + * Index of `arguments` variable in invoke method of [AsmCompiledExpression] built subclass. + */ private val invokeArgumentsVar: Int = 1 + + /** + * List of constants to provide to [AsmCompiledExpression] subclass. + */ private val constants: MutableList = mutableListOf() + + /** + * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. + */ private lateinit var invokeMethodVisitor: MethodVisitor + + /** + * The cache of [AsmCompiledExpression] subclass built by this builder. + */ private var generatedInstance: AsmCompiledExpression? = null + /** + * Subclasses, loads and instantiates the [AsmCompiledExpression] for given parameters. + * + * The built instance is cached. + */ @Suppress("UNCHECKED_CAST") - fun getInstance(): AsmCompiledExpression { + internal fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { @@ -181,6 +215,9 @@ internal class AsmBuilder internal constructor( return new } + /** + * Loads a constant from + */ internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { loadNumberConstant(value as Number) @@ -191,6 +228,9 @@ internal class AsmBuilder internal constructor( loadConstant(value as Any, T_CLASS) } + /** + * Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type]. + */ private fun loadConstant(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex @@ -205,7 +245,12 @@ internal class AsmBuilder internal constructor( private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) - internal fun loadNumberConstant(value: Number) { + /** + * Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive + * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded + * from it). + */ + private fun loadNumberConstant(value: Number) { val clazz = value.javaClass val c = clazz.name.replace('.', '/') val sigLetter = SIGNATURE_LETTERS[clazz] @@ -225,6 +270,9 @@ internal class AsmBuilder internal constructor( loadConstant(value, c) } + /** + * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. + */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { visitLoadObjectVar(invokeArgumentsVar) @@ -253,6 +301,9 @@ internal class AsmBuilder internal constructor( invokeMethodVisitor.visitCheckCast(T_CLASS) } + /** + * Loads algebra from according field of [AsmCompiledExpression] and casts it to class of [algebra] provided. + */ internal fun loadAlgebra() { loadThis() @@ -265,20 +316,32 @@ internal class AsmBuilder internal constructor( invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) } + /** + * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is + * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be + * called before the arguments and this operation. + * + * The result is casted to [T] automatically. + */ internal fun invokeAlgebraOperation( owner: String, method: String, descriptor: String, - opcode: Int = Opcodes.INVOKEINTERFACE, - isInterface: Boolean = true + opcode: Int = Opcodes.INVOKEINTERFACE ) { - invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) + invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, opcode == Opcodes.INVOKEINTERFACE) invokeMethodVisitor.visitCheckCast(T_CLASS) } + /** + * Writes a LDC Instruction with string constant provided. + */ internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) internal companion object { + /** + * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. + */ private val SIGNATURE_LETTERS: Map, String> by lazy { mapOf( java.lang.Byte::class.java to "B", @@ -290,6 +353,9 @@ internal class AsmBuilder internal constructor( ) } + /** + * Provides boxed number types values of which can be stored in JVM bytecode constant pool. + */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index a39dba6b1..889f3accb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -33,8 +33,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri owner = owner, method = aName, descriptor = sig, - opcode = Opcodes.INVOKEVIRTUAL, - isInterface = false + opcode = Opcodes.INVOKEVIRTUAL ) return true From ba499da2dac0b1487ac850974e6f8c5f721818e8 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 20 Jun 2020 00:05:00 +0700 Subject: [PATCH 070/104] More KDoc comments --- .../kmath/asm/internal/AsmCompiledExpression.kt | 7 +++++++ .../kotlin/scientifik/kmath/asm/internal/buildName.kt | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt index 0e113099a..a5236927c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt @@ -3,6 +3,13 @@ package scientifik.kmath.asm.internal import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra +/** + * [Expression] partial implementation to have it subclassed by actual implementations. Provides unified storage for + * objects needed to implement the expression. + * + * @property algebra the algebra to delegate calls. + * @property constants the constants array to have persistent objects to reference in [invoke]. + */ internal abstract class AsmCompiledExpression internal constructor( @JvmField protected val algebra: Algebra, @JvmField protected val constants: Array diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt index 65652d72b..66bd039c3 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -2,7 +2,13 @@ package scientifik.kmath.asm.internal import scientifik.kmath.ast.MST -internal fun buildName(mst: MST, collision: Int = 0): String { +/** + * Creates a class name for [AsmCompiledExpression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" try { From 635d708de5377d9e655136403eebb50ae886932a Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 20 Jun 2020 00:08:53 +0700 Subject: [PATCH 071/104] Add missing KDoc comments --- .../src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt | 2 +- .../scientifik/kmath/asm/internal/optimization.kt | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index bb091732e..43c1a377d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -73,7 +73,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< /** * Compile an [MST] to ASM using given algebra */ -inline fun Algebra.expresion(mst: MST): Expression = mst.compileWith(T::class, this) +inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) /** * Optimize performance of an [MSTExpression] using ASM codegen diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index 889f3accb..e81a7fd1a 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -5,6 +5,11 @@ import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity]. + * + * @return `true` if contains, else `false`. + */ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name @@ -14,6 +19,12 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo return true } +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts + * [AsmBuilder.invokeAlgebraOperation] of this method. + * + * @return `true` if contains, else `false`. + */ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name From 62ebda33021e9fc7dbf41e5d23ea57b0db5d9442 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 21 Jun 2020 20:23:50 +0700 Subject: [PATCH 072/104] Update readme, accident documentation-related refactor --- kmath-ast/README.md | 63 ++++++++++++++++++- .../kotlin/scientifik/kmath/asm/asm.kt | 10 ++- .../kmath/asm/internal/AsmBuilder.kt | 4 +- .../kmath/asm/internal/optimization.kt | 12 +++- 4 files changed, 78 insertions(+), 11 deletions(-) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index f9dbd4663..b5ca5886f 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -1,4 +1,61 @@ -# AST based expression representation and operations +# AST-based expression representation and operations (`kmath-ast`) -## Dynamic expression code generation -Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). \ No newline at end of file +This subproject implements the following features: + +- Expression Language and its parser. +- MST as expression language's syntax intermediate representation. +- Type-safe builder of MST. +- Evaluating expressions by traversing MST. + +## Dynamic expression code generation with OW2 ASM + +`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds +a special implementation of `Expression` with implemented `invoke` function. + +For example, the following builder: + +```kotlin + RealField.mstInField { symbol("x") + 2 }.compile() +``` + +… leads to generation of bytecode, which can be decompiled to the following Java class: + +```java +package scientifik.kmath.asm.generated; + +import java.util.Map; +import scientifik.kmath.asm.internal.AsmCompiledExpression; +import scientifik.kmath.operations.Algebra; +import scientifik.kmath.operations.RealField; + +// The class's name is build with MST's hash-code and collision fixing number. +public final class AsmCompiledExpression_45045_0 extends AsmCompiledExpression { + // Plain constructor + public AsmCompiledExpression_45045_0(Algebra algebra, Object[] constants) { + super(algebra, constants); + } + + // The actual dynamic code: + public final Double invoke(Map arguments) { + return (Double)((RealField)super.algebra).add((Double)arguments.get("x"), (Double)2.0D); + } +} +``` + +### Example Usage + +This API is an extension to MST and MSTExpression APIs. You may optimize both MST and MSTExpression: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +RealField.expression("2+2".parseMath()) +``` + +### Known issues + +- Using numeric algebras causes boxing and calling bridge methods. +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid +class loading overhead. +- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 43c1a377d..350bb95bb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -4,11 +4,11 @@ import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific -import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTExpression +import scientifik.kmath.ast.* import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.operations.RealField import kotlin.reflect.KClass /** @@ -78,4 +78,8 @@ inline fun Algebra.expression(mst: MST): Expression = ms /** * Optimize performance of an [MSTExpression] using ASM codegen */ -inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) \ No newline at end of file +inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) + +fun main() { + RealField.mstInField { symbol("x") + 2 }.compile() +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 649d06277..53fc4c4c1 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -8,7 +8,7 @@ import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.operations.Algebra /** - * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. + * ASM Builder is a structure that abstracts building a class for unwrapping [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @param T the type generated [AsmCompiledExpression] operates. @@ -343,7 +343,7 @@ internal class AsmBuilder internal constructor( * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ private val SIGNATURE_LETTERS: Map, String> by lazy { - mapOf( + hashMapOf( java.lang.Byte::class.java to "B", java.lang.Short::class.java to "S", java.lang.Integer::class.java to "I", diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index e81a7fd1a..efad97763 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -3,7 +3,13 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes import scientifik.kmath.operations.Algebra -private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") +private val methodNameAdapters: Map by lazy { + hashMapOf( + "+" to "add", + "*" to "multiply", + "/" to "divide" + ) +} /** * Checks if the target [context] for code generation contains a method with needed [name] and [arity]. @@ -13,7 +19,7 @@ private val methodNameAdapters: Map = mapOf("+" to "add", "*" to internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context::class.java.methods.find { it.name == aName && it.parameters.size == arity } + context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } ?: return false return true @@ -28,7 +34,7 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context::class.java.methods.find { it.name == aName && it.parameters.size == arity } + context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } ?: return false val owner = context::class.java.name.replace('.', '/') From e99f7ad360d2c0efb448c7dcdd2825dc1c9d663c Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 22 Jun 2020 04:05:52 +0700 Subject: [PATCH 073/104] Optimize constant pooling --- kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt | 6 +++--- .../kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt | 1 + .../scientifik/kmath/asm/internal/methodVisitors.kt | 8 ++++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 43c1a377d..feb3745b0 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -19,11 +19,11 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< when (node) { is MST.Symbolic -> loadVariable(node.value) is MST.Numeric -> { - val constant = if (algebra is NumericAlgebra) { + val constant = if (algebra is NumericAlgebra) algebra.number(node.value) - } else { + else error("Number literals are not supported in $algebra") - } + loadTConstant(constant) } is MST.Unary -> { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 649d06277..8b6892235 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -260,6 +260,7 @@ internal class AsmBuilder internal constructor( is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) + is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value) else -> invokeMethodVisitor.visitLdcInsn(value) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 70fb3bd44..6ecce8833 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -11,6 +11,8 @@ internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value 3 -> visitInsn(ICONST_3) 4 -> visitInsn(ICONST_4) 5 -> visitInsn(ICONST_5) + in -128..127 -> visitIntInsn(BIPUSH, value) + in -32768..32767 -> visitIntInsn(SIPUSH, value) else -> visitLdcInsn(value) } @@ -20,6 +22,12 @@ internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when else -> visitLdcInsn(value) } +internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) { + 0L -> visitInsn(LCONST_0) + 1L -> visitInsn(LCONST_1) + else -> visitLdcInsn(value) +} + internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) { 0f -> visitInsn(FCONST_0) 1f -> visitInsn(FCONST_1) From 29c6d2596733d25c4babc80aa90856b877153f42 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 22 Jun 2020 15:15:46 +0700 Subject: [PATCH 074/104] Optimize constant pooling for Byte and Short --- .../jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 8b6892235..ff66e69b0 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -257,6 +257,8 @@ internal class AsmBuilder internal constructor( if (sigLetter != null) { when (value) { + is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) + is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) From 675ace272c06f5bada1640d8c9219687f3e1ee91 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Tue, 23 Jun 2020 03:38:20 +0700 Subject: [PATCH 075/104] Minor Gradle settings modification, add benchmarks of different Expression implementatinos --- examples/build.gradle.kts | 10 ++-- .../ast/ExpressionsInterpretersBenchmark.kt | 54 +++++++++++++++++++ .../kmath/structures/ViktorBenchmark.kt | 17 ++---- settings.gradle.kts | 2 + 4 files changed, 64 insertions(+), 19 deletions(-) create mode 100644 examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 2fab47ac0..eadfc3f6b 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -1,16 +1,13 @@ -import org.jetbrains.kotlin.allopen.gradle.AllOpenExtension import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.71" - id("kotlinx.benchmark") version "0.2.0-dev-7" + kotlin("plugin.allopen") + id("kotlinx.benchmark") } -configure { - annotation("org.openjdk.jmh.annotations.State") -} +allOpen.annotation("org.openjdk.jmh.annotations.State") repositories { maven("http://dl.bintray.com/kyonifer/maven") @@ -24,6 +21,7 @@ sourceSets { } dependencies { + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt new file mode 100644 index 000000000..c5474e1d2 --- /dev/null +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -0,0 +1,54 @@ +package scientifik.kmath.ast + +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import scientifik.kmath.asm.compile +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.expressionInField +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.RealField +import kotlin.random.Random + +@State(Scope.Benchmark) +class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + private val random: Random = Random(1) + + @Benchmark + fun functionalExpression() { + val expr = algebra.expressionInField { + variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) + } + + invokeAndSum(expr) + } + + @Benchmark + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + } + + invokeAndSum(expr) + } + + @Benchmark + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + }.compile() + + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index be4115d81..5dc166cd9 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -23,9 +23,7 @@ class ViktorBenchmark { fun `Automatic field addition`() { autoField.run { var res = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += one } } } @@ -33,9 +31,7 @@ class ViktorBenchmark { fun `Viktor field addition`() { viktorField.run { var res = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } @@ -43,9 +39,7 @@ class ViktorBenchmark { fun `Raw Viktor`() { val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) var res = one - repeat(n) { - res = res + one - } + repeat(n) { res = res + one } } @Benchmark @@ -53,10 +47,7 @@ class ViktorBenchmark { realField.run { val fortyTwo = produce { 42.0 } var res = one - - repeat(n) { - res = ln(fortyTwo) - } + repeat(n) { res = ln(fortyTwo) } } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 465ecfca8..487e1d87f 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -3,10 +3,12 @@ pluginManagement { val toolsVersion = "0.5.0" plugins { + id("kotlinx.benchmark") version "0.2.0-dev-8" id("scientifik.mpp") version toolsVersion id("scientifik.jvm") version toolsVersion id("scientifik.atomic") version toolsVersion id("scientifik.publish") version toolsVersion + kotlin("plugin.allopen") version "1.3.72" } repositories { From 668d13c9d1b276d81f155d05d39e89ebf29b6c4a Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 23 Jun 2020 20:03:45 +0300 Subject: [PATCH 076/104] Minor refactoring + domains --- build.gradle.kts | 2 +- examples/build.gradle.kts | 8 +-- .../kmath/structures/ViktorBenchmark.kt | 10 +-- .../random/CMRandomGeneratorWrapper.kt | 38 +++++++++++ .../kotlin/scientifik/kmath/domains/Domain.kt | 15 +++++ .../kmath/domains/HyperSquareDomain.kt | 67 +++++++++++++++++++ .../scientifik/kmath/domains/RealDomain.kt | 65 ++++++++++++++++++ .../kmath/domains/UnconstrainedDomain.kt | 36 ++++++++++ .../kmath/domains/UnivariateDomain.kt | 48 +++++++++++++ .../scientifik/kmath/real/RealVector.kt | 25 +++---- .../scientifik/kmath/real/realMatrix.kt | 4 ++ .../scientifik/kmath/histogram/Histogram.kt | 10 +-- .../kotlin/scientifik/memory/MemorySpec.kt | 1 + 13 files changed, 294 insertions(+), 35 deletions(-) create mode 100644 kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt diff --git a/build.gradle.kts b/build.gradle.kts index 6d102a77a..052b457c5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("scientifik.publish") apply false } -val kmathVersion by extra("0.1.4-dev-7") +val kmathVersion by extra("0.1.4-dev-8") val bintrayRepo by extra("scientifik") val githubProject by extra("kmath") diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 2fab47ac0..fb47c998f 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.71" - id("kotlinx.benchmark") version "0.2.0-dev-7" + kotlin("plugin.allopen") version "1.3.72" + id("kotlinx.benchmark") version "0.2.0-dev-8" } configure { @@ -33,8 +33,8 @@ dependencies { implementation(project(":kmath-dimensions")) implementation("com.kyonifer:koma-core-ejml:0.12") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") - implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7") - "benchmarksCompile"(sourceSets.main.get().compileClasspath) + implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8") + "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath } // Configure benchmark diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index be4115d81..54105f778 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -20,7 +20,7 @@ class ViktorBenchmark { final val viktorField = ViktorNDField(intArrayOf(dim, dim)) @Benchmark - fun `Automatic field addition`() { + fun automaticFieldAddition() { autoField.run { var res = one repeat(n) { @@ -30,7 +30,7 @@ class ViktorBenchmark { } @Benchmark - fun `Viktor field addition`() { + fun viktorFieldAddition() { viktorField.run { var res = one repeat(n) { @@ -40,7 +40,7 @@ class ViktorBenchmark { } @Benchmark - fun `Raw Viktor`() { + fun rawViktor() { val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) var res = one repeat(n) { @@ -49,7 +49,7 @@ class ViktorBenchmark { } @Benchmark - fun `Real field log`() { + fun realdFieldLog() { realField.run { val fortyTwo = produce { 42.0 } var res = one @@ -61,7 +61,7 @@ class ViktorBenchmark { } @Benchmark - fun `Raw Viktor log`() { + fun rawViktorLog() { val fortyTwo = F64Array.full(dim, dim, init = 42.0) var res: F64Array repeat(n) { diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt new file mode 100644 index 000000000..13e79d60e --- /dev/null +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.commons.random + +import scientifik.kmath.prob.RandomGenerator + +class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : + org.apache.commons.math3.random.RandomGenerator { + private var generator = factory(intArrayOf()) + + override fun nextBoolean(): Boolean = generator.nextBoolean() + + override fun nextFloat(): Float = generator.nextDouble().toFloat() + + override fun setSeed(seed: Int) { + generator = factory(intArrayOf(seed)) + } + + override fun setSeed(seed: IntArray) { + generator = factory(seed) + } + + override fun setSeed(seed: Long) { + setSeed(seed.toInt()) + } + + override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + override fun nextInt(): Int = generator.nextInt() + + override fun nextInt(n: Int): Int = generator.nextInt(n) + + override fun nextGaussian(): Double = TODO() + + override fun nextDouble(): Double = generator.nextDouble() + + override fun nextLong(): Long = generator.nextLong() +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt new file mode 100644 index 000000000..333b77cb4 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * A simple geometric domain + */ +interface Domain { + operator fun contains(point: Point): Boolean + + /** + * Number of hyperspace dimensions + */ + val dimension: Int +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt new file mode 100644 index 000000000..21912b87c --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.indices + +/** + * + * HyperSquareDomain class. + * + * @author Alexander Nozik + */ +class HyperSquareDomain(private val lower: DoubleBuffer, private val upper: DoubleBuffer) : RealDomain { + + override operator fun contains(point: Point): Boolean = point.indices.all { i -> + point[i] in lower[i]..upper[i] + } + + override val dimension: Int get() = lower.size + + override fun getLowerBound(num: Int, point: Point): Double? = lower[num] + + override fun getLowerBound(num: Int): Double? = lower[num] + + override fun getUpperBound(num: Int, point: Point): Double? = upper[num] + + override fun getUpperBound(num: Int): Double? = upper[num] + + override fun nearestInDomain(point: Point): Point { + val res: DoubleArray = DoubleArray(point.size) { i -> + when { + point[i] < lower[i] -> lower[i] + point[i] > upper[i] -> upper[i] + else -> point[i] + } + } + return DoubleBuffer(*res) + } + + override fun volume(): Double { + var res = 1.0 + for (i in 0 until dimension) { + if (lower[i].isInfinite() || upper[i].isInfinite()) { + return Double.POSITIVE_INFINITY + } + if (upper[i] > lower[i]) { + res *= upper[i] - lower[i] + } + } + return res + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt new file mode 100644 index 000000000..89115887e --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * n-dimensional volume + * + * @author Alexander Nozik + */ +interface RealDomain: Domain { + + fun nearestInDomain(point: Point): Point + + /** + * The lower edge for the domain going down from point + * @param num + * @param point + * @return + */ + fun getLowerBound(num: Int, point: Point): Double? + + /** + * The upper edge of the domain going up from point + * @param num + * @param point + * @return + */ + fun getUpperBound(num: Int, point: Point): Double? + + /** + * Global lower edge + * @param num + * @return + */ + fun getLowerBound(num: Int): Double? + + /** + * Global upper edge + * @param num + * @return + */ + fun getUpperBound(num: Int): Double? + + /** + * Hyper volume + * @return + */ + fun volume(): Double + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt new file mode 100644 index 000000000..e49fd3b37 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +class UnconstrainedDomain(override val dimension: Int) : RealDomain { + + override operator fun contains(point: Point): Boolean = true + + override fun getLowerBound(num: Int, point: Point): Double? = Double.NEGATIVE_INFINITY + + override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY + + override fun getUpperBound(num: Int, point: Point): Double? = Double.POSITIVE_INFINITY + + override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY + + override fun nearestInDomain(point: Point): Point = point + + override fun volume(): Double = Double.POSITIVE_INFINITY + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt new file mode 100644 index 000000000..ef521d5ea --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt @@ -0,0 +1,48 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.asBuffer + +inline class UnivariateDomain(val range: ClosedFloatingPointRange) : RealDomain { + + operator fun contains(d: Double): Boolean = range.contains(d) + + override operator fun contains(point: Point): Boolean { + require(point.size == 0) + return contains(point[0]) + } + + override fun nearestInDomain(point: Point): Point { + require(point.size == 1) + val value = point[0] + return when{ + value in range -> point + value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() + else -> doubleArrayOf(range.start).asBuffer() + } + } + + override fun getLowerBound(num: Int, point: Point): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int, point: Point): Double? { + require(num == 0) + return range.endInclusive + } + + override fun getLowerBound(num: Int): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int): Double? { + require(num == 0) + return range.endInclusive + } + + override fun volume(): Double = range.endInclusive - range.start + + override val dimension: Int get() = 1 +} \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index ff4c835ed..23c7e19cb 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -12,26 +12,23 @@ import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asIterable import kotlin.math.sqrt +typealias RealPoint = Point + fun DoubleArray.asVector() = RealVector(this.asBuffer()) fun List.asVector() = RealVector(this.asBuffer()) - object VectorL2Norm : Norm, Double> { override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) } inline class RealVector(private val point: Point) : - SpaceElement, RealVector, VectorSpace>, Point { + SpaceElement>, RealPoint { - override val context: VectorSpace - get() = space( - point.size - ) + override val context: VectorSpace get() = space(point.size) - override fun unwrap(): Point = point + override fun unwrap(): RealPoint = point - override fun Point.wrap(): RealVector = - RealVector(this) + override fun RealPoint.wrap(): RealVector = RealVector(this) override val size: Int get() = point.size @@ -48,12 +45,8 @@ inline class RealVector(private val point: Point) : operator fun invoke(vararg values: Double): RealVector = values.asVector() - fun space(dim: Int): BufferVectorSpace = - spaceCache.getOrPut(dim) { - BufferVectorSpace( - dim, - RealField - ) { size, init -> Buffer.real(size, init) } - } + fun space(dim: Int): BufferVectorSpace = spaceCache.getOrPut(dim) { + BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } + } } } \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 813d89577..0f4ccf2a8 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -27,6 +27,10 @@ typealias RealMatrix = Matrix fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) +fun Array.toMatrix(): RealMatrix{ + return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } +} + fun Sequence.toMatrix(): RealMatrix = toList().let { MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 329af72a1..5199669f5 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -1,18 +1,10 @@ package scientifik.kmath.histogram +import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.DoubleBuffer -/** - * A simple geometric domain - * TODO move to geometry module - */ -interface Domain { - operator fun contains(vector: Point): Boolean - val dimension: Int -} - /** * The bin in the histogram. The histogram is by definition always done in the real space */ diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt index 0896f0dcb..7999aa2ab 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt @@ -10,6 +10,7 @@ interface MemorySpec { val objectSize: Int fun MemoryReader.read(offset: Int): T + //TODO consider thread safety fun MemoryWriter.write(offset: Int, value: T) } From ea8c0db85445cdcdaa06d8ce78f839cf0d2162d6 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 23 Jun 2020 21:46:05 +0300 Subject: [PATCH 077/104] Histogram bin fix --- .../kotlin/scientifik/kmath/histogram/RealHistogram.kt | 4 ++-- .../kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 4438f5d60..f9d815421 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -1,8 +1,8 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point -import scientifik.kmath.real.asVector import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.real.asVector import scientifik.kmath.structures.* import kotlin.math.floor @@ -21,7 +21,7 @@ data class BinDef>(val space: SpaceOperations>, val c class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - override fun contains(vector: Point): Boolean = def.contains(vector) + override fun contains(point: Point): Boolean = def.contains(point) override val dimension: Int get() = def.center.size diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index dcc5ac0eb..af01205bf 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) - override fun contains(vector: Buffer): Boolean = contains(vector[0]) + override fun contains(point: Buffer): Boolean = contains(point[0]) internal operator fun inc() = this.also { counter.increment() } From 9a3709624d5b415f5ae0bc9c7c3208a5dccb651e Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 24 Jun 2020 15:54:17 +0700 Subject: [PATCH 078/104] Use hashMap instead of map --- .../kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 77a4ee9e9..95c84ca51 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -387,7 +387,7 @@ internal class AsmBuilder internal constructor( * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ private val SIGNATURE_LETTERS: Map, String> by lazy { - mapOf( + hashMapOf( java.lang.Byte::class.java to "B", java.lang.Short::class.java to "S", java.lang.Integer::class.java to "I", @@ -397,8 +397,8 @@ internal class AsmBuilder internal constructor( ) } - private val NUMBER_CONVERTER_METHODS by lazy { - mapOf( + private val NUMBER_CONVERTER_METHODS: Map by lazy { + hashMapOf( "B" to "byteValue", "S" to "shortValue", "I" to "intValue", From f16ebb1682b40a54b08e41a25be1412f170e873e Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 24 Jun 2020 15:55:25 +0700 Subject: [PATCH 079/104] Remove accidentally left debug main function --- kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 3ee071c77..2036ffec6 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -79,7 +79,3 @@ inline fun Algebra.expression(mst: MST): Expression = ms * Optimize performance of an [MSTExpression] using ASM codegen */ inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) - -fun main() { - RealField.mstInField { symbol("x") + 2 }.compile() -} From 02f42ee56a8047af9ab7846b5e0b271084c9192a Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 24 Jun 2020 20:55:48 +0700 Subject: [PATCH 080/104] Eliminate bridging --- .../kotlin/scientifik/kmath/asm/asm.kt | 63 ++- .../kmath/asm/internal/AsmBuilder.kt | 366 ++++++++++-------- .../asm/internal/AsmCompiledExpression.kt | 1 - .../scientifik/kmath/asm/internal/classes.kt | 7 + .../kmath/asm/internal/methodVisitors.kt | 59 +-- .../kmath/asm/internal/specialization.kt | 34 +- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 20 +- 7 files changed, 274 insertions(+), 276 deletions(-) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index a09a7ebe3..a3af80ccd 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,8 +1,9 @@ package scientifik.kmath.asm +import org.objectweb.asm.Type import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.buildExpectationStack import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST import scientifik.kmath.ast.MSTExpression @@ -18,6 +19,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< fun AsmBuilder.visit(node: MST) { when (node) { is MST.Symbolic -> loadVariable(node.value) + is MST.Numeric -> { val constant = if (algebra is NumericAlgebra) algebra.number(node.value) @@ -26,48 +28,49 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< loadTConstant(constant) } + is MST.Unary -> { loadAlgebra() - - if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation) - + if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) visit(node.value) - if (!tryInvokeSpecific(algebra, node.operation, 1)) { - invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "unaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } + if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = "unaryOperation", + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + AsmBuilder.OBJECT_TYPE + ), + + tArity = 1 + ) } is MST.Binary -> { loadAlgebra() - - if (!hasSpecific(algebra, node.operation, 2)) - loadStringConstant(node.operation) - + if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) visit(node.left) visit(node.right) - if (!tryInvokeSpecific(algebra, node.operation, 2)) { + if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = "binaryOperation", - invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "binaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + AsmBuilder.OBJECT_TYPE, + AsmBuilder.OBJECT_TYPE + ), + + tArity = 2 + ) } } } - return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() + return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() } /** @@ -79,7 +82,3 @@ inline fun Algebra.expression(mst: MST): Expression = ms * Optimize performance of an [MSTExpression] using ASM codegen */ inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) - -fun main() { - RealField.mstInField { symbol("x") + 2 }.compile() -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 95c84ca51..2c9cd48d9 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -1,14 +1,17 @@ package scientifik.kmath.asm.internal -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.Label -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.AALOAD +import org.objectweb.asm.Opcodes.RETURN +import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader +import scientifik.kmath.ast.MST import scientifik.kmath.operations.Algebra +import java.util.* +import kotlin.reflect.KClass /** - * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. + * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @param T the type of AsmExpression to unwrap. @@ -16,7 +19,7 @@ import scientifik.kmath.operations.Algebra * @param className the unique class name of new loaded class. */ internal class AsmBuilder internal constructor( - private val classOfT: Class<*>, + private val classOfT: KClass<*>, private val algebra: Algebra, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit @@ -31,16 +34,16 @@ internal class AsmBuilder internal constructor( /** * The instance of [ClassLoader] used by this builder. */ - private val classLoader: ClassLoader = - ClassLoader(javaClass.classLoader) + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) @Suppress("PrivatePropertyName") - private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + private val T_ALGEBRA_TYPE: Type = algebra::class.asm @Suppress("PrivatePropertyName") - private val T_CLASS: String = classOfT.name.replace('.', '/') + internal val T_TYPE: Type = classOfT.asm - private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + @Suppress("PrivatePropertyName") + private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. @@ -60,11 +63,16 @@ internal class AsmBuilder internal constructor( /** * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. */ - private lateinit var invokeMethodVisitor: MethodVisitor - private var primitiveMode = false - internal var primitiveType = OBJECT_CLASS - internal var primitiveTypeSig = "L$OBJECT_CLASS;" - internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;" + private lateinit var invokeMethodVisitor: InstructionAdapter + internal var primitiveMode = false + + @Suppress("PropertyName") + internal var PRIMITIVE_MASK: Type = OBJECT_TYPE + + @Suppress("PropertyName") + internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE + private val typeStack = Stack() + internal val expectationStack = Stack().apply { push(T_TYPE) } /** * The cache of [AsmCompiledExpression] subclass built by this builder. @@ -80,89 +88,85 @@ internal class AsmBuilder internal constructor( fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } - if (classOfT in SIGNATURE_LETTERS) { + if (SIGNATURE_LETTERS.containsKey(classOfT.java)) { primitiveMode = true - primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT) - primitiveType = SIGNATURE_LETTERS.getValue(classOfT) - primitiveTypeReturnSig = "L$T_CLASS;" + PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java) + PRIMITIVE_MASK_BOXED = T_TYPE } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - slashesClassName, - "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + CLASS_TYPE.internalName, + "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;", + ASM_COMPILED_EXPRESSION_TYPE.internalName, arrayOf() ) visitMethod( - access = Opcodes.ACC_PUBLIC, - name = "", - descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", - signature = null, - exceptions = null - ) { + Opcodes.ACC_PUBLIC, + "", + Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + null, + null + ).instructionAdapter { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 val l0 = Label() visitLabel(l0) - visitLoadObjectVar(thisVar) - visitLoadObjectVar(algebraVar) - visitLoadObjectVar(constantsVar) + load(thisVar, CLASS_TYPE) + load(algebraVar, ALGEBRA_TYPE) + load(constantsVar, OBJECT_ARRAY_TYPE) - visitInvokeSpecial( - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + invokespecial( + ASM_COMPILED_EXPRESSION_TYPE.internalName, "", - "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V" + Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + false ) val l1 = Label() visitLabel(l1) - visitReturn() + visitInsn(RETURN) val l2 = Label() visitLabel(l2) - visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar) visitLocalVariable( "algebra", - "L$ALGEBRA_CLASS;", - "L$ALGEBRA_CLASS;", + ALGEBRA_TYPE.descriptor, + "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;", l0, l2, algebraVar ) - visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar) visitMaxs(0, 3) visitEnd() } visitMethod( - access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, - name = "invoke", - descriptor = "(L$MAP_CLASS;)L$T_CLASS;", - signature = "(L$MAP_CLASS;)L$T_CLASS;", - exceptions = null - ) { + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + "invoke", + Type.getMethodDescriptor(T_TYPE, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}", + null + ).instructionAdapter { invokeMethodVisitor = this visitCode() val l0 = Label() visitLabel(l0) invokeLabel0Visitor() - - if (primitiveMode) - boxPrimitiveToNumber() - - visitReturnObject() + areturn(T_TYPE) val l1 = Label() visitLabel(l1) visitLocalVariable( "this", - "L$slashesClassName;", + CLASS_TYPE.descriptor, null, l0, l1, @@ -171,8 +175,8 @@ internal class AsmBuilder internal constructor( visitLocalVariable( "arguments", - "L$MAP_CLASS;", - "L$MAP_CLASS;", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;", l0, l1, invokeArgumentsVar @@ -183,28 +187,28 @@ internal class AsmBuilder internal constructor( } visitMethod( - access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, - name = "invoke", - descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;", - signature = null, - exceptions = null - ) { + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + "invoke", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null + ).instructionAdapter { val thisVar = 0 val argumentsVar = 1 visitCode() val l0 = Label() visitLabel(l0) - visitLoadObjectVar(thisVar) - visitLoadObjectVar(argumentsVar) - visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;") - visitReturnObject() + load(thisVar, OBJECT_TYPE) + load(argumentsVar, MAP_TYPE) + invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false) + areturn(T_TYPE) val l1 = Label() visitLabel(l1) visitLocalVariable( "this", - "L$slashesClassName;", - T_CLASS, + CLASS_TYPE.descriptor, + null, l0, l1, thisVar @@ -231,116 +235,111 @@ internal class AsmBuilder internal constructor( * Loads a constant from */ internal fun loadTConstant(value: T) { - if (primitiveMode) { - loadNumberConstant(value as Number) + if (classOfT.java in INLINABLE_NUMBERS) { + val expectedType = expectationStack.pop()!! + val mustBeBoxed = expectedType.sort == Type.OBJECT + loadNumberConstant(value as Number, mustBeBoxed) + if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK) return } - if (classOfT in INLINABLE_NUMBERS) { - loadNumberConstant(value as Number) - invokeMethodVisitor.visitCheckCast(T_CLASS) - return - } - - loadConstant(value as Any, T_CLASS) + loadConstant(value as Any, T_TYPE) } - /** - * Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type]. - */ - private fun unboxInPrimitiveMode() { - if (!primitiveMode) - return + private fun box(): Unit = invokeMethodVisitor.invokestatic( + T_TYPE.internalName, + "valueOf", + Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK), + false + ) - unboxNumberToPrimitive() - } + private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( + NUMBER_TYPE.internalName, + NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK), + Type.getMethodDescriptor(PRIMITIVE_MASK), + false + ) - private fun boxPrimitiveToNumber() { - invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;") - } - - private fun unboxNumberToPrimitive() { - invokeMethodVisitor.visitInvokeVirtual( - owner = NUMBER_CLASS, - name = NUMBER_CONVERTER_METHODS.getValue(primitiveType), - descriptor = "()${primitiveTypeSig}" - ) - } - - private fun loadConstant(value: Any, type: String) { + private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - - invokeMethodVisitor.run { - loadThis() - visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") - visitLdcOrIntConstant(idx) - visitGetObjectArrayElement() - invokeMethodVisitor.visitCheckCast(type) - } + loadThis() + getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + checkcast(type) } - private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE) /** - * Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive + * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * from it). */ - private fun loadNumberConstant(value: Number) { - val clazz = value.javaClass - val c = clazz.name.replace('.', '/') - val sigLetter = SIGNATURE_LETTERS[clazz] + private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { + val boxed = value::class.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] - if (sigLetter != null) { - when (value) { - is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) - is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) - is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) - is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) - is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) - is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value) - else -> invokeMethodVisitor.visitLdcInsn(value) + if (primitive != null) { + when (primitive) { + Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) } - if (!primitiveMode) - invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + if (mustBeBoxed) { + box() + invokeMethodVisitor.checkcast(T_TYPE) + } return } - loadConstant(value, c) + loadConstant(value, boxed) + if (!mustBeBoxed) unbox() + else invokeMethodVisitor.checkcast(T_TYPE) } /** * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - visitLoadObjectVar(invokeArgumentsVar) + load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) if (defaultValue != null) { - visitLdcInsn(name) + loadStringConstant(name) loadTConstant(defaultValue) - visitInvokeInterface( - owner = MAP_CLASS, - name = "getOrDefault", - descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;" + invokeinterface( + MAP_TYPE.internalName, + "getOrDefault", + Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.visitCheckCast(T_CLASS) + invokeMethodVisitor.checkcast(T_TYPE) return } - visitLdcInsn(name) + loadStringConstant(name) - visitInvokeInterface( - owner = MAP_CLASS, - name = "get", - descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;" + invokeinterface( + MAP_TYPE.internalName, + "get", + Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.visitCheckCast(T_CLASS) - unboxInPrimitiveMode() + invokeMethodVisitor.checkcast(T_TYPE) + val expectedType = expectationStack.pop()!! + + if (expectedType.sort == Type.OBJECT) + typeStack.push(T_TYPE) + else { + unbox() + typeStack.push(PRIMITIVE_MASK) + } } /** @@ -349,13 +348,10 @@ internal class AsmBuilder internal constructor( internal fun loadAlgebra() { loadThis() - invokeMethodVisitor.visitGetField( - owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS, - name = "algebra", - descriptor = "L$ALGEBRA_CLASS;" - ) - - invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) + invokeMethodVisitor.run { + getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor) + checkcast(T_ALGEBRA_TYPE) + } } /** @@ -369,42 +365,70 @@ internal class AsmBuilder internal constructor( owner: String, method: String, descriptor: String, - opcode: Int = Opcodes.INVOKEINTERFACE, - isInterface: Boolean = true + tArity: Int, + opcode: Int = Opcodes.INVOKEINTERFACE ) { - invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) - invokeMethodVisitor.visitCheckCast(T_CLASS) - unboxInPrimitiveMode() + repeat(tArity) { typeStack.pop() } + + invokeMethodVisitor.visitMethodInsn( + opcode, + owner, + method, + descriptor, + opcode == Opcodes.INVOKEINTERFACE + ) + + invokeMethodVisitor.checkcast(T_TYPE) + val isLastExpr = expectationStack.size == 1 + val expectedType = expectationStack.pop()!! + + if (expectedType.sort == Type.OBJECT || isLastExpr) + typeStack.push(T_TYPE) + else { + unbox() + typeStack.push(PRIMITIVE_MASK) + } } /** * Writes a LDC Instruction with string constant provided. */ - internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) + internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) internal companion object { /** * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ - private val SIGNATURE_LETTERS: Map, String> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" + java.lang.Byte::class.java to Type.BYTE_TYPE, + java.lang.Short::class.java to Type.SHORT_TYPE, + java.lang.Integer::class.java to Type.INT_TYPE, + java.lang.Long::class.java to Type.LONG_TYPE, + java.lang.Float::class.java to Type.FLOAT_TYPE, + java.lang.Double::class.java to Type.DOUBLE_TYPE ) } - private val NUMBER_CONVERTER_METHODS: Map by lazy { + private val BOXED_TO_PRIMITIVES: Map by lazy { hashMapOf( - "B" to "byteValue", - "S" to "shortValue", - "I" to "intValue", - "J" to "longValue", - "F" to "floatValue", - "D" to "doubleValue" + java.lang.Byte::class.asm to Type.BYTE_TYPE, + java.lang.Short::class.asm to Type.SHORT_TYPE, + java.lang.Integer::class.asm to Type.INT_TYPE, + java.lang.Long::class.asm to Type.LONG_TYPE, + java.lang.Float::class.asm to Type.FLOAT_TYPE, + java.lang.Double::class.asm to Type.DOUBLE_TYPE + ) + } + + private val NUMBER_CONVERTER_METHODS: Map by lazy { + hashMapOf( + Type.BYTE_TYPE to "byteValue", + Type.SHORT_TYPE to "shortValue", + Type.INT_TYPE to "intValue", + Type.LONG_TYPE to "longValue", + Type.FLOAT_TYPE to "floatValue", + Type.DOUBLE_TYPE to "doubleValue" ) } @@ -412,14 +436,14 @@ internal class AsmBuilder internal constructor( * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm + internal val NUMBER_TYPE: Type = java.lang.Number::class.asm + internal val MAP_TYPE: Type = java.util.Map::class.asm + internal val OBJECT_TYPE: Type = java.lang.Object::class.asm - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/asm/internal/AsmCompiledExpression" - - internal const val NUMBER_CLASS = "java/lang/Number" - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val STRING_CLASS = "java/lang/String" + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") + internal val OBJECT_ARRAY_TYPE: Type = Array::class.asm + internal val ALGEBRA_TYPE: Type = Algebra::class.asm + internal val STRING_TYPE: Type = java.lang.String::class.asm } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt index a5236927c..7c4a9fc99 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt @@ -16,4 +16,3 @@ internal abstract class AsmCompiledExpression internal constructor( ) : Expression { abstract override fun invoke(arguments: Map): T } - diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt new file mode 100644 index 000000000..dc0b35531 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt @@ -0,0 +1,7 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Type +import kotlin.reflect.KClass + +internal val KClass<*>.asm: Type + get() = Type.getType(java) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 6ecce8833..7b0d346b7 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -1,60 +1,9 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.commons.InstructionAdapter -internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) { - -1 -> visitInsn(ICONST_M1) - 0 -> visitInsn(ICONST_0) - 1 -> visitInsn(ICONST_1) - 2 -> visitInsn(ICONST_2) - 3 -> visitInsn(ICONST_3) - 4 -> visitInsn(ICONST_4) - 5 -> visitInsn(ICONST_5) - in -128..127 -> visitIntInsn(BIPUSH, value) - in -32768..32767 -> visitIntInsn(SIPUSH, value) - else -> visitLdcInsn(value) -} +fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) -internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) { - 0.0 -> visitInsn(DCONST_0) - 1.0 -> visitInsn(DCONST_1) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) { - 0L -> visitInsn(LCONST_0) - 1L -> visitInsn(LCONST_1) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) { - 0f -> visitInsn(FCONST_0) - 1f -> visitInsn(FCONST_1) - 2f -> visitInsn(FCONST_2) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true) - -internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false) - -internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false) - -internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false) - -internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type) - -internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit = - visitFieldInsn(GETFIELD, owner, name, descriptor) - -internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`) - -internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD) - -internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN) -internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN) +fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 51d056f8c..2e15a1a93 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -1,6 +1,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map by lazy { @@ -12,17 +13,19 @@ private val methodNameAdapters: Map by lazy { } /** - * Checks if the target [context] for code generation contains a method with needed [name] and [arity]. + * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds + * type expectation stack for needed arity. * * @return `true` if contains, else `false`. */ -internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { +internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } - ?: return false + val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null + val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE + repeat(arity) { expectationStack.push(t) } - return true + return hasSpecific } /** @@ -34,22 +37,23 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } - ?: return false + val method = + context.javaClass.methods.find { + var suitableSignature = it.name == aName && it.parameters.size == arity + + if (primitiveMode && it.isBridge) + suitableSignature = false + + suitableSignature + } ?: return false val owner = context::class.java.name.replace('.', '/') - val sig = buildString { - append('(') - repeat(arity) { append(primitiveTypeSig) } - append(')') - append(primitiveTypeReturnSig) - } - invokeAlgebraOperation( owner = owner, method = aName, - descriptor = sig, + descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }), + tArity = arity, opcode = Opcodes.INVOKEVIRTUAL ) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 294da32e9..ce5b7cbf9 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -10,7 +10,23 @@ import kotlin.test.assertEquals class TestAsmAlgebras { @Test fun space() { - val res = ByteRing.mstInSpace { + val res1 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }("x" to 2.toByte()) + + val res2 = ByteRing.mstInSpace { binaryOperation( "+", @@ -26,7 +42,7 @@ class TestAsmAlgebras { ) + symbol("x") + zero }.compile()("x" to 2.toByte()) - assertEquals(16, res) + assertEquals(res1, res2) } // @Test From fffc75215313dd580e1b23ddd32eca04e84cc7b1 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 24 Jun 2020 21:17:06 +0700 Subject: [PATCH 081/104] Add more tests and improve current, fix type stack underflow exception --- .../kmath/asm/internal/AsmBuilder.kt | 2 +- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 123 +++++++++--------- .../scietifik/kmath/asm/TestAsmExpressions.kt | 7 + .../kotlin/scietifik/kmath/ast/AsmTest.kt | 15 ++- .../kotlin/scietifik/kmath/ast/ParserTest.kt | 12 +- 5 files changed, 92 insertions(+), 67 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 2c9cd48d9..8f45c4044 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -368,7 +368,7 @@ internal class AsmBuilder internal constructor( tArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { - repeat(tArity) { typeStack.pop() } + repeat(tArity) { if (!typeStack.empty()) typeStack.pop() } invokeMethodVisitor.visitMethodInsn( opcode, diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index ce5b7cbf9..4c2be811e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,9 +1,12 @@ package scietifik.kmath.asm import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInRing import scientifik.kmath.ast.mstInSpace import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.ByteRing +import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals @@ -45,65 +48,63 @@ class TestAsmAlgebras { assertEquals(res1, res2) } -// @Test -// fun space() { -// val res = ByteRing.asm { -// binaryOperation( -// "+", -// -// unaryOperation( -// "+", -// 3.toByte() - (2.toByte() + (multiply( -// add(number(1), number(1)), -// 2 -// ) + 1.toByte()) * 3.toByte() - 1.toByte()) -// ), -// -// number(1) -// ) + symbol("x") + zero -// }("x" to 2.toByte()) -// -// assertEquals(16, res) -// } -// -// @Test -// fun ring() { -// val res = ByteRing.asmInRing { -// binaryOperation( -// "+", -// -// unaryOperation( -// "+", -// (3.toByte() - (2.toByte() + (multiply( -// add(const(1), const(1)), -// 2 -// ) + 1.toByte()))) * 3.0 - 1.toByte() -// ), -// -// number(1) -// ) * const(2) -// }() -// -// assertEquals(24, res) -// } -// -// @Test -// fun field() { -// val res = RealField.asmInField { -// +(3 - 2 + 2*(number(1)+1.0) -// -// unaryOperation( -// "+", -// (3.0 - (2.0 + (multiply( -// add((1.0), const(1.0)), -// 2 -// ) + 1.0))) * 3 - 1.0 -// )+ -// -// number(1) -// ) / 2, const(2.0)) * one -// }() -// -// assertEquals(3.0, res) -// } + @Test + fun ring() { + val res1 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }("x" to 3.toByte()) + + val res2 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }.compile()("x" to 3.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun field() { + val res1 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + 1 / 2 + number(2.0) * one + ) + }("x" to 2.0) + + val res2 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + 1 / 2 + number(2.0) * one + ) + }.compile()("x" to 2.0) + + assertEquals(res1, res2) + } } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 82af1a927..824201aa7 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -16,6 +16,13 @@ class TestAsmExpressions { assertEquals(-2.0, res) } + @Test + fun testBinaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile() + val res = expression("x" to 2.0) + assertEquals(-1.0, res) + } + @Test fun testConstProductInvocation() { val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index f0f9a8bc1..08d7fff47 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -1,7 +1,10 @@ package scietifik.kmath.ast -import scientifik.kmath.ast.evaluate +import scientifik.kmath.asm.compile +import scientifik.kmath.asm.expression +import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import kotlin.test.Test @@ -9,9 +12,15 @@ import kotlin.test.assertEquals class AsmTest { @Test - fun parsedExpression() { + fun `compile MST`() { val mst = "2+2*(2+2)".parseMath() - val res = ComplexField.evaluate(mst) + val res = ComplexField.expression(mst)() + assertEquals(Complex(10.0, 0.0), res) + } + + @Test + fun `compile MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()() assertEquals(Complex(10.0, 0.0), res) } } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt index 06546302e..5394a4b00 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -1,17 +1,25 @@ package scietifik.kmath.ast import scientifik.kmath.ast.evaluate +import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField -import kotlin.test.assertEquals import kotlin.test.Test +import kotlin.test.assertEquals internal class ParserTest { @Test - fun parsedExpression() { + fun `evaluate MST`() { val mst = "2+2*(2+2)".parseMath() val res = ComplexField.evaluate(mst) assertEquals(Complex(10.0, 0.0), res) } + + @Test + fun `evaluate MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() + assertEquals(Complex(10.0, 0.0), res) + } } From e47ec1aeb9587b0f04f54f99d90c4da1df933152 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 25 Jun 2020 10:07:36 +0700 Subject: [PATCH 082/104] Delete AsmCompiledExpression abstract class, implement dynamic field generation to reduce quantity of cast instructions, minor refactor and renaming of internal APIs --- kmath-ast/README.md | 18 +- .../kmath/asm/internal/AsmBuilder.kt | 220 +++++++++--------- .../asm/internal/AsmCompiledExpression.kt | 18 -- .../kmath/asm/internal/buildName.kt | 3 +- .../kmath/asm/internal/classWriters.kt | 12 +- .../kmath/asm/internal/instructionAdapters.kt | 10 + .../kmath/asm/internal/methodVisitors.kt | 4 +- .../kmath/asm/internal/specialization.kt | 2 +- 8 files changed, 140 insertions(+), 147 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt diff --git a/kmath-ast/README.md b/kmath-ast/README.md index b5ca5886f..4563e17cf 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -24,20 +24,20 @@ For example, the following builder: package scientifik.kmath.asm.generated; import java.util.Map; -import scientifik.kmath.asm.internal.AsmCompiledExpression; -import scientifik.kmath.operations.Algebra; +import scientifik.kmath.expressions.Expression; import scientifik.kmath.operations.RealField; -// The class's name is build with MST's hash-code and collision fixing number. -public final class AsmCompiledExpression_45045_0 extends AsmCompiledExpression { - // Plain constructor - public AsmCompiledExpression_45045_0(Algebra algebra, Object[] constants) { - super(algebra, constants); +public final class AsmCompiledExpression_1073786867_0 implements Expression { + private final RealField algebra; + private final Object[] constants; + + public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) { + this.algebra = algebra; + this.constants = constants; } - // The actual dynamic code: public final Double invoke(Map arguments) { - return (Double)((RealField)super.algebra).add((Double)arguments.get("x"), (Double)2.0D); + return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); } } ``` diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 8f45c4044..536d6136d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.Opcodes.RETURN import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import java.util.* import kotlin.reflect.KClass @@ -36,32 +37,27 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - @Suppress("PrivatePropertyName") - private val T_ALGEBRA_TYPE: Type = algebra::class.asm - - @Suppress("PrivatePropertyName") - internal val T_TYPE: Type = classOfT.asm - - @Suppress("PrivatePropertyName") - private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + private val tAlgebraType: Type = algebra::class.asm + internal val tType: Type = classOfT.asm + private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** - * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `this` variable in invoke method of the built subclass. */ private val invokeThisVar: Int = 0 /** - * Index of `arguments` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `arguments` variable in invoke method of the built subclass. */ private val invokeArgumentsVar: Int = 1 /** - * List of constants to provide to [AsmCompiledExpression] subclass. + * List of constants to provide to the subclass. */ private val constants: MutableList = mutableListOf() /** - * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. + * Method visitor of `invoke` method of the subclass. */ private lateinit var invokeMethodVisitor: InstructionAdapter internal var primitiveMode = false @@ -72,78 +68,92 @@ internal class AsmBuilder internal constructor( @Suppress("PropertyName") internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE private val typeStack = Stack() - internal val expectationStack = Stack().apply { push(T_TYPE) } + internal val expectationStack: Stack = Stack().apply { push(tType) } /** - * The cache of [AsmCompiledExpression] subclass built by this builder. + * The cache for instance built by this builder. */ - private var generatedInstance: AsmCompiledExpression? = null + private var generatedInstance: Expression? = null /** - * Subclasses, loads and instantiates the [AsmCompiledExpression] for given parameters. + * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - fun getInstance(): AsmCompiledExpression { + fun getInstance(): Expression { generatedInstance?.let { return it } - if (SIGNATURE_LETTERS.containsKey(classOfT.java)) { + if (SIGNATURE_LETTERS.containsKey(classOfT)) { primitiveMode = true - PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java) - PRIMITIVE_MASK_BOXED = T_TYPE + PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) + PRIMITIVE_MASK_BOXED = tType } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - CLASS_TYPE.internalName, - "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;", - ASM_COMPILED_EXPRESSION_TYPE.internalName, - arrayOf() + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName) + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "algebra", + descriptor = tAlgebraType.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd ) visitMethod( Opcodes.ACC_PUBLIC, "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), null, null ).instructionAdapter { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 - val l0 = Label() - visitLabel(l0) - load(thisVar, CLASS_TYPE) - load(algebraVar, ALGEBRA_TYPE) + val l0 = label() + load(thisVar, classType) + invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + label() + load(thisVar, classType) + load(algebraVar, tAlgebraType) + putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + label() + load(thisVar, classType) load(constantsVar, OBJECT_ARRAY_TYPE) - - invokespecial( - ASM_COMPILED_EXPRESSION_TYPE.internalName, - "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), - false - ) - - val l1 = Label() - visitLabel(l1) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + label() visitInsn(RETURN) - val l2 = Label() - visitLabel(l2) - visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) visitLocalVariable( "algebra", - ALGEBRA_TYPE.descriptor, - "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;", + tAlgebraType.descriptor, + null, l0, - l2, + l4, algebraVar ) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) visitMaxs(0, 3) visitEnd() } @@ -151,22 +161,20 @@ internal class AsmBuilder internal constructor( visitMethod( Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, "invoke", - Type.getMethodDescriptor(T_TYPE, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}", + Type.getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", null ).instructionAdapter { invokeMethodVisitor = this visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() invokeLabel0Visitor() - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -176,7 +184,7 @@ internal class AsmBuilder internal constructor( visitLocalVariable( "arguments", MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;", + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", l0, l1, invokeArgumentsVar @@ -196,18 +204,16 @@ internal class AsmBuilder internal constructor( val thisVar = 0 val argumentsVar = 1 visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() load(thisVar, OBJECT_TYPE) load(argumentsVar, MAP_TYPE) - invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false) - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -225,7 +231,7 @@ internal class AsmBuilder internal constructor( .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as Expression generatedInstance = new return new @@ -235,21 +241,21 @@ internal class AsmBuilder internal constructor( * Loads a constant from */ internal fun loadTConstant(value: T) { - if (classOfT.java in INLINABLE_NUMBERS) { + if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop()!! val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) - if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) return } - loadConstant(value as Any, T_TYPE) + loadConstant(value as Any, tType) } private fun box(): Unit = invokeMethodVisitor.invokestatic( - T_TYPE.internalName, + tType.internalName, "valueOf", - Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK), + Type.getMethodDescriptor(tType, PRIMITIVE_MASK), false ) @@ -263,16 +269,16 @@ internal class AsmBuilder internal constructor( private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() - getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) iconst(idx) visitInsn(AALOAD) checkcast(type) } - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE) + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) /** - * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * from it). */ @@ -292,7 +298,7 @@ internal class AsmBuilder internal constructor( if (mustBeBoxed) { box() - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) } return @@ -300,11 +306,11 @@ internal class AsmBuilder internal constructor( loadConstant(value, boxed) if (!mustBeBoxed) unbox() - else invokeMethodVisitor.checkcast(T_TYPE) + else invokeMethodVisitor.checkcast(tType) } /** - * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. + * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) @@ -319,7 +325,7 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) return } @@ -331,11 +337,11 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -343,15 +349,11 @@ internal class AsmBuilder internal constructor( } /** - * Loads algebra from according field of [AsmCompiledExpression] and casts it to class of [algebra] provided. + * Loads algebra from according field of the class and casts it to class of [algebra] provided. */ internal fun loadAlgebra() { loadThis() - - invokeMethodVisitor.run { - getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor) - checkcast(T_ALGEBRA_TYPE) - } + invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) } /** @@ -368,7 +370,12 @@ internal class AsmBuilder internal constructor( tArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { - repeat(tArity) { if (!typeStack.empty()) typeStack.pop() } + run loop@{ + repeat(tArity) { + if (typeStack.empty()) return@loop + typeStack.pop() + } + } invokeMethodVisitor.visitMethodInsn( opcode, @@ -378,12 +385,12 @@ internal class AsmBuilder internal constructor( opcode == Opcodes.INVOKEINTERFACE ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val isLastExpr = expectationStack.size == 1 val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -399,27 +406,18 @@ internal class AsmBuilder internal constructor( /** * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class.java to Type.BYTE_TYPE, - java.lang.Short::class.java to Type.SHORT_TYPE, - java.lang.Integer::class.java to Type.INT_TYPE, - java.lang.Long::class.java to Type.LONG_TYPE, - java.lang.Float::class.java to Type.FLOAT_TYPE, - java.lang.Double::class.java to Type.DOUBLE_TYPE + java.lang.Byte::class to Type.BYTE_TYPE, + java.lang.Short::class to Type.SHORT_TYPE, + java.lang.Integer::class to Type.INT_TYPE, + java.lang.Long::class to Type.LONG_TYPE, + java.lang.Float::class to Type.FLOAT_TYPE, + java.lang.Double::class to Type.DOUBLE_TYPE ) } - private val BOXED_TO_PRIMITIVES: Map by lazy { - hashMapOf( - java.lang.Byte::class.asm to Type.BYTE_TYPE, - java.lang.Short::class.asm to Type.SHORT_TYPE, - java.lang.Integer::class.asm to Type.INT_TYPE, - java.lang.Long::class.asm to Type.LONG_TYPE, - java.lang.Float::class.asm to Type.FLOAT_TYPE, - java.lang.Double::class.asm to Type.DOUBLE_TYPE - ) - } + private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } private val NUMBER_CONVERTER_METHODS: Map by lazy { hashMapOf( @@ -435,15 +433,15 @@ internal class AsmBuilder internal constructor( /** * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm - internal val NUMBER_TYPE: Type = java.lang.Number::class.asm - internal val MAP_TYPE: Type = java.util.Map::class.asm - internal val OBJECT_TYPE: Type = java.lang.Object::class.asm + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type = Array::class.asm - internal val ALGEBRA_TYPE: Type = Algebra::class.asm - internal val STRING_TYPE: Type = java.lang.String::class.asm + internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt deleted file mode 100644 index 7c4a9fc99..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ /dev/null @@ -1,18 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra - -/** - * [Expression] partial implementation to have it subclassed by actual implementations. Provides unified storage for - * objects needed to implement the expression. - * - * @property algebra the algebra to delegate calls. - * @property constants the constants array to have persistent objects to reference in [invoke]. - */ -internal abstract class AsmCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: Array -) : Expression { - abstract override fun invoke(arguments: Map): T -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt index 66bd039c3..41dbf5807 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -1,9 +1,10 @@ package scientifik.kmath.asm.internal import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression /** - * Creates a class name for [AsmCompiledExpression] subclassed to implement [mst] provided. + * Creates a class name for [Expression] subclassed to implement [mst] provided. * * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index 95d713b18..af5c1049d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -1,15 +1,17 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter +import org.objectweb.asm.FieldVisitor import org.objectweb.asm.MethodVisitor -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) -internal inline fun ClassWriter.visitMethod( +internal inline fun ClassWriter.visitField( access: Int, name: String, descriptor: String, signature: String?, - exceptions: Array?, - block: MethodVisitor.() -> Unit -): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block) + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt new file mode 100644 index 000000000..f47293687 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt @@ -0,0 +1,10 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Label +import org.objectweb.asm.commons.InstructionAdapter + +internal fun InstructionAdapter.label(): Label { + val l = Label() + visitLabel(l) + return l +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 7b0d346b7..aaae02ebb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor import org.objectweb.asm.commons.InstructionAdapter -fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) +internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) -fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 2e15a1a93..4c7a0d57e 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -22,7 +22,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: val aName = methodNameAdapters[name] ?: name val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE + val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType repeat(arity) { expectationStack.push(t) } return hasSpecific From f7f9ce7817cb66c8b5af6e0a590c5332dc3dc1f0 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 25 Jun 2020 10:07:36 +0700 Subject: [PATCH 083/104] Delete AsmCompiledExpression abstract class, implement dynamic field generation to reduce quantity of cast instructions, minor refactor and renaming of internal APIs --- kmath-ast/README.md | 18 +- .../kmath/asm/internal/AsmBuilder.kt | 220 +++++++++--------- .../asm/internal/AsmCompiledExpression.kt | 18 -- .../kmath/asm/internal/buildName.kt | 3 +- .../kmath/asm/internal/classWriters.kt | 12 +- .../kmath/asm/internal/instructionAdapters.kt | 10 + .../kmath/asm/internal/methodVisitors.kt | 4 +- .../kmath/asm/internal/specialization.kt | 2 +- 8 files changed, 140 insertions(+), 147 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt diff --git a/kmath-ast/README.md b/kmath-ast/README.md index b5ca5886f..4563e17cf 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -24,20 +24,20 @@ For example, the following builder: package scientifik.kmath.asm.generated; import java.util.Map; -import scientifik.kmath.asm.internal.AsmCompiledExpression; -import scientifik.kmath.operations.Algebra; +import scientifik.kmath.expressions.Expression; import scientifik.kmath.operations.RealField; -// The class's name is build with MST's hash-code and collision fixing number. -public final class AsmCompiledExpression_45045_0 extends AsmCompiledExpression { - // Plain constructor - public AsmCompiledExpression_45045_0(Algebra algebra, Object[] constants) { - super(algebra, constants); +public final class AsmCompiledExpression_1073786867_0 implements Expression { + private final RealField algebra; + private final Object[] constants; + + public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) { + this.algebra = algebra; + this.constants = constants; } - // The actual dynamic code: public final Double invoke(Map arguments) { - return (Double)((RealField)super.algebra).add((Double)arguments.get("x"), (Double)2.0D); + return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); } } ``` diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 8f45c4044..536d6136d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.Opcodes.RETURN import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import java.util.* import kotlin.reflect.KClass @@ -36,32 +37,27 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - @Suppress("PrivatePropertyName") - private val T_ALGEBRA_TYPE: Type = algebra::class.asm - - @Suppress("PrivatePropertyName") - internal val T_TYPE: Type = classOfT.asm - - @Suppress("PrivatePropertyName") - private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + private val tAlgebraType: Type = algebra::class.asm + internal val tType: Type = classOfT.asm + private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** - * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `this` variable in invoke method of the built subclass. */ private val invokeThisVar: Int = 0 /** - * Index of `arguments` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `arguments` variable in invoke method of the built subclass. */ private val invokeArgumentsVar: Int = 1 /** - * List of constants to provide to [AsmCompiledExpression] subclass. + * List of constants to provide to the subclass. */ private val constants: MutableList = mutableListOf() /** - * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. + * Method visitor of `invoke` method of the subclass. */ private lateinit var invokeMethodVisitor: InstructionAdapter internal var primitiveMode = false @@ -72,78 +68,92 @@ internal class AsmBuilder internal constructor( @Suppress("PropertyName") internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE private val typeStack = Stack() - internal val expectationStack = Stack().apply { push(T_TYPE) } + internal val expectationStack: Stack = Stack().apply { push(tType) } /** - * The cache of [AsmCompiledExpression] subclass built by this builder. + * The cache for instance built by this builder. */ - private var generatedInstance: AsmCompiledExpression? = null + private var generatedInstance: Expression? = null /** - * Subclasses, loads and instantiates the [AsmCompiledExpression] for given parameters. + * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - fun getInstance(): AsmCompiledExpression { + fun getInstance(): Expression { generatedInstance?.let { return it } - if (SIGNATURE_LETTERS.containsKey(classOfT.java)) { + if (SIGNATURE_LETTERS.containsKey(classOfT)) { primitiveMode = true - PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java) - PRIMITIVE_MASK_BOXED = T_TYPE + PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) + PRIMITIVE_MASK_BOXED = tType } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - CLASS_TYPE.internalName, - "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;", - ASM_COMPILED_EXPRESSION_TYPE.internalName, - arrayOf() + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName) + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "algebra", + descriptor = tAlgebraType.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd ) visitMethod( Opcodes.ACC_PUBLIC, "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), null, null ).instructionAdapter { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 - val l0 = Label() - visitLabel(l0) - load(thisVar, CLASS_TYPE) - load(algebraVar, ALGEBRA_TYPE) + val l0 = label() + load(thisVar, classType) + invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + label() + load(thisVar, classType) + load(algebraVar, tAlgebraType) + putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + label() + load(thisVar, classType) load(constantsVar, OBJECT_ARRAY_TYPE) - - invokespecial( - ASM_COMPILED_EXPRESSION_TYPE.internalName, - "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), - false - ) - - val l1 = Label() - visitLabel(l1) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + label() visitInsn(RETURN) - val l2 = Label() - visitLabel(l2) - visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) visitLocalVariable( "algebra", - ALGEBRA_TYPE.descriptor, - "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;", + tAlgebraType.descriptor, + null, l0, - l2, + l4, algebraVar ) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) visitMaxs(0, 3) visitEnd() } @@ -151,22 +161,20 @@ internal class AsmBuilder internal constructor( visitMethod( Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, "invoke", - Type.getMethodDescriptor(T_TYPE, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}", + Type.getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", null ).instructionAdapter { invokeMethodVisitor = this visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() invokeLabel0Visitor() - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -176,7 +184,7 @@ internal class AsmBuilder internal constructor( visitLocalVariable( "arguments", MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;", + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", l0, l1, invokeArgumentsVar @@ -196,18 +204,16 @@ internal class AsmBuilder internal constructor( val thisVar = 0 val argumentsVar = 1 visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() load(thisVar, OBJECT_TYPE) load(argumentsVar, MAP_TYPE) - invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false) - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -225,7 +231,7 @@ internal class AsmBuilder internal constructor( .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as Expression generatedInstance = new return new @@ -235,21 +241,21 @@ internal class AsmBuilder internal constructor( * Loads a constant from */ internal fun loadTConstant(value: T) { - if (classOfT.java in INLINABLE_NUMBERS) { + if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop()!! val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) - if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) return } - loadConstant(value as Any, T_TYPE) + loadConstant(value as Any, tType) } private fun box(): Unit = invokeMethodVisitor.invokestatic( - T_TYPE.internalName, + tType.internalName, "valueOf", - Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK), + Type.getMethodDescriptor(tType, PRIMITIVE_MASK), false ) @@ -263,16 +269,16 @@ internal class AsmBuilder internal constructor( private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() - getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) iconst(idx) visitInsn(AALOAD) checkcast(type) } - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE) + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) /** - * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * from it). */ @@ -292,7 +298,7 @@ internal class AsmBuilder internal constructor( if (mustBeBoxed) { box() - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) } return @@ -300,11 +306,11 @@ internal class AsmBuilder internal constructor( loadConstant(value, boxed) if (!mustBeBoxed) unbox() - else invokeMethodVisitor.checkcast(T_TYPE) + else invokeMethodVisitor.checkcast(tType) } /** - * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. + * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) @@ -319,7 +325,7 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) return } @@ -331,11 +337,11 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -343,15 +349,11 @@ internal class AsmBuilder internal constructor( } /** - * Loads algebra from according field of [AsmCompiledExpression] and casts it to class of [algebra] provided. + * Loads algebra from according field of the class and casts it to class of [algebra] provided. */ internal fun loadAlgebra() { loadThis() - - invokeMethodVisitor.run { - getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor) - checkcast(T_ALGEBRA_TYPE) - } + invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) } /** @@ -368,7 +370,12 @@ internal class AsmBuilder internal constructor( tArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { - repeat(tArity) { if (!typeStack.empty()) typeStack.pop() } + run loop@{ + repeat(tArity) { + if (typeStack.empty()) return@loop + typeStack.pop() + } + } invokeMethodVisitor.visitMethodInsn( opcode, @@ -378,12 +385,12 @@ internal class AsmBuilder internal constructor( opcode == Opcodes.INVOKEINTERFACE ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val isLastExpr = expectationStack.size == 1 val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -399,27 +406,18 @@ internal class AsmBuilder internal constructor( /** * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class.java to Type.BYTE_TYPE, - java.lang.Short::class.java to Type.SHORT_TYPE, - java.lang.Integer::class.java to Type.INT_TYPE, - java.lang.Long::class.java to Type.LONG_TYPE, - java.lang.Float::class.java to Type.FLOAT_TYPE, - java.lang.Double::class.java to Type.DOUBLE_TYPE + java.lang.Byte::class to Type.BYTE_TYPE, + java.lang.Short::class to Type.SHORT_TYPE, + java.lang.Integer::class to Type.INT_TYPE, + java.lang.Long::class to Type.LONG_TYPE, + java.lang.Float::class to Type.FLOAT_TYPE, + java.lang.Double::class to Type.DOUBLE_TYPE ) } - private val BOXED_TO_PRIMITIVES: Map by lazy { - hashMapOf( - java.lang.Byte::class.asm to Type.BYTE_TYPE, - java.lang.Short::class.asm to Type.SHORT_TYPE, - java.lang.Integer::class.asm to Type.INT_TYPE, - java.lang.Long::class.asm to Type.LONG_TYPE, - java.lang.Float::class.asm to Type.FLOAT_TYPE, - java.lang.Double::class.asm to Type.DOUBLE_TYPE - ) - } + private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } private val NUMBER_CONVERTER_METHODS: Map by lazy { hashMapOf( @@ -435,15 +433,15 @@ internal class AsmBuilder internal constructor( /** * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm - internal val NUMBER_TYPE: Type = java.lang.Number::class.asm - internal val MAP_TYPE: Type = java.util.Map::class.asm - internal val OBJECT_TYPE: Type = java.lang.Object::class.asm + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type = Array::class.asm - internal val ALGEBRA_TYPE: Type = Algebra::class.asm - internal val STRING_TYPE: Type = java.lang.String::class.asm + internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt deleted file mode 100644 index 7c4a9fc99..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ /dev/null @@ -1,18 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra - -/** - * [Expression] partial implementation to have it subclassed by actual implementations. Provides unified storage for - * objects needed to implement the expression. - * - * @property algebra the algebra to delegate calls. - * @property constants the constants array to have persistent objects to reference in [invoke]. - */ -internal abstract class AsmCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: Array -) : Expression { - abstract override fun invoke(arguments: Map): T -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt index 66bd039c3..41dbf5807 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -1,9 +1,10 @@ package scientifik.kmath.asm.internal import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression /** - * Creates a class name for [AsmCompiledExpression] subclassed to implement [mst] provided. + * Creates a class name for [Expression] subclassed to implement [mst] provided. * * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index 95d713b18..af5c1049d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -1,15 +1,17 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter +import org.objectweb.asm.FieldVisitor import org.objectweb.asm.MethodVisitor -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) -internal inline fun ClassWriter.visitMethod( +internal inline fun ClassWriter.visitField( access: Int, name: String, descriptor: String, signature: String?, - exceptions: Array?, - block: MethodVisitor.() -> Unit -): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block) + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt new file mode 100644 index 000000000..f47293687 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt @@ -0,0 +1,10 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Label +import org.objectweb.asm.commons.InstructionAdapter + +internal fun InstructionAdapter.label(): Label { + val l = Label() + visitLabel(l) + return l +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 7b0d346b7..aaae02ebb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor import org.objectweb.asm.commons.InstructionAdapter -fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) +internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) -fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 2e15a1a93..4c7a0d57e 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -22,7 +22,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: val aName = methodNameAdapters[name] ?: name val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE + val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType repeat(arity) { expectationStack.push(t) } return hasSpecific From c9de04a6106384c0128b4d0445b3e00bc0216379 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 25 Jun 2020 10:24:21 +0700 Subject: [PATCH 084/104] Make benchmarks 'naive' --- .../ast/ExpressionsInterpretersBenchmark.kt | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) rename examples/src/{benchmarks => main}/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt (70%) diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt similarity index 70% rename from examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt rename to examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt index c5474e1d2..17a70a4aa 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -1,21 +1,16 @@ package scientifik.kmath.ast -import org.openjdk.jmh.annotations.Benchmark -import org.openjdk.jmh.annotations.Scope -import org.openjdk.jmh.annotations.State import scientifik.kmath.asm.compile import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.expressionInField +import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Field import scientifik.kmath.operations.RealField import kotlin.random.Random +import kotlin.system.measureTimeMillis -@State(Scope.Benchmark) class ExpressionsInterpretersBenchmark { private val algebra: Field = RealField - private val random: Random = Random(1) - - @Benchmark fun functionalExpression() { val expr = algebra.expressionInField { variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) @@ -24,7 +19,6 @@ class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } - @Benchmark fun mstExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -33,7 +27,6 @@ class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } - @Benchmark fun asmExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -43,6 +36,7 @@ class ExpressionsInterpretersBenchmark { } private fun invokeAndSum(expr: Expression) { + val random = Random(0) var sum = 0.0 repeat(1000000) { @@ -52,3 +46,25 @@ class ExpressionsInterpretersBenchmark { println(sum) } } + +fun main() { + val benchmark = ExpressionsInterpretersBenchmark() + + val fe = measureTimeMillis { + benchmark.functionalExpression() + } + + println("fe=$fe") + + val mst = measureTimeMillis { + benchmark.mstExpression() + } + + println("mst=$mst") + + val asm = measureTimeMillis { + benchmark.asmExpression() + } + + println("asm=$asm") +} From b11a7f1426a707168d8174be4be43fd1d352fed3 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:29:31 +0700 Subject: [PATCH 085/104] Update README.md --- kmath-ast/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 4563e17cf..0e375f14d 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -48,12 +48,11 @@ This API is an extension to MST and MSTExpression APIs. You may optimize both MS ```kotlin RealField.mstInField { symbol("x") + 2 }.compile() -RealField.expression("2+2".parseMath()) +RealField.expression("x+2".parseMath()) ``` ### Known issues -- Using numeric algebras causes boxing and calling bridge methods. - The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead. - This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. From 23816d33665314016dfc46b67a34d09568e56640 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:42:13 +0700 Subject: [PATCH 086/104] Update KDoc comments, optimize imports --- .../kmath/asm/internal/AsmBuilder.kt | 101 +++++++++++++++--- .../kmath/asm/internal/classWriters.kt | 1 - .../kmath/asm/internal/specialization.kt | 4 +- 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 536d6136d..a291ba4ee 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -18,6 +18,7 @@ import kotlin.reflect.KClass * @param T the type of AsmExpression to unwrap. * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. + * @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. */ internal class AsmBuilder internal constructor( private val classOfT: KClass<*>, @@ -37,8 +38,19 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + /** + * ASM Type for [algebra] + */ private val tAlgebraType: Type = algebra::class.asm + + /** + * ASM type for [T] + */ internal val tType: Type = classOfT.asm + + /** + * ASM type for new class + */ private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** @@ -60,14 +72,30 @@ internal class AsmBuilder internal constructor( * Method visitor of `invoke` method of the subclass. */ private lateinit var invokeMethodVisitor: InstructionAdapter + + /** + * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. + */ internal var primitiveMode = false - @Suppress("PropertyName") - internal var PRIMITIVE_MASK: Type = OBJECT_TYPE + /** + * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMask: Type = OBJECT_TYPE - @Suppress("PropertyName") - internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE + /** + * Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMaskBoxed: Type = OBJECT_TYPE + + /** + * Stack of useful objects types on stack to verify types. + */ private val typeStack = Stack() + + /** + * Stack of useful objects types on stack expected by algebra calls. + */ internal val expectationStack: Stack = Stack().apply { push(tType) } /** @@ -86,8 +114,8 @@ internal class AsmBuilder internal constructor( if (SIGNATURE_LETTERS.containsKey(classOfT)) { primitiveMode = true - PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) - PRIMITIVE_MASK_BOXED = tType + primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) + primitiveMaskBoxed = tType } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { @@ -238,34 +266,43 @@ internal class AsmBuilder internal constructor( } /** - * Loads a constant from + * Loads a [T] constant from [constants]. */ internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop()!! val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) - if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) return } loadConstant(value as Any, tType) } + /** + * Boxes the current value and pushes it. + */ private fun box(): Unit = invokeMethodVisitor.invokestatic( tType.internalName, "valueOf", - Type.getMethodDescriptor(tType, PRIMITIVE_MASK), + Type.getMethodDescriptor(tType, primitiveMask), false ) + /** + * Unboxes the current boxed value and pushes it. + */ private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK), - Type.getMethodDescriptor(PRIMITIVE_MASK), + NUMBER_CONVERTER_METHODS.getValue(primitiveMask), + Type.getMethodDescriptor(primitiveMask), false ) + /** + * Loads [java.lang.Object] constant from constants. + */ private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() @@ -275,6 +312,9 @@ internal class AsmBuilder internal constructor( checkcast(type) } + /** + * Loads this variable. + */ private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) /** @@ -344,7 +384,7 @@ internal class AsmBuilder internal constructor( typeStack.push(tType) else { unbox() - typeStack.push(PRIMITIVE_MASK) + typeStack.push(primitiveMask) } } @@ -393,7 +433,7 @@ internal class AsmBuilder internal constructor( typeStack.push(tType) else { unbox() - typeStack.push(PRIMITIVE_MASK) + typeStack.push(primitiveMask) } } @@ -404,7 +444,7 @@ internal class AsmBuilder internal constructor( internal companion object { /** - * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. + * Maps JVM primitive numbers boxed types to their primitive ASM types. */ private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( @@ -417,8 +457,14 @@ internal class AsmBuilder internal constructor( ) } + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + /** + * Maps primitive ASM types to [Number] functions unboxing them. + */ private val NUMBER_CONVERTER_METHODS: Map by lazy { hashMapOf( Type.BYTE_TYPE to "byteValue", @@ -434,14 +480,41 @@ internal class AsmBuilder internal constructor( * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + + /** + * ASM type for [Expression]. + */ internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + + /** + * ASM type for [java.lang.Number]. + */ internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + + /** + * ASM type for [java.util.Map]. + */ internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + + /** + * ASM type for [java.lang.Object]. + */ internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } + /** + * ASM type for array of [java.lang.Object]. + */ @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + + /** + * ASM type for [Algebra]. + */ internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + + /** + * ASM type for [java.lang.String]. + */ internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index af5c1049d..00093aaa7 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -2,7 +2,6 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter import org.objectweb.asm.FieldVisitor -import org.objectweb.asm.MethodVisitor internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 4c7a0d57e..252509d59 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -22,7 +22,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: val aName = methodNameAdapters[name] ?: name val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType + val t = if (primitiveMode && hasSpecific) primitiveMask else tType repeat(arity) { expectationStack.push(t) } return hasSpecific @@ -52,7 +52,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri invokeAlgebraOperation( owner = owner, method = aName, - descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }), + descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), tArity = arity, opcode = Opcodes.INVOKEVIRTUAL ) From 46f99139e2850a4bfd63e3823b48b9e026fec017 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:45:33 +0700 Subject: [PATCH 087/104] Update number literal call in tests --- .../src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 4c2be811e..6ce769613 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -92,7 +92,7 @@ class TestAsmAlgebras { "+", (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), - 1 / 2 + number(2.0) * one + number(1) / 2 + number(2.0) * one ) }("x" to 2.0) @@ -101,7 +101,7 @@ class TestAsmAlgebras { "+", (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), - 1 / 2 + number(2.0) * one + number(1) / 2 + number(2.0) * one ) }.compile()("x" to 2.0) From 7faa48be582bc34f997b50104c3e3dc30f3ee979 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:46:32 +0700 Subject: [PATCH 088/104] Add zero call in MSTField test --- .../src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 6ce769613..079a6be0f 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -93,7 +93,7 @@ class TestAsmAlgebras { (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one - ) + ) + zero }("x" to 2.0) val res2 = RealField.mstInField { @@ -102,7 +102,7 @@ class TestAsmAlgebras { (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one - ) + ) + zero }.compile()("x" to 2.0) assertEquals(res1, res2) From 3528fa16dbc4c1750a56c7ee292044dbec2b4ef7 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 20:10:38 +0700 Subject: [PATCH 089/104] Add missing dependency in examples --- examples/build.gradle.kts | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index fb47c998f..73def3572 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -24,6 +24,7 @@ sourceSets { } dependencies { + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) From e2cc3c8efefa3b6a7d604990e1b1ca0a4860678b Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 20:54:14 +0700 Subject: [PATCH 090/104] Specify type explicitly, minor implementation refactor --- .../kotlin/scientifik/kmath/asm/asm.kt | 4 +-- .../kmath/asm/internal/AsmBuilder.kt | 9 ++--- .../kmath/asm/internal/specialization.kt | 35 ++++++++----------- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index a3af80ccd..bb456e6eb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -44,7 +44,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< AsmBuilder.OBJECT_TYPE ), - tArity = 1 + expectedArity = 1 ) } is MST.Binary -> { @@ -64,7 +64,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< AsmBuilder.OBJECT_TYPE ), - tArity = 2 + expectedArity = 2 ) } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index a291ba4ee..ebe25e19f 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -76,7 +76,7 @@ internal class AsmBuilder internal constructor( /** * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. */ - internal var primitiveMode = false + internal var primitiveMode: Boolean = false /** * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. @@ -91,7 +91,7 @@ internal class AsmBuilder internal constructor( /** * Stack of useful objects types on stack to verify types. */ - private val typeStack = Stack() + private val typeStack: Stack = Stack() /** * Stack of useful objects types on stack expected by algebra calls. @@ -345,6 +345,7 @@ internal class AsmBuilder internal constructor( } loadConstant(value, boxed) + if (!mustBeBoxed) unbox() else invokeMethodVisitor.checkcast(tType) } @@ -407,11 +408,11 @@ internal class AsmBuilder internal constructor( owner: String, method: String, descriptor: String, - tArity: Int, + expectedArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { run loop@{ - repeat(tArity) { + repeat(expectedArity) { if (typeStack.empty()) return@loop typeStack.pop() } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 252509d59..e54acf6f9 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -5,11 +5,7 @@ import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map by lazy { - hashMapOf( - "+" to "add", - "*" to "multiply", - "/" to "divide" - ) + hashMapOf("+" to "add", "*" to "multiply", "/" to "divide") } /** @@ -19,12 +15,10 @@ private val methodNameAdapters: Map by lazy { * @return `true` if contains, else `false`. */ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name - - val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null + val theName = methodNameAdapters[name] ?: name + val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType repeat(arity) { expectationStack.push(t) } - return hasSpecific } @@ -35,25 +29,24 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * @return `true` if contains, else `false`. */ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name] ?: name - val method = - context.javaClass.methods.find { - var suitableSignature = it.name == aName && it.parameters.size == arity + context.javaClass.methods.find { + var suitableSignature = it.name == theName && it.parameters.size == arity - if (primitiveMode && it.isBridge) - suitableSignature = false + if (primitiveMode && it.isBridge) + suitableSignature = false - suitableSignature - } ?: return false + suitableSignature + } ?: return false - val owner = context::class.java.name.replace('.', '/') + val owner = context::class.asm invokeAlgebraOperation( - owner = owner, - method = aName, + owner = owner.internalName, + method = theName, descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), - tArity = arity, + expectedArity = arity, opcode = Opcodes.INVOKEVIRTUAL ) From a275e74cf287bb3cf6b929b78d0db2b329f3cb43 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 14:57:07 +0700 Subject: [PATCH 091/104] Add mapping for other dynamic operations --- .../kotlin/scientifik/kmath/operations/NumberAlgebra.kt | 4 ++++ .../kotlin/scientifik/kmath/operations/OptionalOperations.kt | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index a1b845ccc..3fb57656c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -17,6 +17,10 @@ interface ExtendedFieldOperations : override fun unaryOperation(operation: String, arg: T): T = when (operation) { TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg) + TrigonometricOperations.TAN_OPERATION -> tan(arg) + InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) + InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) + InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index a0266c78b..542d6376b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -13,7 +13,7 @@ package scientifik.kmath.operations interface TrigonometricOperations : FieldOperations { fun sin(arg: T): T fun cos(arg: T): T - fun tan(arg: T): T = sin(arg) / cos(arg) + fun tan(arg: T): T companion object { const val SIN_OPERATION = "sin" From 5ab6960e9b1a0b51ffca0cf30ee5eb83c550989c Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 15:55:01 +0700 Subject: [PATCH 092/104] Add mapIntrinsics.kt, update specialization mappings --- .../kotlin/scientifik/kmath/asm/asm.kt | 1 + .../kmath/asm/internal/AsmBuilder.kt | 37 +++++++-------- .../kmath/asm/internal/mapIntrinsics.kt | 7 +++ .../kmath/asm/internal/specialization.kt | 15 +++++-- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 2 +- .../scietifik/kmath/asm/TestAsmExpressions.kt | 2 +- .../scietifik/kmath/asm/TestSpecialization.kt | 45 +++++++++++++++++++ .../kotlin/scietifik/kmath/ast/AsmTest.kt | 2 +- 8 files changed, 84 insertions(+), 27 deletions(-) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index bb456e6eb..af39d9091 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -47,6 +47,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< expectedArity = 1 ) } + is MST.Binary -> { loadAlgebra() if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index ebe25e19f..89c9dca9d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -354,31 +354,23 @@ internal class AsmBuilder internal constructor( * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) + load(invokeArgumentsVar, MAP_TYPE) + aconst(name) - if (defaultValue != null) { - loadStringConstant(name) + if (defaultValue != null) loadTConstant(defaultValue) + else + aconst(null) - invokeinterface( - MAP_TYPE.internalName, - "getOrDefault", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) - ) - - invokeMethodVisitor.checkcast(tType) - return - } - - loadStringConstant(name) - - invokeinterface( - MAP_TYPE.internalName, - "get", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE), + false ) - invokeMethodVisitor.checkcast(tType) + checkcast(tType) + val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT) @@ -517,5 +509,10 @@ internal class AsmBuilder internal constructor( * ASM type for [java.lang.String]. */ internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } + + /** + * ASM type for MapIntrinsics. + */ + internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt new file mode 100644 index 000000000..7f7126b55 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt @@ -0,0 +1,7 @@ +@file:JvmName("MapIntrinsics") + +package scientifik.kmath.asm.internal + +internal fun Map.getOrFail(key: K, default: V?): V { + return this[key] ?: default ?: error("Parameter not found: $key") +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index e54acf6f9..a8d5a605f 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -4,8 +4,15 @@ import org.objectweb.asm.Opcodes import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra -private val methodNameAdapters: Map by lazy { - hashMapOf("+" to "add", "*" to "multiply", "/" to "divide") +private val methodNameAdapters: Map, String> by lazy { + hashMapOf( + "+" to 2 to "add", + "*" to 2 to "multiply", + "/" to 2 to "divide", + "+" to 1 to "unaryPlus", + "-" to 1 to "unaryMinus", + "-" to 2 to "minus" + ) } /** @@ -15,7 +22,7 @@ private val methodNameAdapters: Map by lazy { * @return `true` if contains, else `false`. */ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { - val theName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name to arity] ?: name val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType repeat(arity) { expectationStack.push(t) } @@ -29,7 +36,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * @return `true` if contains, else `false`. */ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - val theName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name to arity] ?: name context.javaClass.methods.find { var suitableSignature = it.name == theName && it.parameters.size == arity diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 079a6be0f..3acc6eb28 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmAlgebras { +internal class TestAsmAlgebras { @Test fun space() { val res1 = ByteRing.mstInSpace { diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 824201aa7..36c254c38 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmExpressions { +internal class TestAsmExpressions { @Test fun testUnaryOperationInvocation() { val expression = RealField.mstInSpace { -symbol("x") }.compile() diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt new file mode 100644 index 000000000..f3b07df56 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt @@ -0,0 +1,45 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(2.0, res) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(-2.0, res) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(4.0, res) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(0.0, res) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(1.0, res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 08d7fff47..23203172e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField import kotlin.test.Test import kotlin.test.assertEquals -class AsmTest { +internal class AsmTest { @Test fun `compile MST`() { val mst = "2+2*(2+2)".parseMath() From 90c287d42fe2b2f7703ae4dad80da824bf05e125 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 15:59:24 +0700 Subject: [PATCH 093/104] Add tests for MapInstrinsics --- .../scietifik/kmath/asm/TestAsmVariables.kt | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt new file mode 100644 index 000000000..aafc75448 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt @@ -0,0 +1,22 @@ +package scietifik.kmath.asm + +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestAsmVariables { + @Test + fun testVariableWithoutDefault() { + val expr = ByteRing.mstInRing { symbol("x") } + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testVariableWithoutDefaultFails() { + val expr = ByteRing.mstInRing { symbol("x") } + assertFailsWith { expr() } + } +} From 092728b1c328a6e515de5c68ade5e99bc460c4e5 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 16:01:50 +0700 Subject: [PATCH 094/104] Replace Stack with ArrayDeque --- .../scientifik/kmath/asm/internal/AsmBuilder.kt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 89c9dca9d..c9f797787 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -91,12 +91,12 @@ internal class AsmBuilder internal constructor( /** * Stack of useful objects types on stack to verify types. */ - private val typeStack: Stack = Stack() + private val typeStack: ArrayDeque = ArrayDeque() /** * Stack of useful objects types on stack expected by algebra calls. */ - internal val expectationStack: Stack = Stack().apply { push(tType) } + internal val expectationStack: ArrayDeque = ArrayDeque().apply { push(tType) } /** * The cache for instance built by this builder. @@ -270,7 +270,7 @@ internal class AsmBuilder internal constructor( */ internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) @@ -371,7 +371,7 @@ internal class AsmBuilder internal constructor( checkcast(tType) - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() if (expectedType.sort == Type.OBJECT) typeStack.push(tType) @@ -405,7 +405,7 @@ internal class AsmBuilder internal constructor( ) { run loop@{ repeat(expectedArity) { - if (typeStack.empty()) return@loop + if (typeStack.isEmpty()) return@loop typeStack.pop() } } @@ -420,7 +420,7 @@ internal class AsmBuilder internal constructor( invokeMethodVisitor.checkcast(tType) val isLastExpr = expectationStack.size == 1 - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() if (expectedType.sort == Type.OBJECT || isLastExpr) typeStack.push(tType) From 2df97ca4c3e2f33680dae70181e06c3dc2fe1349 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 16:05:13 +0700 Subject: [PATCH 095/104] Update README.md, add suppression --- kmath-ast/README.md | 4 +++- .../kotlin/scientifik/kmath/asm/internal/classWriters.kt | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 0e375f14d..12d425460 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -24,6 +24,7 @@ For example, the following builder: package scientifik.kmath.asm.generated; import java.util.Map; +import scientifik.kmath.asm.internal.MapIntrinsics; import scientifik.kmath.expressions.Expression; import scientifik.kmath.operations.RealField; @@ -37,9 +38,10 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression arguments) { - return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); + return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D); } } + ``` ### Example Usage diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index 00093aaa7..7f0770b28 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -3,6 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter import org.objectweb.asm.FieldVisitor +@Suppress("FunctionName") internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) From 0ee1d31571a59659fbac33482a4e8cce16cf83fb Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 20:57:47 +0700 Subject: [PATCH 096/104] Fix MSTField and MSTRing invalid unary operation, update according ASM tests --- .../kotlin/scientifik/kmath/ast/MSTAlgebra.kt | 37 +++++++++---------- ...ialization.kt => TestAsmSpecialization.kt} | 23 ++++++------ 2 files changed, 29 insertions(+), 31 deletions(-) rename kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/{TestSpecialization.kt => TestAsmSpecialization.kt} (68%) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt index 07194a7bb..f741fc8c4 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -34,43 +34,40 @@ object MSTSpace : Space, NumericAlgebra { } object MSTRing : Ring, NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) - override val zero: MST = MSTSpace.number(0.0) override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } -object MSTField : Field{ - override fun symbol(value: String): MST = MST.Symbolic(value) - override fun number(value: Number): MST = MST.Numeric(value) - +object MSTField : Field { override val zero: MST = MSTSpace.number(0.0) override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MST.Numeric(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun divide(a: MST, b: MST): MST = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt similarity index 68% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt rename to kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt index f3b07df56..b571e076f 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt @@ -7,39 +7,40 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -internal class TestSpecialization { +internal class TestAsmSpecialization { @Test fun testUnaryPlus() { val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(2.0, res) + assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(-2.0, res) + assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(4.0, res) + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(0.0, res) + assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(1.0, res) + assertEquals(1.0, expr("x" to 2.0)) } } From d962ab4d11298fd6cd699edd584c3c12c260d11e Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 21:02:22 +0700 Subject: [PATCH 097/104] Rename and refactor MstAlgebra (ex-MSTAlgebra) (and its subclasses), MstExpression (ex-MSTExpression) --- .../scientifik/kmath/ast/MSTExpression.kt | 55 ------------------- .../ast/{MSTAlgebra.kt => MstAlgebra.kt} | 41 +++++++------- .../scientifik/kmath/ast/MstExpression.kt | 55 +++++++++++++++++++ .../kotlin/scientifik/kmath/asm/asm.kt | 6 +- 4 files changed, 78 insertions(+), 79 deletions(-) delete mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt rename kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/{MSTAlgebra.kt => MstAlgebra.kt} (62%) create mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt deleted file mode 100644 index 61703cac7..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.FunctionalExpressionField -import scientifik.kmath.expressions.FunctionalExpressionRing -import scientifik.kmath.expressions.FunctionalExpressionSpace -import scientifik.kmath.operations.* - -/** - * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. - */ -class MSTExpression(val algebra: Algebra, val mst: MST) : Expression { - - /** - * Substitute algebra raw value - */ - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra{ - override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right) - - override fun number(value: Number): T = if(algebra is NumericAlgebra){ - algebra.number(value) - } else{ - error("Numeric nodes are not supported by $this") - } - } - - override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} - - -inline fun , E : Algebra> A.mst( - mstAlgebra: E, - block: E.() -> MST -): MSTExpression = MSTExpression(this, mstAlgebra.block()) - -inline fun Space.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - MSTExpression(this, MSTSpace.block()) - -inline fun Ring.mstInRing(block: MSTRing.() -> MST): MSTExpression = - MSTExpression(this, MSTRing.block()) - -inline fun Field.mstInField(block: MSTField.() -> MST): MSTExpression = - MSTExpression(this, MSTField.block()) - -inline fun > FunctionalExpressionSpace.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - algebra.mstInSpace(block) - -inline fun > FunctionalExpressionRing.mstInRing(block: MSTRing.() -> MST): MSTExpression = - algebra.mstInRing(block) - -inline fun > FunctionalExpressionField.mstInField(block: MSTField.() -> MST): MSTExpression = - algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt similarity index 62% rename from kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt rename to kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt index f741fc8c4..007cf57c4 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt @@ -2,7 +2,7 @@ package scientifik.kmath.ast import scientifik.kmath.operations.* -object MSTAlgebra : NumericAlgebra { +object MstAlgebra : NumericAlgebra { override fun number(value: Number): MST = MST.Numeric(value) override fun symbol(value: String): MST = MST.Symbolic(value) @@ -14,12 +14,11 @@ object MSTAlgebra : NumericAlgebra { MST.Binary(operation, left, right) } -object MSTSpace : Space, NumericAlgebra { +object MstSpace : Space, NumericAlgebra { override val zero: MST = number(0.0) - override fun number(value: Number): MST = MST.Numeric(value) - - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) @@ -28,46 +27,46 @@ object MSTSpace : Space, NumericAlgebra { binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } -object MSTRing : Ring, NumericAlgebra { - override val zero: MST = MSTSpace.number(0.0) +object MstRing : Ring, NumericAlgebra { + override val zero: MST = number(0.0) override val one: MST = number(1.0) - override fun number(value: Number): MST = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } -object MSTField : Field { - override val zero: MST = MSTSpace.number(0.0) +object MstField : Field { + override val zero: MST = number(0.0) override val one: MST = number(1.0) - override fun symbol(value: String): MST = MST.Symbolic(value) - override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun number(value: Number): MST = MstAlgebra.number(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt new file mode 100644 index 000000000..1468c3ad4 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -0,0 +1,55 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.FunctionalExpressionField +import scientifik.kmath.expressions.FunctionalExpressionRing +import scientifik.kmath.expressions.FunctionalExpressionSpace +import scientifik.kmath.operations.* + +/** + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. + */ +class MstExpression(val algebra: Algebra, val mst: MST) : Expression { + + /** + * Substitute algebra raw value + */ + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T = + algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if (algebra is NumericAlgebra) + algebra.number(value) + else + error("Numeric nodes are not supported by $this") + } + + override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} + + +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MstExpression = MstExpression(this, mstAlgebra.block()) + +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = + MstExpression(this, MstSpace.block()) + +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = + MstExpression(this, MstRing.block()) + +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = + MstExpression(this, MstField.block()) + +inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression = + algebra.mstInSpace(block) + +inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression = + algebra.mstInRing(block) + +inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression = + algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index af39d9091..468ed01ba 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -6,7 +6,7 @@ import scientifik.kmath.asm.internal.buildExpectationStack import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTExpression +import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra @@ -80,6 +80,6 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) /** - * Optimize performance of an [MSTExpression] using ASM codegen + * Optimize performance of an [MstExpression] using ASM codegen */ -inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) +inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra) From ec46f5cf229109fe266e888f29b301d62bb2d3ea Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 21:02:31 +0700 Subject: [PATCH 098/104] Update README.md --- kmath-ast/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 12d425460..62b18b4b5 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -46,7 +46,7 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression Date: Fri, 26 Jun 2020 21:39:39 +0700 Subject: [PATCH 099/104] Add explicit toRegex call to have better IDE support --- .../kotlin/scientifik/kmath/ast/parser.kt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt index cec61a8ff..30a92c5ae 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt @@ -16,15 +16,15 @@ import scientifik.kmath.operations.SpaceOperations * TODO move to common */ private object ArithmeticsEvaluator : Grammar() { - val num by token("-?[\\d.]+(?:[eE]-?\\d+)?") - val lpar by token("\\(") - val rpar by token("\\)") - val mul by token("\\*") - val pow by token("\\^") - val div by token("/") - val minus by token("-") - val plus by token("\\+") - val ws by token("\\s+", ignore = true) + val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) + val lpar by token("\\(".toRegex()) + val rpar by token("\\)".toRegex()) + val mul by token("\\*".toRegex()) + val pow by token("\\^".toRegex()) + val div by token("/".toRegex()) + val minus by token("-".toRegex()) + val plus by token("\\+".toRegex()) + val ws by token("\\s+".toRegex(), ignore = true) val number: Parser by num use { MST.Numeric(text.toDouble()) } From bf89aa09e561da09bde2df50355d7adf312c75cc Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 22:05:42 +0700 Subject: [PATCH 100/104] Add static imports for Opcodes --- .../kmath/asm/internal/AsmBuilder.kt | 26 +++++++++---------- .../kmath/asm/internal/specialization.kt | 4 +-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index c9f797787..cea6be933 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -1,8 +1,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.AALOAD -import org.objectweb.asm.Opcodes.RETURN +import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST @@ -120,8 +119,8 @@ internal class AsmBuilder internal constructor( val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( - Opcodes.V1_8, - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, classType.internalName, "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", OBJECT_TYPE.internalName, @@ -129,7 +128,7 @@ internal class AsmBuilder internal constructor( ) visitField( - access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + access = ACC_PRIVATE or ACC_FINAL, name = "algebra", descriptor = tAlgebraType.descriptor, signature = null, @@ -138,7 +137,7 @@ internal class AsmBuilder internal constructor( ) visitField( - access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + access = ACC_PRIVATE or ACC_FINAL, name = "constants", descriptor = OBJECT_ARRAY_TYPE.descriptor, signature = null, @@ -147,7 +146,7 @@ internal class AsmBuilder internal constructor( ) visitMethod( - Opcodes.ACC_PUBLIC, + ACC_PUBLIC, "", Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), null, @@ -187,7 +186,7 @@ internal class AsmBuilder internal constructor( } visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + ACC_PUBLIC or ACC_FINAL, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", @@ -223,7 +222,7 @@ internal class AsmBuilder internal constructor( } visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, "invoke", Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), null, @@ -351,7 +350,8 @@ internal class AsmBuilder internal constructor( } /** - * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be + * provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, MAP_TYPE) @@ -391,7 +391,7 @@ internal class AsmBuilder internal constructor( /** * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is - * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be + * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be * called before the arguments and this operation. * * The result is casted to [T] automatically. @@ -401,7 +401,7 @@ internal class AsmBuilder internal constructor( method: String, descriptor: String, expectedArity: Int, - opcode: Int = Opcodes.INVOKEINTERFACE + opcode: Int = INVOKEINTERFACE ) { run loop@{ repeat(expectedArity) { @@ -415,7 +415,7 @@ internal class AsmBuilder internal constructor( owner, method, descriptor, - opcode == Opcodes.INVOKEINTERFACE + opcode == INVOKEINTERFACE ) invokeMethodVisitor.checkcast(tType) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a8d5a605f..a6d2c045b 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -1,6 +1,6 @@ package scientifik.kmath.asm.internal -import org.objectweb.asm.Opcodes +import org.objectweb.asm.Opcodes.INVOKEVIRTUAL import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra @@ -54,7 +54,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri method = theName, descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), expectedArity = arity, - opcode = Opcodes.INVOKEVIRTUAL + opcode = INVOKEVIRTUAL ) return true From 4b067f7a97c2a147b046944dfb66ddb5c78b5577 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 27 Jun 2020 12:19:43 +0300 Subject: [PATCH 101/104] DoubleBuffer -> RealBuffer. Number algebra refactoring. --- doc/buffers.md | 2 +- .../kmath/structures/BufferBenchmark.kt | 4 +- .../structures/StructureReadBenchmark.kt | 2 +- .../structures/StructureWriteBenchmark.kt | 4 +- .../commons/transform/Transformations.kt | 2 +- .../kmath/domains/HyperSquareDomain.kt | 6 +- .../scientifik/kmath/linear/BufferMatrix.kt | 8 +- .../scientifik/kmath/structures/Buffers.kt | 62 +----------- .../kmath/structures/FlaggedBuffer.kt | 53 +++++++++++ .../scientifik/kmath/structures/IntBuffer.kt | 20 ++++ .../scientifik/kmath/structures/LongBuffer.kt | 19 ++++ .../{DoubleBuffer.kt => RealBuffer.kt} | 12 +-- .../kmath/structures/RealBufferField.kt | 94 +++++++++---------- .../kmath/structures/RealNDField.kt | 6 +- .../kmath/structures/ShortBuffer.kt | 20 ++++ .../scientifik/kmath/streaming/BufferFlow.kt | 8 +- .../scientifik/kmath/real/RealVector.kt | 4 +- .../scientifik/kmath/real/realBuffer.kt | 6 +- .../scientifik/kmath/real/realMatrix.kt | 10 +- .../scientifik/kmath/histogram/Histogram.kt | 6 +- .../kmath/histogram/RealHistogram.kt | 2 +- 21 files changed, 205 insertions(+), 145 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt rename kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/{DoubleBuffer.kt => RealBuffer.kt} (59%) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt diff --git a/doc/buffers.md b/doc/buffers.md index b0b7489b3..52a9df86e 100644 --- a/doc/buffers.md +++ b/doc/buffers.md @@ -2,7 +2,7 @@ Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). There are different types of buffers: -* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. +* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays. * Boxing `ListBuffer` wrapping a list * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * `MemoryBuffer` allows direct allocation of objects in continuous memory block. diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt index 9676b5e4a..e40b0c4b7 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt @@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex class BufferBenchmark { @Benchmark - fun genericDoubleBufferReadWrite() { - val buffer = DoubleBuffer(size){it.toDouble()} + fun genericRealBufferReadWrite() { + val buffer = RealBuffer(size){it.toDouble()} (0 until size).forEach { buffer[it] diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt index ecfb4ab20..a33fdb2c4 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt @@ -6,7 +6,7 @@ fun main(args: Array) { val n = 6000 val array = DoubleArray(n * n) { 1.0 } - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) val strides = DefaultStrides(intArrayOf(n, n)) val structure = BufferNDStructure(strides, buffer) diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt index 2d16cc8f4..0241f12ad 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt @@ -26,10 +26,10 @@ fun main(args: Array) { } println("Array mapping finished in $time2 millis") - val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 }) + val buffer = RealBuffer(DoubleArray(n * n) { 1.0 }) val time3 = measureTimeMillis { - val target = DoubleBuffer(DoubleArray(n * n)) + val target = RealBuffer(DoubleArray(n * n)) val res = array.forEachIndexed { index, value -> target[index] = value + 1 } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt index bcb3ea87b..eb1b5b69a 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt @@ -18,7 +18,7 @@ object Transformations { private fun Buffer.toArray(): Array = Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } - private fun Buffer.asArray() = if (this is DoubleBuffer) { + private fun Buffer.asArray() = if (this is RealBuffer) { array } else { DoubleArray(size) { i -> get(i) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt index 21912b87c..e0019c96b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt @@ -16,7 +16,7 @@ package scientifik.kmath.domains import scientifik.kmath.linear.Point -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.indices /** @@ -25,7 +25,7 @@ import scientifik.kmath.structures.indices * * @author Alexander Nozik */ -class HyperSquareDomain(private val lower: DoubleBuffer, private val upper: DoubleBuffer) : RealDomain { +class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { override operator fun contains(point: Point): Boolean = point.indices.all { i -> point[i] in lower[i]..upper[i] @@ -49,7 +49,7 @@ class HyperSquareDomain(private val lower: DoubleBuffer, private val upper: Doub else -> point[i] } } - return DoubleBuffer(*res) + return RealBuffer(*res) } override fun volume(): Double { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 73b18b810..c4c38284b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext { override val elementContext get() = RealField override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix { - val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } return BufferMatrix(rows, columns, buffer) } - override inline fun point(size: Int, initializer: (Int) -> Double): Point = DoubleBuffer(size,initializer) + override inline fun point(size: Int, initializer: (Int) -> Double): Point = RealBuffer(size,initializer) } class BufferMatrix( @@ -102,7 +102,7 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix.unsafeArray(): DoubleArray = if (this is DoubleBuffer) { + fun Buffer.unsafeArray(): DoubleArray = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } @@ -119,6 +119,6 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { companion object { - inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { + inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { val array = DoubleArray(size) { initializer(it) } - return DoubleBuffer(array) + return RealBuffer(array) } /** @@ -51,7 +51,7 @@ interface Buffer { inline fun auto(type: KClass, size: Int, crossinline initializer: (Int) -> T): Buffer { //TODO add resolution based on Annotation or companion resolution return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer @@ -93,7 +93,7 @@ interface MutableBuffer : Buffer { @Suppress("UNCHECKED_CAST") inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer { return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer @@ -109,12 +109,11 @@ interface MutableBuffer : Buffer { auto(T::class, size, initializer) val real: MutableBufferFactory = { size: Int, initializer: (Int) -> Double -> - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) } } } - inline class ListBuffer(val list: List) : Buffer { override val size: Int @@ -163,57 +162,6 @@ class ArrayBuffer(private val array: Array) : MutableBuffer { fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) -inline class ShortBuffer(val array: ShortArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Short = array[index] - - override fun set(index: Int, value: Short) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) - -} - -fun ShortArray.asBuffer() = ShortBuffer(this) - -inline class IntBuffer(val array: IntArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Int = array[index] - - override fun set(index: Int, value: Int) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = IntBuffer(array.copyOf()) - -} - -fun IntArray.asBuffer() = IntBuffer(this) - -inline class LongBuffer(val array: LongArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Long = array[index] - - override fun set(index: Int, value: Long) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) - -} - -fun LongArray.asBuffer() = LongBuffer(this) - inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { override val size: Int get() = buffer.size diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt new file mode 100644 index 000000000..749e4eeec --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt @@ -0,0 +1,53 @@ +package scientifik.kmath.structures + +import kotlin.experimental.and + +enum class ValueFlag(val mask: Byte) { + NAN(0b0000_0001), + MISSING(0b0000_0010), + NEGATIVE_INFINITY(0b0000_0100), + POSITIVE_INFINITY(0b0000_1000) +} + +/** + * A buffer with flagged values + */ +interface FlaggedBuffer : Buffer { + fun getFlag(index: Int): Byte +} + +/** + * The value is valid if all flags are down + */ +fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte() + +fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte() + +fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING) + +/** + * A real buffer which supports flags for each value like NaN or Missing + */ +class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer, Buffer { + init { + require(values.size == flags.size) { "Values and flags must have the same dimensions" } + } + + override fun getFlag(index: Int): Byte = flags[index] + + override val size: Int get() = values.size + + override fun get(index: Int): Double? = if (isValid(index)) values[index] else null + + override fun iterator(): Iterator = values.indices.asSequence().map { + if (isValid(it)) values[it] else null + }.iterator() +} + +inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { + for(i in indices){ + if(isValid(i)){ + block(values[i]) + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt new file mode 100644 index 000000000..a354c5de0 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class IntBuffer(val array: IntArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Int = array[index] + + override fun set(index: Int, value: Int) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + IntBuffer(array.copyOf()) + +} + + +fun IntArray.asBuffer() = IntBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt new file mode 100644 index 000000000..fa6229a71 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt @@ -0,0 +1,19 @@ +package scientifik.kmath.structures + +inline class LongBuffer(val array: LongArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Long = array[index] + + override fun set(index: Int, value: Long) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + LongBuffer(array.copyOf()) + +} + +fun LongArray.asBuffer() = LongBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt similarity index 59% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt index c0b7f713b..f48ace3a9 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt @@ -1,6 +1,6 @@ package scientifik.kmath.structures -inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { +inline class RealBuffer(val array: DoubleArray) : MutableBuffer { override val size: Int get() = array.size override fun get(index: Int): Double = array[index] @@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { override fun iterator() = array.iterator() override fun copy(): MutableBuffer = - DoubleBuffer(array.copyOf()) + RealBuffer(array.copyOf()) } @Suppress("FunctionName") -inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) }) +inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) @Suppress("FunctionName") -fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles) +fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) /** * Transform buffer of doubles into array for high performance operations */ val MutableBuffer.array: DoubleArray - get() = if (this is DoubleBuffer) { + get() = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } } -fun DoubleArray.asBuffer() = DoubleBuffer(this) \ No newline at end of file +fun DoubleArray.asBuffer() = RealBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 88c8c29db..a91000a2a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -9,143 +9,143 @@ import kotlin.math.* * A simple field over linear buffers of [Double] */ object RealBufferFieldOperations : ExtendedFieldOperations> { - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) + RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) + RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { val kValue = k.toDouble() - return if (a is DoubleBuffer) { + return if (a is RealBuffer) { val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) + RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) + RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) + RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) + RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) + RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) + RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } } - override fun sin(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun sin(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } } - override fun cos(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun cos(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) + RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) } } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun power(arg: Buffer, pow: Number): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) } } - override fun exp(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun exp(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) + RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) } } - override fun ln(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun ln(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) + RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } } } class RealBufferField(val size: Int) : ExtendedField> { - override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } + override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } - override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } + override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, k) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, b) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) } - override fun sin(arg: Buffer): DoubleBuffer { + override fun sin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sin(arg) } - override fun cos(arg: Buffer): DoubleBuffer { + override fun cos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cos(arg) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) } - override fun exp(arg: Buffer): DoubleBuffer { + override fun exp(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.exp(arg) } - override fun ln(arg: Buffer): DoubleBuffer { + override fun ln(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 8c1bd4239..4a5f10790 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) : override val one by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) /** * Inline transform an NDStructure to @@ -82,7 +82,7 @@ class RealNDField(override val shape: IntArray) : */ inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } - return BufferedNDFieldElement(this, DoubleBuffer(array)) + return BufferedNDFieldElement(this, RealBuffer(array)) } /** @@ -96,7 +96,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int */ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } - return BufferedNDFieldElement(context, DoubleBuffer(array)) + return BufferedNDFieldElement(context, RealBuffer(array)) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt new file mode 100644 index 000000000..f4b2f7d13 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class ShortBuffer(val array: ShortArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Short = array[index] + + override fun set(index: Int, value: Short) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + ShortBuffer(array.copyOf()) + +} + + +fun ShortArray.asBuffer() = ShortBuffer(this) \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt index bef21a680..54da66bb7 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt @@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.* import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer /** @@ -45,7 +45,7 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< /** * Specialized flow chunker for real buffer */ -fun Flow.chunked(bufferSize: Int): Flow = flow { +fun Flow.chunked(bufferSize: Int): Flow = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } if (this@chunked is BlockingRealChain) { @@ -61,13 +61,13 @@ fun Flow.chunked(bufferSize: Int): Flow = flow { array[counter] = element counter++ if (counter == bufferSize) { - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) emit(buffer) counter = 0 } } if (counter > 0) { - emit(DoubleBuffer(counter) { array[it] }) + emit(RealBuffer(counter) { array[it] }) } } } diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index 23c7e19cb..2b89904e3 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -7,7 +7,7 @@ import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField import scientifik.kmath.operations.SpaceElement import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asIterable import kotlin.math.sqrt @@ -41,7 +41,7 @@ inline class RealVector(private val point: Point) : private val spaceCache = HashMap>() inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = - RealVector(DoubleBuffer(dim, initializer)) + RealVector(RealBuffer(dim, initializer)) operator fun invoke(vararg values: Double): RealVector = values.asVector() diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt index d9ee4d90b..82c0e86b2 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt @@ -1,8 +1,8 @@ package scientifik.kmath.real -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer /** - * Simplified [DoubleBuffer] to array comparison + * Simplified [RealBuffer] to array comparison */ -fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file +fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 0f4ccf2a8..65f86eec7 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asIterable import kotlin.math.pow @@ -133,22 +133,22 @@ fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = fun Matrix.extractColumn(columnIndex: Int): RealMatrix = extractColumns(columnIndex..columnIndex) -fun Matrix.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> val column = columns[j] with(elementContext) { sum(column.asIterable()) } } -fun Matrix.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().average() } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 5199669f5..43d50ad20 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -3,7 +3,7 @@ package scientifik.kmath.histogram import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer /** * The bin in the histogram. The histogram is by definition always done in the real space @@ -43,9 +43,9 @@ interface MutableHistogram> : Histogram { fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) fun MutableHistogram.put(vararg point: Number) = - put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) + put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(DoubleBuffer(point)) +fun MutableHistogram.put(vararg point: Double) = put(RealBuffer(point)) fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index f9d815421..628a68461 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -50,7 +50,7 @@ class RealHistogram( override val dimension: Int get() = lower.size - private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } + private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks From efcfb4425366ee32c2aafb733c1dab8567ea9abe Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 20:04:22 +0700 Subject: [PATCH 102/104] Refactor Algebra call building --- .../kotlin/scientifik/kmath/asm/asm.kt | 51 +++++-------------- .../kmath/asm/internal/specialization.kt | 29 ++++++++++- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 468ed01ba..ef2330533 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,10 +1,8 @@ package scientifik.kmath.asm -import org.objectweb.asm.Type import scientifik.kmath.asm.internal.AsmBuilder -import scientifik.kmath.asm.internal.buildExpectationStack +import scientifik.kmath.asm.internal.buildAlgebraOperationCall import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression @@ -29,44 +27,21 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< loadTConstant(constant) } - is MST.Unary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) - visit(node.value) + is MST.Unary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "unaryOperation", + arity = 1 + ) { visit(node.value) } - if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "unaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - expectedArity = 1 - ) - } - - is MST.Binary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) + is MST.Binary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "binaryOperation", + arity = 2 + ) { visit(node.left) visit(node.right) - - if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "binaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - expectedArity = 2 - ) } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a6d2c045b..003dc47dd 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -21,7 +21,7 @@ private val methodNameAdapters: Map, String> by lazy { * * @return `true` if contains, else `false`. */ -internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { +private fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { val theName = methodNameAdapters[name to arity] ?: name val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType @@ -35,7 +35,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * * @return `true` if contains, else `false`. */ -internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val theName = methodNameAdapters[name to arity] ?: name context.javaClass.methods.find { @@ -59,3 +59,28 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri return true } + +internal fun AsmBuilder.buildAlgebraOperationCall( + context: Algebra, + name: String, + fallbackMethodName: String, + arity: Int, + parameters: AsmBuilder.() -> Unit +) { + loadAlgebra() + if (!buildExpectationStack(context, name, arity)) loadStringConstant(name) + parameters() + + if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = fallbackMethodName, + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + *Array(arity) { AsmBuilder.OBJECT_TYPE } + ), + + expectedArity = arity + ) +} From e98fc126c4118bcf8750ea85bb25bca2d85b145f Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 20:15:14 +0700 Subject: [PATCH 103/104] Merge various codegen utilities into one file --- .../kmath/asm/internal/buildName.kt | 22 ------- .../kmath/asm/internal/classWriters.kt | 17 ----- .../scientifik/kmath/asm/internal/classes.kt | 7 -- .../{specialization.kt => codegenUtils.kt} | 64 ++++++++++++++++++- .../kmath/asm/internal/instructionAdapters.kt | 10 --- .../kmath/asm/internal/methodVisitors.kt | 9 --- 6 files changed, 63 insertions(+), 66 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{specialization.kt => codegenUtils.kt} (57%) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt deleted file mode 100644 index 41dbf5807..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ /dev/null @@ -1,22 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.ast.MST -import scientifik.kmath.expressions.Expression - -/** - * Creates a class name for [Expression] subclassed to implement [mst] provided. - * - * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there - * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. - */ -internal tailrec fun buildName(mst: MST, collision: Int = 0): String { - val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(mst, collision + 1) -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt deleted file mode 100644 index 7f0770b28..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ /dev/null @@ -1,17 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.FieldVisitor - -@Suppress("FunctionName") -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = - ClassWriter(flags).apply(block) - -internal inline fun ClassWriter.visitField( - access: Int, - name: String, - descriptor: String, - signature: String?, - value: Any?, - block: FieldVisitor.() -> Unit -): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt deleted file mode 100644 index dc0b35531..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt +++ /dev/null @@ -1,7 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.Type -import kotlin.reflect.KClass - -internal val KClass<*>.asm: Type - get() = Type.getType(java) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt similarity index 57% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt index 003dc47dd..46d07976d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -1,8 +1,12 @@ package scientifik.kmath.asm.internal +import org.objectweb.asm.* import org.objectweb.asm.Opcodes.INVOKEVIRTUAL -import org.objectweb.asm.Type +import org.objectweb.asm.commons.InstructionAdapter +import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra +import kotlin.reflect.KClass private val methodNameAdapters: Map, String> by lazy { hashMapOf( @@ -15,6 +19,60 @@ private val methodNameAdapters: Map, String> by lazy { ) } +internal val KClass<*>.asm: Type + get() = Type.getType(java) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor]. + */ +private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. + */ +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) + +/** + * Constructs a [Label], then applies it to this visitor. + */ +internal fun MethodVisitor.label(): Label { + val l = Label() + visitLabel(l) + return l +} + +/** + * Creates a class name for [Expression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} + +@Suppress("FunctionName") +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) + +internal inline fun ClassWriter.visitField( + access: Int, + name: String, + descriptor: String, + signature: String?, + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) + /** * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds * type expectation stack for needed arity. @@ -60,6 +118,9 @@ private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Strin return true } +/** + * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. + */ internal fun AsmBuilder.buildAlgebraOperationCall( context: Algebra, name: String, @@ -84,3 +145,4 @@ internal fun AsmBuilder.buildAlgebraOperationCall( expectedArity = arity ) } + diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt deleted file mode 100644 index f47293687..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt +++ /dev/null @@ -1,10 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.Label -import org.objectweb.asm.commons.InstructionAdapter - -internal fun InstructionAdapter.label(): Label { - val l = Label() - visitLabel(l) - return l -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt deleted file mode 100644 index aaae02ebb..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ /dev/null @@ -1,9 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.commons.InstructionAdapter - -internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) - -internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = - instructionAdapter().apply(block) From e91c5a57c493dbee9f8f1545b294b69848572cf0 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 20:31:42 +0700 Subject: [PATCH 104/104] Minor refactor for changed ExtendedFieldOperations, replace DoubleBuffer with RealBuffer --- .../kmath/operations/NumberAlgebra.kt | 4 +- .../kmath/operations/OptionalOperations.kt | 4 +- .../kmath/structures/RealBufferField.kt | 160 ++++++++---------- 3 files changed, 76 insertions(+), 92 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 3fb57656c..953c5a112 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -7,7 +7,6 @@ import kotlin.math.pow as kpow * Advanced Number-like field that implements basic operations */ interface ExtendedFieldOperations : - FieldOperations, InverseTrigonometricOperations, PowerOperations, ExponentialOperations { @@ -24,9 +23,8 @@ interface ExtendedFieldOperations : PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + else -> super.unaryOperation(operation, arg) } - } interface ExtendedField : ExtendedFieldOperations, Field { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index 542d6376b..709f0260f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -64,7 +64,7 @@ fun >> sqr(arg: T): T = arg pow 2.0 /* Exponential */ -interface ExponentialOperations: Algebra { +interface ExponentialOperations : Algebra { fun exp(arg: T): T fun ln(arg: T): T @@ -81,4 +81,4 @@ interface Norm { fun norm(arg: T): R } -fun >, R> norm(arg: T): R = arg.context.norm(arg) \ No newline at end of file +fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 2fb6d15d4..826203d1f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -9,185 +9,171 @@ import kotlin.math.* * A simple field over linear buffers of [Double] */ object RealBufferFieldOperations : ExtendedFieldOperations> { - - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { val kValue = k.toDouble() - return if (a is DoubleBuffer) { + return if (a is RealBuffer) { val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } - override fun sin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } - override fun cos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) + RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + + override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + + override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) } - override fun tan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun acos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { tan(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { tan(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { acos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - override fun asin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun atan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { asin(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { asin(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { atan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - override fun acos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { acos(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + } else + RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - override fun atan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { atan(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (arg is DoubleBuffer) { + override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - } - - override fun exp(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - } - - override fun ln(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } class RealBufferField(val size: Int) : ExtendedField> { + override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } + override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } - override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } - - override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } - - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, k) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, b) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) } - override fun sin(arg: Buffer): DoubleBuffer { + override fun sin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sin(arg) } - override fun cos(arg: Buffer): DoubleBuffer { + override fun cos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cos(arg) } - override fun tan(arg: Buffer): DoubleBuffer { + override fun tan(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.tan(arg) } - override fun asin(arg: Buffer): DoubleBuffer { + override fun asin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.asin(arg) } - override fun acos(arg: Buffer): DoubleBuffer { + override fun acos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.acos(arg) } - override fun atan(arg: Buffer): DoubleBuffer { + override fun atan(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.atan(arg) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) } - override fun exp(arg: Buffer): DoubleBuffer { + override fun exp(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.exp(arg) } - override fun ln(arg: Buffer): DoubleBuffer { + override fun ln(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) }