From 7b50400de5ad44ee123da699e7b8136130942117 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 16 Nov 2021 14:03:12 +0700 Subject: [PATCH] Provide specializations of AsmBuilder for Double, Long, Int --- README.md | 16 +- .../ExpressionsInterpretersBenchmark.kt | 13 +- .../space/kscience/kmath/ast/expressions.kt | 6 +- kmath-ast/README.md | 6 +- kmath-ast/build.gradle.kts | 4 + .../kmath/ast/TestCompilerOperations.kt | 24 + .../kmath/wasm/internal/WasmBuilder.kt | 36 +- .../kotlin/space/kscience/kmath/wasm/wasm.kt | 18 - .../kotlin/space/kscience/kmath/asm/asm.kt | 90 +++- .../kscience/kmath/asm/internal/AsmBuilder.kt | 356 +-------------- .../asm/internal/ByteArrayClassLoader.kt | 1 + .../kmath/asm/internal/GenericAsmBuilder.kt | 325 ++++++++++++++ .../kmath/asm/internal/PrimitiveAsmBuilder.kt | 411 ++++++++++++++++++ .../kmath/asm/internal/mapIntrinsics.kt | 1 + .../kotlin/space/kscience/kmath/ast/utils.kt | 21 +- kmath-complex/README.md | 6 +- kmath-core/README.md | 6 +- .../kscience/kmath/expressions/Symbol.kt | 4 +- kmath-ejml/README.md | 6 +- kmath-for-real/README.md | 6 +- kmath-functions/README.md | 6 +- kmath-jafama/README.md | 6 +- kmath-kotlingrad/README.md | 6 +- kmath-nd4j/README.md | 6 +- kmath-tensors/README.md | 6 +- 25 files changed, 961 insertions(+), 425 deletions(-) create mode 100644 kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt diff --git a/README.md b/README.md index db069d4e0..8604873ae 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,12 @@ One can still use generic algebras though. > **Maturity**: DEVELOPMENT
+* ### [kmath-multik](kmath-multik) +> +> +> **Maturity**: PROTOTYPE +
+ * ### [kmath-nd4j](kmath-nd4j) > > @@ -252,6 +258,12 @@ One can still use generic algebras though.
+* ### [kmath-optimization](kmath-optimization) +> +> +> **Maturity**: EXPERIMENTAL +
+ * ### [kmath-stat](kmath-stat) > > @@ -319,8 +331,8 @@ repositories { } dependencies { - api("space.kscience:kmath-core:0.3.0-dev-14") - // api("space.kscience:kmath-core-jvm:0.3.0-dev-14") for jvm-specific version + api("space.kscience:kmath-core:0.3.0-dev-17") + // api("space.kscience:kmath-core-jvm:0.3.0-dev-17") for jvm-specific version } ``` diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt index b824a0d69..4c6b92248 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt @@ -11,6 +11,7 @@ import kotlinx.benchmark.Scope import kotlinx.benchmark.State import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.expressions.* +import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.invoke @@ -35,7 +36,14 @@ internal class ExpressionsInterpretersBenchmark { * Benchmark case for [Expression] created with [compileToExpression]. */ @Benchmark - fun asmExpression(blackhole: Blackhole) = invokeAndSum(asm, blackhole) + fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole) + + + /** + * Benchmark case for [Expression] created with [compileToExpression]. + */ + @Benchmark + fun asmPrimitiveExpression(blackhole: Blackhole) = invokeAndSum(asmPrimitive, blackhole) /** * Benchmark case for [Expression] implemented manually with `kotlin.math` functions. @@ -87,7 +95,8 @@ internal class ExpressionsInterpretersBenchmark { } private val mst = node.toExpression(DoubleField) - private val asm = node.compileToExpression(DoubleField) + private val asmPrimitive = node.compileToExpression(DoubleField) + private val asmGeneric = node.compileToExpression(DoubleField as Algebra) private val raw = Expression { args -> val x = args[x]!! diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt index b37d120fb..3c68f9bb5 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt @@ -6,15 +6,15 @@ package space.kscience.kmath.ast import space.kscience.kmath.asm.compileToExpression -import space.kscience.kmath.expressions.MstField +import space.kscience.kmath.expressions.MstExtendedField import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke fun main() { - val expr = MstField { - x * 2.0 + number(2.0) / x - 16.0 + val expr = MstExtendedField { + x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x) }.compileToExpression(DoubleField) val m = HashMap() diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 686506f6f..5e3366881 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -10,7 +10,7 @@ Performance and visualization extensions to MST API. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-17`. **Gradle:** ```gradle @@ -20,7 +20,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-ast:0.3.0-dev-14' + implementation 'space.kscience:kmath-ast:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -31,7 +31,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-ast:0.3.0-dev-14") + implementation("space.kscience:kmath-ast:0.3.0-dev-17") } ``` diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 9de7e9980..586652349 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -55,6 +55,10 @@ tasks.dokkaHtml { dependsOn(tasks.build) } +tasks.jvmTest { + jvmArgs = (jvmArgs ?: emptyList()) + listOf("-Dspace.kscience.kmath.ast.dump.generated.classes=1") +} + readme { maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt index f5b1e2842..00344410c 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt @@ -44,6 +44,30 @@ internal class TestCompilerOperations { assertEquals(1.0, expr(x to 0.0)) } + @Test + fun testTangent() = runCompilerTest { + val expr = MstExtendedField { tan(x) }.compileToExpression(DoubleField) + assertEquals(0.0, expr(x to 0.0)) + } + + @Test + fun testArcSine() = runCompilerTest { + val expr = MstExtendedField { asin(x) }.compileToExpression(DoubleField) + assertEquals(0.0, expr(x to 0.0)) + } + + @Test + fun testArcCosine() = runCompilerTest { + val expr = MstExtendedField { acos(x) }.compileToExpression(DoubleField) + assertEquals(0.0, expr(x to 1.0)) + } + + @Test + fun testAreaHyperbolicSine() = runCompilerTest { + val expr = MstExtendedField { asinh(x) }.compileToExpression(DoubleField) + assertEquals(0.0, expr(x to 0.0)) + } + @Test fun testSubtract() = runCompilerTest { val expr = MstExtendedField { x - x }.compileToExpression(DoubleField) diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt index b04c4d48f..6039d5c27 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt @@ -18,12 +18,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule private val spreader = eval("(obj, args) => obj(...args)") @Suppress("UnsafeCastFromDynamic") -internal sealed class WasmBuilder( - val binaryenType: Type, - val algebra: Algebra, - val target: MST, -) where T : Number { - val keys: MutableList = mutableListOf() +internal sealed class WasmBuilder( + protected val binaryenType: Type, + protected val algebra: Algebra, + protected val target: MST, +) { + protected val keys: MutableList = mutableListOf() lateinit var ctx: BinaryenModule open fun visitSymbolic(mst: Symbol): ExpressionRef { @@ -41,30 +41,36 @@ internal sealed class WasmBuilder( abstract fun visitNumeric(mst: Numeric): ExpressionRef - open fun visitUnary(mst: Unary): ExpressionRef = + protected open fun visitUnary(mst: Unary): ExpressionRef = error("Unary operation ${mst.operation} not defined in $this") - open fun visitBinary(mst: Binary): ExpressionRef = + protected open fun visitBinary(mst: Binary): ExpressionRef = error("Binary operation ${mst.operation} not defined in $this") - open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") + protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") - fun visit(mst: MST): ExpressionRef = when (mst) { + protected fun visit(mst: MST): ExpressionRef = when (mst) { is Symbol -> visitSymbolic(mst) is Numeric -> visitNumeric(mst) is Unary -> when { algebra is NumericAlgebra && mst.value is Numeric -> visitNumeric( - Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value)))) + Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value))) + ) else -> visitUnary(mst) } is Binary -> when { - algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric(Numeric( - algebra.binaryOperationFunction(mst.operation) - .invoke(algebra.number((mst.left as Numeric).value), algebra.number((mst.right as Numeric).value)) - )) + algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric( + Numeric( + algebra.binaryOperationFunction(mst.operation) + .invoke( + algebra.number((mst.left as Numeric).value), + algebra.number((mst.right as Numeric).value) + ) + ) + ) else -> visitBinary(mst) } diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt index 5b28b8782..393a02e3d 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt @@ -16,24 +16,6 @@ import space.kscience.kmath.operations.IntRing import space.kscience.kmath.wasm.internal.DoubleWasmBuilder import space.kscience.kmath.wasm.internal.IntWasmBuilder -/** - * Compiles an [MST] to WASM in the context of reals. - * - * @author Iaroslav Postovalov - */ -@UnstableKMathAPI -public fun DoubleField.expression(mst: MST): Expression = - DoubleWasmBuilder(mst).instance - -/** - * Compiles an [MST] to WASM in the context of integers. - * - * @author Iaroslav Postovalov - */ -@UnstableKMathAPI -public fun IntRing.expression(mst: MST): Expression = - IntWasmBuilder(mst).instance - /** * Create a compiled expression with given [MST] and given [algebra]. * diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 9a7c4b023..19b400cbd 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -3,18 +3,18 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ +@file:Suppress("UNUSED_PARAMETER") + package space.kscience.kmath.asm -import space.kscience.kmath.asm.internal.AsmBuilder -import space.kscience.kmath.asm.internal.buildName +import space.kscience.kmath.asm.internal.* import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST.* import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.invoke -import space.kscience.kmath.operations.Algebra -import space.kscience.kmath.operations.NumericAlgebra -import space.kscience.kmath.operations.bindSymbolOrNull +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.* /** * Compiles given MST to an Expression using AST compiler. @@ -26,7 +26,7 @@ import space.kscience.kmath.operations.bindSymbolOrNull */ @PublishedApi internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { - fun AsmBuilder.variablesVisitor(node: MST): Unit = when (node) { + fun GenericAsmBuilder.variablesVisitor(node: MST): Unit = when (node) { is Symbol -> prepareVariable(node.identity) is Unary -> variablesVisitor(node.value) @@ -38,7 +38,7 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp else -> Unit } - fun AsmBuilder.expressionVisitor(node: MST): Unit = when (node) { + fun GenericAsmBuilder.expressionVisitor(node: MST): Unit = when (node) { is Symbol -> { val symbol = algebra.bindSymbolOrNull(node) @@ -87,7 +87,7 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp } } - return AsmBuilder( + return GenericAsmBuilder( type, buildName(this), { variablesVisitor(this@compileWith) }, @@ -114,3 +114,77 @@ public inline fun MST.compile(algebra: Algebra, arguments: */ public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T = compileToExpression(algebra).invoke(*arguments) + + +/** + * Create a compiled expression with given [MST] and given [algebra]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compileToExpression(algebra: IntRing): Expression = IntAsmBuilder(this).instance + +/** + * Compile given MST to expression and evaluate it against [arguments]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compile(algebra: IntRing, arguments: Map): Int = + compileToExpression(algebra).invoke(arguments) + +/** + * Compile given MST to expression and evaluate it against [arguments]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compile(algebra: IntRing, vararg arguments: Pair): Int = + compileToExpression(algebra)(*arguments) + + +/** + * Create a compiled expression with given [MST] and given [algebra]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compileToExpression(algebra: LongRing): Expression = LongAsmBuilder(this).instance + + +/** + * Compile given MST to expression and evaluate it against [arguments]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compile(algebra: LongRing, arguments: Map): Long = + compileToExpression(algebra).invoke(arguments) + + +/** + * Compile given MST to expression and evaluate it against [arguments]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compile(algebra: LongRing, vararg arguments: Pair): Long = + compileToExpression(algebra)(*arguments) + + +/** + * Create a compiled expression with given [MST] and given [algebra]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compileToExpression(algebra: DoubleField): Expression = DoubleAsmBuilder(this).instance + +/** + * Compile given MST to expression and evaluate it against [arguments]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compile(algebra: DoubleField, arguments: Map): Double = + compileToExpression(algebra).invoke(arguments) + +/** + * Compile given MST to expression and evaluate it against [arguments]. + * + * @author Iaroslav Postovalov + */ +public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double = + compileToExpression(algebra).invoke(*arguments) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt index 4e29c77e0..2a7b3dea9 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt @@ -1,377 +1,47 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. - */ - package space.kscience.kmath.asm.internal -import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.* -import org.objectweb.asm.Type.* -import org.objectweb.asm.commons.InstructionAdapter -import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader -import space.kscience.kmath.expressions.* -import java.lang.invoke.MethodHandles -import java.lang.invoke.MethodType -import java.nio.file.Paths -import java.util.stream.Collectors.toMap -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract -import kotlin.io.path.writeBytes +import org.objectweb.asm.Type +import space.kscience.kmath.expressions.Expression -/** - * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. - * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. - * - * @property T the type of AsmExpression to unwrap. - * @property className the unique class name of new loaded class. - * @property expressionResultCallback the function to apply to this object when generating expression value. - * @author Iaroslav Postovalov - */ -internal class AsmBuilder( - classOfT: Class<*>, - private val className: String, - private val variablesPrepareCallback: AsmBuilder.() -> Unit, - private val expressionResultCallback: AsmBuilder.() -> Unit, -) { +internal abstract class AsmBuilder { /** - * Internal classloader of [AsmBuilder] with alias to define class from byte array. + * Internal classloader with alias to define class from byte array. */ - private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + class ByteArrayClassLoader(parent: ClassLoader) : ClassLoader(parent) { fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } - /** - * The instance of [ClassLoader] used by this builder. - */ - private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - - /** - * ASM type for [T]. - */ - private val tType: Type = classOfT.asm - - /** - * ASM type for new class. - */ - private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/')) - - /** - * List of constants to provide to the subclass. - */ - private val constants: MutableList = mutableListOf() - - /** - * Method visitor of `invoke` method of the subclass. - */ - private lateinit var invokeMethodVisitor: InstructionAdapter - - /** - * Local variables indices are indices of symbols in this list. - */ - private val argumentsLocals = mutableListOf() - - /** - * Subclasses, loads and instantiates [Expression] for given parameters. - * - * The built instance is cached. - */ - @Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE") - val instance: Expression by lazy { - val hasConstants: Boolean - - val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { - visit( - V1_8, - ACC_PUBLIC or ACC_FINAL or ACC_SUPER, - classType.internalName, - "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", - OBJECT_TYPE.internalName, - arrayOf(EXPRESSION_TYPE.internalName), - ) - - visitMethod( - ACC_PUBLIC or ACC_FINAL, - "invoke", - getMethodDescriptor(tType, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", - null, - ).instructionAdapter { - invokeMethodVisitor = this - visitCode() - val preparingVariables = label() - variablesPrepareCallback() - val expressionResult = label() - expressionResultCallback() - areturn(tType) - val end = label() - - visitLocalVariable( - "this", - classType.descriptor, - null, - preparingVariables, - end, - 0, - ) - - visitLocalVariable( - "arguments", - MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;", - preparingVariables, - end, - 1, - ) - - visitMaxs(0, 2) - visitEnd() - } - - visitMethod( - ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, - "invoke", - getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), - null, - null, - ).instructionAdapter { - visitCode() - val l0 = label() - load(0, OBJECT_TYPE) - load(1, MAP_TYPE) - invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false) - areturn(tType) - val l1 = label() - - visitLocalVariable( - "this", - classType.descriptor, - null, - l0, - l1, - 0, - ) - - visitMaxs(0, 2) - visitEnd() - } - - hasConstants = constants.isNotEmpty() - - if (hasConstants) - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "constants", - descriptor = OBJECT_ARRAY_TYPE.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd, - ) - - visitMethod( - ACC_PUBLIC, - "", - getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), - null, - null, - ).instructionAdapter { - val l0 = label() - load(0, classType) - invokespecial(OBJECT_TYPE.internalName, "", getMethodDescriptor(VOID_TYPE), false) - label() - load(0, classType) - - if (hasConstants) { - label() - load(0, classType) - load(1, OBJECT_ARRAY_TYPE) - putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) - } - - label() - visitInsn(RETURN) - val l4 = label() - visitLocalVariable("this", classType.descriptor, null, l0, l4, 0) - - if (hasConstants) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1) - - visitMaxs(0, 3) - visitEnd() - } - - visitEnd() - } - - val binary = classWriter.toByteArray() - val cls = classLoader.defineClass(className, binary) - - if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1") - Paths.get("$className.class").writeBytes(binary) - - val l = MethodHandles.publicLookup() - - (if (hasConstants) - l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array::class.java))(constants.toTypedArray()) - else - l.findConstructor(cls, MethodType.methodType(Void.TYPE))()) as Expression - } - - /** - * Loads [java.lang.Object] constant from constants. - */ - fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { - val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex - invokeMethodVisitor.load(0, classType) - getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) - iconst(idx) - visitInsn(AALOAD) - if (type != OBJECT_TYPE) checkcast(type) - } - - /** - * Either loads a numeric constant [value] from the class's constants field or boxes a primitive - * constant from the constant pool. - */ - fun loadNumberConstant(value: Number) { - val boxed = value.javaClass.asm - val primitive = BOXED_TO_PRIMITIVES[boxed] - - if (primitive != null) { - when (primitive) { - BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) - FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) - LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) - INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - } - - val r = PRIMITIVES_TO_BOXED.getValue(primitive) - - invokeMethodVisitor.invokestatic( - r.internalName, - "valueOf", - getMethodDescriptor(r, primitive), - false, - ) - - return - } - - loadObjectConstant(value, boxed) - } - - /** - * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using - * [loadVariable]. - */ - fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { - if (name in argumentsLocals) return@run - load(1, MAP_TYPE) - aconst(name) - - invokestatic( - MAP_INTRINSICS_TYPE.internalName, - "getOrFail", - getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), - false, - ) - - checkcast(tType) - var idx = argumentsLocals.indexOf(name) - - if (idx == -1) { - argumentsLocals += name - idx = argumentsLocals.lastIndex - } - - store(2 + idx, tType) - } - - /** - * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored - * with [prepareVariable] first. - */ - fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType) - - inline fun buildCall(function: Function, parameters: AsmBuilder.() -> Unit) { - contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } - val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces } - - val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount - ?: error("Provided function object doesn't contain invoke method") - - val type = getType(`interface`) - loadObjectConstant(function, type) - parameters(this) - - invokeMethodVisitor.invokeinterface( - type.internalName, - "invoke", - getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }), - ) - - invokeMethodVisitor.checkcast(tType) - } - - companion object { - /** - * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. - */ - private val BOXED_TO_PRIMITIVES: Map by lazy { - hashMapOf( - Byte::class.java.asm to BYTE_TYPE, - Short::class.java.asm to SHORT_TYPE, - Integer::class.java.asm to INT_TYPE, - Long::class.java.asm to LONG_TYPE, - Float::class.java.asm to FLOAT_TYPE, - Double::class.java.asm to DOUBLE_TYPE, - ) - } - - /** - * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. - */ - private val PRIMITIVES_TO_BOXED: Map by lazy { - BOXED_TO_PRIMITIVES.entries.stream().collect( - toMap(Map.Entry::value, Map.Entry::key), - ) - } + protected val classLoader = ByteArrayClassLoader(javaClass.classLoader) + protected companion object { /** * ASM type for [Expression]. */ - val EXPRESSION_TYPE: Type by lazy { getObjectType("space/kscience/kmath/expressions/Expression") } + val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Expression") } /** * ASM type for [java.util.Map]. */ - val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") } + val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") } /** * ASM type for [java.lang.Object]. */ - val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") } - - /** - * ASM type for array of [java.lang.Object]. - */ - val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } + val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") } /** * ASM type for [java.lang.String]. */ - val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") } + val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") } /** * ASM type for MapIntrinsics. */ - val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") } + val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") } /** * ASM Type for [space.kscience.kmath.expressions.Symbol]. */ - val SYMBOL_TYPE: Type by lazy { getObjectType("space/kscience/kmath/expressions/Symbol") } + val SYMBOL_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Symbol") } } } diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt new file mode 100644 index 000000000..afff48ef6 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt @@ -0,0 +1 @@ +package space.kscience.kmath.asm.internal diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt new file mode 100644 index 000000000..1b45b0625 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt @@ -0,0 +1,325 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.asm.internal + +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.Type.* +import org.objectweb.asm.commons.InstructionAdapter +import space.kscience.kmath.expressions.* +import java.lang.invoke.MethodHandles +import java.lang.invoke.MethodType +import java.nio.file.Paths +import java.util.stream.Collectors.toMap +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract +import kotlin.io.path.writeBytes + +/** + * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. + * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. + * + * @property T the type of AsmExpression to unwrap. + * @property className the unique class name of new loaded class. + * @property expressionResultCallback the function to apply to this object when generating expression value. + * @author Iaroslav Postovalov + */ +internal class GenericAsmBuilder( + classOfT: Class<*>, + private val className: String, + private val variablesPrepareCallback: GenericAsmBuilder.() -> Unit, + private val expressionResultCallback: GenericAsmBuilder.() -> Unit, +) : AsmBuilder() { + /** + * ASM type for [T]. + */ + private val tType: Type = classOfT.asm + + /** + * ASM type for new class. + */ + private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/')) + + /** + * List of constants to provide to the subclass. + */ + private val constants = mutableListOf() + + /** + * Method visitor of `invoke` method of the subclass. + */ + private lateinit var invokeMethodVisitor: InstructionAdapter + + /** + * Local variables indices are indices of symbols in this list. + */ + private val argumentsLocals = mutableListOf() + + /** + * Subclasses, loads and instantiates [Expression] for given parameters. + * + * The built instance is cached. + */ + @Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE") + val instance: Expression by lazy { + val hasConstants: Boolean + + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { + visit( + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName), + ) + + visitMethod( + ACC_PUBLIC or ACC_FINAL, + "invoke", + getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", + null, + ).instructionAdapter { + invokeMethodVisitor = this + visitCode() + val preparingVariables = label() + variablesPrepareCallback() + val expressionResult = label() + expressionResultCallback() + areturn(tType) + val end = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + preparingVariables, + end, + 0, + ) + + visitLocalVariable( + "arguments", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;", + preparingVariables, + end, + 1, + ) + + visitMaxs(0, 0) + visitEnd() + } + + visitMethod( + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, + "invoke", + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null, + ).instructionAdapter { + visitCode() + val start = label() + load(0, OBJECT_TYPE) + load(1, MAP_TYPE) + invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val end = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + start, + end, + 0, + ) + + visitMaxs(0, 0) + visitEnd() + } + + hasConstants = constants.isNotEmpty() + + if (hasConstants) + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd, + ) + + visitMethod( + ACC_PUBLIC, + "", + getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), + null, + null, + ).instructionAdapter { + val l0 = label() + load(0, classType) + invokespecial(OBJECT_TYPE.internalName, "", getMethodDescriptor(VOID_TYPE), false) + label() + load(0, classType) + + if (hasConstants) { + label() + load(0, classType) + load(1, OBJECT_ARRAY_TYPE) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + } + + label() + visitInsn(RETURN) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, 0) + + if (hasConstants) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1) + + visitMaxs(0, 0) + visitEnd() + } + + visitEnd() + } + + val binary = classWriter.toByteArray() + val cls = classLoader.defineClass(className, binary) + + if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1") + Paths.get("${className.split('.').last()}.class").writeBytes(binary) + + val l = MethodHandles.publicLookup() + + (if (hasConstants) + l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array::class.java))(constants.toTypedArray()) + else + l.findConstructor(cls, MethodType.methodType(Void.TYPE))()) as Expression + } + + /** + * Loads [java.lang.Object] constant from constants. + */ + fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { + val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex + invokeMethodVisitor.load(0, classType) + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + if (type != OBJECT_TYPE) checkcast(type) + } + + /** + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive + * constant from the constant pool. + */ + fun loadNumberConstant(value: Number) { + val boxed = value.javaClass.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] + + if (primitive != null) { + when (primitive) { + BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + + val r = boxed + + invokeMethodVisitor.invokestatic( + r.internalName, + "valueOf", + getMethodDescriptor(r, primitive), + false, + ) + + return + } + + loadObjectConstant(value, boxed) + } + + /** + * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using + * [loadVariable]. + */ + fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { + if (name in argumentsLocals) return@run + load(1, MAP_TYPE) + aconst(name) + + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), + false, + ) + + checkcast(tType) + var idx = argumentsLocals.indexOf(name) + + if (idx == -1) { + argumentsLocals += name + idx = argumentsLocals.lastIndex + } + + store(2 + idx, tType) + } + + /** + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored + * with [prepareVariable] first. + */ + fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType) + + inline fun buildCall(function: Function, parameters: GenericAsmBuilder.() -> Unit) { + contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } + val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces } + + val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount + ?: error("Provided function object doesn't contain invoke method") + + val type = getType(`interface`) + loadObjectConstant(function, type) + parameters(this) + + invokeMethodVisitor.invokeinterface( + type.internalName, + "invoke", + getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }), + ) + + invokeMethodVisitor.checkcast(tType) + } + + private companion object { + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ + private val BOXED_TO_PRIMITIVES: Map by lazy { + hashMapOf( + Byte::class.java.asm to BYTE_TYPE, + Short::class.java.asm to SHORT_TYPE, + Integer::class.java.asm to INT_TYPE, + Long::class.java.asm to LONG_TYPE, + Float::class.java.asm to FLOAT_TYPE, + Double::class.java.asm to DOUBLE_TYPE, + ) + } + + /** + * ASM type for array of [java.lang.Object]. + */ + val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt new file mode 100644 index 000000000..2819368ed --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt @@ -0,0 +1,411 @@ +package space.kscience.kmath.asm.internal + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type +import org.objectweb.asm.Type.* +import org.objectweb.asm.commons.InstructionAdapter +import space.kscience.kmath.expressions.Expression +import space.kscience.kmath.expressions.MST +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.operations.* +import java.lang.invoke.MethodHandles +import java.lang.invoke.MethodType +import java.nio.file.Paths +import kotlin.io.path.writeBytes + +internal sealed class PrimitiveAsmBuilder( + protected val algebra: Algebra, + classOfT: Class<*>, + protected val classOfTPrimitive: Class<*>, + protected val target: MST, +) : AsmBuilder() { + private val className: String = buildName(target) + + /** + * ASM type for [T]. + */ + private val tType: Type = classOfT.asm + + /** + * ASM type for [T]. + */ + protected val tTypePrimitive: Type = classOfTPrimitive.asm + + /** + * ASM type for new class. + */ + private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/')) + + /** + * Method visitor of `invoke` method of the subclass. + */ + protected lateinit var invokeMethodVisitor: InstructionAdapter + + /** + * Local variables indices are indices of symbols in this list. + */ + private val argumentsLocals = mutableListOf() + + /** + * Subclasses, loads and instantiates [Expression] for given parameters. + * + * The built instance is cached. + */ + @Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE") + val instance: Expression by lazy { + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { + visit( + Opcodes.V1_8, + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName), + ) + + visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + "invoke", + getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", + null, + ).instructionAdapter { + invokeMethodVisitor = this + visitCode() + val preparingVariables = label() + visitVariables(target) + val expressionResult = label() + visitExpression(target) + box() + areturn(tType) + val end = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + preparingVariables, + end, + 0, + ) + + visitLocalVariable( + "arguments", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;", + preparingVariables, + end, + 1, + ) + + visitMaxs(0, 0) + visitEnd() + } + + visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + "invoke", + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null, + ).instructionAdapter { + visitCode() + val start = label() + load(0, OBJECT_TYPE) + load(1, MAP_TYPE) + invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val end = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + start, + end, + 0, + ) + + visitMaxs(0, 0) + visitEnd() + } + + visitMethod( + Opcodes.ACC_PUBLIC, + "", + getMethodDescriptor(VOID_TYPE), + null, + null, + ).instructionAdapter { + val start = label() + load(0, classType) + invokespecial(OBJECT_TYPE.internalName, "", getMethodDescriptor(VOID_TYPE), false) + label() + load(0, classType) + label() + visitInsn(Opcodes.RETURN) + val end = label() + visitLocalVariable("this", classType.descriptor, null, start, end, 0) + visitMaxs(0, 0) + visitEnd() + } + + visitEnd() + } + + val binary = classWriter.toByteArray() + val cls = classLoader.defineClass(className, binary) + + if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1") + Paths.get("${className.split('.').last()}.class").writeBytes(binary) + + MethodHandles.publicLookup().findConstructor(cls, MethodType.methodType(Void.TYPE))() as Expression + } + + /** + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive + * constant from the constant pool. + */ + fun loadNumberConstant(value: Number) { + when (tTypePrimitive) { + BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + } + + /** + * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using + * [loadVariable]. + */ + fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { + if (name in argumentsLocals) return@run + load(1, MAP_TYPE) + aconst(name) + + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), + false, + ) + + checkcast(tType) + var idx = argumentsLocals.indexOf(name) + + if (idx == -1) { + argumentsLocals += name + idx = argumentsLocals.lastIndex + } + + unbox() + store(2 + idx, tTypePrimitive) + } + + /** + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored + * with [prepareVariable] first. + */ + fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tTypePrimitive) + + private fun unbox() = invokeMethodVisitor.run { + invokevirtual( + NUMBER_TYPE.internalName, + "${classOfTPrimitive.simpleName}Value", + getMethodDescriptor(tTypePrimitive), + false + ) + } + + private fun box() = invokeMethodVisitor.run { + invokestatic(tType.internalName, "valueOf", getMethodDescriptor(tType, tTypePrimitive), false) + } + + protected fun visitVariables(node: MST): Unit = when (node) { + is Symbol -> prepareVariable(node.identity) + is MST.Unary -> visitVariables(node.value) + + is MST.Binary -> { + visitVariables(node.left) + visitVariables(node.right) + } + + else -> Unit + } + + protected fun visitExpression(mst: MST): Unit = when (mst) { + is Symbol -> loadVariable(mst.identity) + is MST.Numeric -> loadNumberConstant(mst.value) + + is MST.Unary -> when { + algebra is NumericAlgebra && mst.value is MST.Numeric -> { + loadNumberConstant( + MST.Numeric( + algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as MST.Numeric).value)), + ).value, + ) + } + + else -> visitUnary(mst) + } + + is MST.Binary -> when { + algebra is NumericAlgebra && mst.left is MST.Numeric && mst.right is MST.Numeric -> { + loadNumberConstant( + MST.Numeric( + algebra.binaryOperationFunction(mst.operation)( + algebra.number((mst.left as MST.Numeric).value), + algebra.number((mst.right as MST.Numeric).value), + ), + ).value, + ) + } + + else -> visitBinary(mst) + } + } + + protected open fun visitUnary(mst: MST.Unary) { + visitExpression(mst.value) + } + + protected open fun visitBinary(mst: MST.Binary) { + visitExpression(mst.left) + visitExpression(mst.right) + } + + protected companion object { + /** + * ASM type for [java.lang.Number]. + */ + val NUMBER_TYPE: Type by lazy { getObjectType("java/lang/Number") } + } +} + +internal class DoubleAsmBuilder(target: MST) : + PrimitiveAsmBuilder(DoubleField, java.lang.Double::class.java, java.lang.Double.TYPE, target) { + + private fun buildUnaryJavaMathCall(name: String) { + invokeMethodVisitor.invokestatic( + MATH_TYPE.internalName, + name, + getMethodDescriptor(tTypePrimitive, tTypePrimitive), + false, + ) + } + + private fun buildBinaryJavaMathCall(name: String) { + invokeMethodVisitor.invokestatic( + MATH_TYPE.internalName, + name, + getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive), + false, + ) + } + + private fun buildUnaryKotlinMathCall(name: String) { + invokeMethodVisitor.invokestatic( + MATH_KT_TYPE.internalName, + name, + getMethodDescriptor(tTypePrimitive, tTypePrimitive), + false, + ) + } + + override fun visitUnary(mst: MST.Unary) { + super.visitUnary(mst) + + when (mst.operation) { + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DNEG) + GroupOps.PLUS_OPERATION -> Unit + PowerOperations.SQRT_OPERATION -> buildUnaryJavaMathCall("sqrt") + TrigonometricOperations.SIN_OPERATION -> buildUnaryJavaMathCall("sin") + TrigonometricOperations.COS_OPERATION -> buildUnaryJavaMathCall("cos") + TrigonometricOperations.TAN_OPERATION -> buildUnaryJavaMathCall("tan") + TrigonometricOperations.ASIN_OPERATION -> buildUnaryJavaMathCall("asin") + TrigonometricOperations.ACOS_OPERATION -> buildUnaryJavaMathCall("acos") + TrigonometricOperations.ATAN_OPERATION -> buildUnaryJavaMathCall("atan") + ExponentialOperations.SINH_OPERATION -> buildUnaryJavaMathCall("sqrt") + ExponentialOperations.COSH_OPERATION -> buildUnaryJavaMathCall("cosh") + ExponentialOperations.TANH_OPERATION -> buildUnaryJavaMathCall("tanh") + ExponentialOperations.ASINH_OPERATION -> buildUnaryKotlinMathCall("asinh") + ExponentialOperations.ACOSH_OPERATION -> buildUnaryKotlinMathCall("acosh") + ExponentialOperations.ATANH_OPERATION -> buildUnaryKotlinMathCall("atanh") + ExponentialOperations.EXP_OPERATION -> buildUnaryJavaMathCall("exp") + ExponentialOperations.LN_OPERATION -> buildUnaryJavaMathCall("log") + else -> super.visitUnary(mst) + } + } + + override fun visitBinary(mst: MST.Binary) { + super.visitBinary(mst) + + when (mst.operation) { + GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DADD) + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DSUB) + RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DMUL) + FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DDIV) + PowerOperations.POW_OPERATION -> buildBinaryJavaMathCall("pow") + else -> super.visitBinary(mst) + } + } + + companion object { + val MATH_TYPE: Type by lazy { getObjectType("java/lang/Math") } + val MATH_KT_TYPE: Type by lazy { getObjectType("kotlin/math/MathKt") } + } +} + +internal class IntAsmBuilder(target: MST) : + PrimitiveAsmBuilder(IntRing, Integer::class.java, Integer.TYPE, target) { + override fun visitUnary(mst: MST.Unary) { + super.visitUnary(mst) + + when (mst.operation) { + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.INEG) + GroupOps.PLUS_OPERATION -> Unit + else -> super.visitUnary(mst) + } + } + + override fun visitBinary(mst: MST.Binary) { + super.visitBinary(mst) + + when (mst.operation) { + GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IADD) + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.ISUB) + RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IMUL) + else -> super.visitBinary(mst) + } + } +} + +internal class LongAsmBuilder(target: MST) : + PrimitiveAsmBuilder(LongRing, java.lang.Long::class.java, java.lang.Long.TYPE, target) { + override fun visitUnary(mst: MST.Unary) { + super.visitUnary(mst) + + when (mst.operation) { + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LNEG) + GroupOps.PLUS_OPERATION -> Unit + else -> super.visitUnary(mst) + } + } + + override fun visitBinary(mst: MST.Binary) { + super.visitBinary(mst) + + when (mst.operation) { + GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LADD) + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LSUB) + RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LMUL) + else -> super.visitBinary(mst) + } + } +} + diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt index 40d9d8fe6..b8a2a9669 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt @@ -14,4 +14,5 @@ import space.kscience.kmath.expressions.Symbol * * @author Iaroslav Postovalov */ +@Suppress("unused") internal fun Map.getOrFail(key: String): V = getValue(Symbol(key)) diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt index a0bdd68a0..1ea4cd0be 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt @@ -8,6 +8,7 @@ package space.kscience.kmath.ast import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.IntRing import kotlin.contracts.InvocationKind @@ -15,7 +16,21 @@ import kotlin.contracts.contract import space.kscience.kmath.asm.compile as asmCompile import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression -private object AsmCompilerTestContext : CompilerTestContext { +private object GenericAsmCompilerTestContext : CompilerTestContext { + override fun MST.compileToExpression(algebra: IntRing): Expression = + asmCompileToExpression(algebra as Algebra) + + override fun MST.compile(algebra: IntRing, arguments: Map): Int = + asmCompile(algebra as Algebra, arguments) + + override fun MST.compileToExpression(algebra: DoubleField): Expression = + asmCompileToExpression(algebra as Algebra) + + override fun MST.compile(algebra: DoubleField, arguments: Map): Double = + asmCompile(algebra as Algebra, arguments) +} + +private object PrimitiveAsmCompilerTestContext : CompilerTestContext { override fun MST.compileToExpression(algebra: IntRing): Expression = asmCompileToExpression(algebra) override fun MST.compile(algebra: IntRing, arguments: Map): Int = asmCompile(algebra, arguments) override fun MST.compileToExpression(algebra: DoubleField): Expression = asmCompileToExpression(algebra) @@ -24,7 +39,9 @@ private object AsmCompilerTestContext : CompilerTestContext { asmCompile(algebra, arguments) } + internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) { contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } - action(AsmCompilerTestContext) + action(GenericAsmCompilerTestContext) + action(PrimitiveAsmCompilerTestContext) } diff --git a/kmath-complex/README.md b/kmath-complex/README.md index 110529b72..92f2435ba 100644 --- a/kmath-complex/README.md +++ b/kmath-complex/README.md @@ -8,7 +8,7 @@ Complex and hypercomplex number systems in KMath. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-17`. **Gradle:** ```gradle @@ -18,7 +18,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-complex:0.3.0-dev-14' + implementation 'space.kscience:kmath-complex:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -29,6 +29,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-complex:0.3.0-dev-14") + implementation("space.kscience:kmath-complex:0.3.0-dev-17") } ``` diff --git a/kmath-core/README.md b/kmath-core/README.md index 4ea493f44..e765ad50c 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -15,7 +15,7 @@ performance calculations to code generation. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-17`. **Gradle:** ```gradle @@ -25,7 +25,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-core:0.3.0-dev-14' + implementation 'space.kscience:kmath-core:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -36,6 +36,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-core:0.3.0-dev-14") + implementation("space.kscience:kmath-core:0.3.0-dev-17") } ``` diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt index cd49e4519..fcfd57e82 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt @@ -9,8 +9,8 @@ import kotlin.jvm.JvmInline import kotlin.properties.ReadOnlyProperty /** - * A marker interface for a symbol. A symbol must have an identity. - * Ic + * A marker interface for a symbol. A symbol must have an identity with equality relation based on it. + * Other properties are to store additional, transient data only. */ public interface Symbol : MST { /** diff --git a/kmath-ejml/README.md b/kmath-ejml/README.md index f88f53000..fcd092bf1 100644 --- a/kmath-ejml/README.md +++ b/kmath-ejml/README.md @@ -9,7 +9,7 @@ EJML based linear algebra implementation. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-17`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-ejml:0.3.0-dev-14' + implementation 'space.kscience:kmath-ejml:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -30,6 +30,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-ejml:0.3.0-dev-14") + implementation("space.kscience:kmath-ejml:0.3.0-dev-17") } ``` diff --git a/kmath-for-real/README.md b/kmath-for-real/README.md index d449b4540..938327612 100644 --- a/kmath-for-real/README.md +++ b/kmath-for-real/README.md @@ -9,7 +9,7 @@ Specialization of KMath APIs for Double numbers. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-17`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-for-real:0.3.0-dev-14' + implementation 'space.kscience:kmath-for-real:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -30,6 +30,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-for-real:0.3.0-dev-14") + implementation("space.kscience:kmath-for-real:0.3.0-dev-17") } ``` diff --git a/kmath-functions/README.md b/kmath-functions/README.md index d0beae2c8..3d4beee47 100644 --- a/kmath-functions/README.md +++ b/kmath-functions/README.md @@ -11,7 +11,7 @@ Functions and interpolations. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-17`. **Gradle:** ```gradle @@ -21,7 +21,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-functions:0.3.0-dev-14' + implementation 'space.kscience:kmath-functions:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -32,6 +32,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-functions:0.3.0-dev-14") + implementation("space.kscience:kmath-functions:0.3.0-dev-17") } ``` diff --git a/kmath-jafama/README.md b/kmath-jafama/README.md index 3c5d4e19d..760244751 100644 --- a/kmath-jafama/README.md +++ b/kmath-jafama/README.md @@ -7,7 +7,7 @@ Integration with [Jafama](https://github.com/jeffhain/jafama). ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-jafama:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-jafama:0.3.0-dev-17`. **Gradle:** ```gradle @@ -17,7 +17,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-jafama:0.3.0-dev-14' + implementation 'space.kscience:kmath-jafama:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -28,7 +28,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-jafama:0.3.0-dev-14") + implementation("space.kscience:kmath-jafama:0.3.0-dev-17") } ``` diff --git a/kmath-kotlingrad/README.md b/kmath-kotlingrad/README.md index aeb44ea13..588ccb9b4 100644 --- a/kmath-kotlingrad/README.md +++ b/kmath-kotlingrad/README.md @@ -8,7 +8,7 @@ ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-kotlingrad:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-kotlingrad:0.3.0-dev-17`. **Gradle:** ```gradle @@ -18,7 +18,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-kotlingrad:0.3.0-dev-14' + implementation 'space.kscience:kmath-kotlingrad:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -29,6 +29,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-kotlingrad:0.3.0-dev-14") + implementation("space.kscience:kmath-kotlingrad:0.3.0-dev-17") } ``` diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index 5cbb31d5a..7ca9cd4fd 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -9,7 +9,7 @@ ND4J based implementations of KMath abstractions. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-17`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-nd4j:0.3.0-dev-14' + implementation 'space.kscience:kmath-nd4j:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -30,7 +30,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-nd4j:0.3.0-dev-14") + implementation("space.kscience:kmath-nd4j:0.3.0-dev-17") } ``` diff --git a/kmath-tensors/README.md b/kmath-tensors/README.md index b19a55381..42ce91336 100644 --- a/kmath-tensors/README.md +++ b/kmath-tensors/README.md @@ -9,7 +9,7 @@ Common linear algebra operations on tensors. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-14`. +The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-17`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-tensors:0.3.0-dev-14' + implementation 'space.kscience:kmath-tensors:0.3.0-dev-17' } ``` **Gradle Kotlin DSL:** @@ -30,6 +30,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-tensors:0.3.0-dev-14") + implementation("space.kscience:kmath-tensors:0.3.0-dev-17") } ```