Provide specializations of AsmBuilder for Double, Long, Int #437

Merged
CommanderTvis merged 1 commits from commandertvis/double-specialized into dev 2021-11-16 11:41:54 +03:00
25 changed files with 961 additions and 425 deletions

View File

@ -240,6 +240,12 @@ One can still use generic algebras though.
> **Maturity**: DEVELOPMENT > **Maturity**: DEVELOPMENT
<hr/> <hr/>
* ### [kmath-multik](kmath-multik)
>
>
> **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-nd4j](kmath-nd4j) * ### [kmath-nd4j](kmath-nd4j)
> >
> >
@ -252,6 +258,12 @@ One can still use generic algebras though.
<hr/> <hr/>
* ### [kmath-optimization](kmath-optimization)
>
>
> **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-stat](kmath-stat) * ### [kmath-stat](kmath-stat)
> >
> >
@ -319,8 +331,8 @@ repositories {
} }
dependencies { dependencies {
api("space.kscience:kmath-core:0.3.0-dev-14") api("space.kscience:kmath-core:0.3.0-dev-17")
// api("space.kscience:kmath-core-jvm:0.3.0-dev-14") for jvm-specific version // api("space.kscience:kmath-core-jvm:0.3.0-dev-17") for jvm-specific version
} }
``` ```

View File

@ -11,6 +11,7 @@ import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -35,7 +36,14 @@ internal class ExpressionsInterpretersBenchmark {
* Benchmark case for [Expression] created with [compileToExpression]. * Benchmark case for [Expression] created with [compileToExpression].
*/ */
@Benchmark @Benchmark
fun asmExpression(blackhole: Blackhole) = invokeAndSum(asm, blackhole) fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole)
/**
* Benchmark case for [Expression] created with [compileToExpression].
*/
@Benchmark
fun asmPrimitiveExpression(blackhole: Blackhole) = invokeAndSum(asmPrimitive, blackhole)
/** /**
* Benchmark case for [Expression] implemented manually with `kotlin.math` functions. * Benchmark case for [Expression] implemented manually with `kotlin.math` functions.
@ -87,7 +95,8 @@ internal class ExpressionsInterpretersBenchmark {
} }
private val mst = node.toExpression(DoubleField) private val mst = node.toExpression(DoubleField)
private val asm = node.compileToExpression(DoubleField) private val asmPrimitive = node.compileToExpression(DoubleField)
private val asmGeneric = node.compileToExpression(DoubleField as Algebra<Double>)
private val raw = Expression<Double> { args -> private val raw = Expression<Double> { args ->
val x = args[x]!! val x = args[x]!!

View File

@ -6,15 +6,15 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.MstField import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
fun main() { fun main() {
val expr = MstField { val expr = MstExtendedField {
x * 2.0 + number(2.0) / x - 16.0 x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x)
}.compileToExpression(DoubleField) }.compileToExpression(DoubleField)
val m = HashMap<Symbol, Double>() val m = HashMap<Symbol, Double>()

View File

@ -10,7 +10,7 @@ Performance and visualization extensions to MST API.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -20,7 +20,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-ast:0.3.0-dev-14' implementation 'space.kscience:kmath-ast:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -31,7 +31,7 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-ast:0.3.0-dev-14") implementation("space.kscience:kmath-ast:0.3.0-dev-17")
} }
``` ```

View File

@ -55,6 +55,10 @@ tasks.dokkaHtml {
dependsOn(tasks.build) dependsOn(tasks.build)
} }
tasks.jvmTest {
jvmArgs = (jvmArgs ?: emptyList()) + listOf("-Dspace.kscience.kmath.ast.dump.generated.classes=1")
}
readme { readme {
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))

View File

@ -44,6 +44,30 @@ internal class TestCompilerOperations {
assertEquals(1.0, expr(x to 0.0)) assertEquals(1.0, expr(x to 0.0))
} }
@Test
fun testTangent() = runCompilerTest {
val expr = MstExtendedField { tan(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testArcSine() = runCompilerTest {
val expr = MstExtendedField { asin(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testArcCosine() = runCompilerTest {
val expr = MstExtendedField { acos(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 1.0))
}
@Test
fun testAreaHyperbolicSine() = runCompilerTest {
val expr = MstExtendedField { asinh(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test @Test
fun testSubtract() = runCompilerTest { fun testSubtract() = runCompilerTest {
val expr = MstExtendedField { x - x }.compileToExpression(DoubleField) val expr = MstExtendedField { x - x }.compileToExpression(DoubleField)

View File

@ -18,12 +18,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
private val spreader = eval("(obj, args) => obj(...args)") private val spreader = eval("(obj, args) => obj(...args)")
@Suppress("UnsafeCastFromDynamic") @Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder<T>( internal sealed class WasmBuilder<T : Number>(
val binaryenType: Type, protected val binaryenType: Type,
val algebra: Algebra<T>, protected val algebra: Algebra<T>,
val target: MST, protected val target: MST,
) where T : Number { ) {
val keys: MutableList<Symbol> = mutableListOf() protected val keys: MutableList<Symbol> = mutableListOf()
lateinit var ctx: BinaryenModule lateinit var ctx: BinaryenModule
open fun visitSymbolic(mst: Symbol): ExpressionRef { open fun visitSymbolic(mst: Symbol): ExpressionRef {
@ -41,30 +41,36 @@ internal sealed class WasmBuilder<T>(
abstract fun visitNumeric(mst: Numeric): ExpressionRef abstract fun visitNumeric(mst: Numeric): ExpressionRef
open fun visitUnary(mst: Unary): ExpressionRef = protected open fun visitUnary(mst: Unary): ExpressionRef =
error("Unary operation ${mst.operation} not defined in $this") error("Unary operation ${mst.operation} not defined in $this")
open fun visitBinary(mst: Binary): ExpressionRef = protected open fun visitBinary(mst: Binary): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this") error("Binary operation ${mst.operation} not defined in $this")
open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
fun visit(mst: MST): ExpressionRef = when (mst) { protected fun visit(mst: MST): ExpressionRef = when (mst) {
is Symbol -> visitSymbolic(mst) is Symbol -> visitSymbolic(mst)
is Numeric -> visitNumeric(mst) is Numeric -> visitNumeric(mst)
is Unary -> when { is Unary -> when {
algebra is NumericAlgebra && mst.value is Numeric -> visitNumeric( algebra is NumericAlgebra && mst.value is Numeric -> visitNumeric(
Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value)))) Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value)))
)
else -> visitUnary(mst) else -> visitUnary(mst)
} }
is Binary -> when { is Binary -> when {
algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric(Numeric( algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric(
algebra.binaryOperationFunction(mst.operation) Numeric(
.invoke(algebra.number((mst.left as Numeric).value), algebra.number((mst.right as Numeric).value)) algebra.binaryOperationFunction(mst.operation)
)) .invoke(
algebra.number((mst.left as Numeric).value),
algebra.number((mst.right as Numeric).value)
)
)
)
else -> visitBinary(mst) else -> visitBinary(mst)
} }

View File

@ -16,24 +16,6 @@ import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.wasm.internal.DoubleWasmBuilder import space.kscience.kmath.wasm.internal.DoubleWasmBuilder
import space.kscience.kmath.wasm.internal.IntWasmBuilder import space.kscience.kmath.wasm.internal.IntWasmBuilder
/**
* Compiles an [MST] to WASM in the context of reals.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun DoubleField.expression(mst: MST): Expression<Double> =
DoubleWasmBuilder(mst).instance
/**
* Compiles an [MST] to WASM in the context of integers.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun IntRing.expression(mst: MST): Expression<Int> =
IntWasmBuilder(mst).instance
/** /**
* Create a compiled expression with given [MST] and given [algebra]. * Create a compiled expression with given [MST] and given [algebra].
* *

View File

@ -3,18 +3,18 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/ */
@file:Suppress("UNUSED_PARAMETER")
package space.kscience.kmath.asm package space.kscience.kmath.asm
import space.kscience.kmath.asm.internal.AsmBuilder import space.kscience.kmath.asm.internal.*
import space.kscience.kmath.asm.internal.buildName
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.MST.* import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.bindSymbolOrNull
/** /**
* Compiles given MST to an Expression using AST compiler. * Compiles given MST to an Expression using AST compiler.
@ -26,7 +26,7 @@ import space.kscience.kmath.operations.bindSymbolOrNull
*/ */
@PublishedApi @PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> { internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) { fun GenericAsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
is Symbol -> prepareVariable(node.identity) is Symbol -> prepareVariable(node.identity)
is Unary -> variablesVisitor(node.value) is Unary -> variablesVisitor(node.value)
@ -38,7 +38,7 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
else -> Unit else -> Unit
} }
fun AsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) { fun GenericAsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
is Symbol -> { is Symbol -> {
val symbol = algebra.bindSymbolOrNull(node) val symbol = algebra.bindSymbolOrNull(node)
@ -87,7 +87,7 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
} }
} }
return AsmBuilder<T>( return GenericAsmBuilder<T>(
type, type,
buildName(this), buildName(this),
{ variablesVisitor(this@compileWith) }, { variablesVisitor(this@compileWith) },
@ -114,3 +114,77 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments:
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra).invoke(*arguments)
/**
* Create a compiled expression with given [MST] and given [algebra].
*
* @author Iaroslav Postovalov
*/
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = IntAsmBuilder(this).instance
/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int =
compileToExpression(algebra)(*arguments)
/**
* Create a compiled expression with given [MST] and given [algebra].
*
* @author Iaroslav Postovalov
*/
public fun MST.compileToExpression(algebra: LongRing): Expression<Long> = LongAsmBuilder(this).instance
/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
compileToExpression(algebra).invoke(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>): Long =
compileToExpression(algebra)(*arguments)
/**
* Create a compiled expression with given [MST] and given [algebra].
*
* @author Iaroslav Postovalov
*/
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleAsmBuilder(this).instance
/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments)

View File

@ -1,377 +1,47 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.asm.internal package space.kscience.kmath.asm.internal
import org.objectweb.asm.* import org.objectweb.asm.Type
import org.objectweb.asm.Opcodes.* import space.kscience.kmath.expressions.Expression
import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader
import space.kscience.kmath.expressions.*
import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType
import java.nio.file.Paths
import java.util.stream.Collectors.toMap
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.io.path.writeBytes
/** internal abstract class AsmBuilder {
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
*
* @property T the type of AsmExpression to unwrap.
* @property className the unique class name of new loaded class.
* @property expressionResultCallback the function to apply to this object when generating expression value.
* @author Iaroslav Postovalov
*/
internal class AsmBuilder<T>(
classOfT: Class<*>,
private val className: String,
private val variablesPrepareCallback: AsmBuilder<T>.() -> Unit,
private val expressionResultCallback: AsmBuilder<T>.() -> Unit,
) {
/** /**
* Internal classloader of [AsmBuilder] with alias to define class from byte array. * Internal classloader with alias to define class from byte array.
*/ */
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { class ByteArrayClassLoader(parent: ClassLoader) : ClassLoader(parent) {
fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
} }
/** protected val classLoader = ByteArrayClassLoader(javaClass.classLoader)
* The instance of [ClassLoader] used by this builder.
*/
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
/**
* ASM type for [T].
*/
private val tType: Type = classOfT.asm
/**
* ASM type for new class.
*/
private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/'))
/**
* List of constants to provide to the subclass.
*/
private val constants: MutableList<Any> = mutableListOf()
/**
* Method visitor of `invoke` method of the subclass.
*/
private lateinit var invokeMethodVisitor: InstructionAdapter
/**
* Local variables indices are indices of symbols in this list.
*/
private val argumentsLocals = mutableListOf<String>()
/**
* Subclasses, loads and instantiates [Expression] for given parameters.
*
* The built instance is cached.
*/
@Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE")
val instance: Expression<T> by lazy {
val hasConstants: Boolean
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit(
V1_8,
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
classType.internalName,
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
OBJECT_TYPE.internalName,
arrayOf(EXPRESSION_TYPE.internalName),
)
visitMethod(
ACC_PUBLIC or ACC_FINAL,
"invoke",
getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
null,
).instructionAdapter {
invokeMethodVisitor = this
visitCode()
val preparingVariables = label()
variablesPrepareCallback()
val expressionResult = label()
expressionResultCallback()
areturn(tType)
val end = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
preparingVariables,
end,
0,
)
visitLocalVariable(
"arguments",
MAP_TYPE.descriptor,
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
preparingVariables,
end,
1,
)
visitMaxs(0, 2)
visitEnd()
}
visitMethod(
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
"invoke",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null,
null,
).instructionAdapter {
visitCode()
val l0 = label()
load(0, OBJECT_TYPE)
load(1, MAP_TYPE)
invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false)
areturn(tType)
val l1 = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
l0,
l1,
0,
)
visitMaxs(0, 2)
visitEnd()
}
hasConstants = constants.isNotEmpty()
if (hasConstants)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd,
)
visitMethod(
ACC_PUBLIC,
"<init>",
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
null,
null,
).instructionAdapter {
val l0 = label()
load(0, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
label()
load(0, classType)
if (hasConstants) {
label()
load(0, classType)
load(1, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
}
label()
visitInsn(RETURN)
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
if (hasConstants)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1)
visitMaxs(0, 3)
visitEnd()
}
visitEnd()
}
val binary = classWriter.toByteArray()
val cls = classLoader.defineClass(className, binary)
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
Paths.get("$className.class").writeBytes(binary)
val l = MethodHandles.publicLookup()
(if (hasConstants)
l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array<Any>::class.java))(constants.toTypedArray())
else
l.findConstructor(cls, MethodType.methodType(Void.TYPE))()) as Expression<T>
}
/**
* Loads [java.lang.Object] constant from constants.
*/
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
invokeMethodVisitor.load(0, classType)
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
iconst(idx)
visitInsn(AALOAD)
if (type != OBJECT_TYPE) checkcast(type)
}
/**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool.
*/
fun loadNumberConstant(value: Number) {
val boxed = value.javaClass.asm
val primitive = BOXED_TO_PRIMITIVES[boxed]
if (primitive != null) {
when (primitive) {
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
val r = PRIMITIVES_TO_BOXED.getValue(primitive)
invokeMethodVisitor.invokestatic(
r.internalName,
"valueOf",
getMethodDescriptor(r, primitive),
false,
)
return
}
loadObjectConstant(value, boxed)
}
/**
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable].
*/
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run
load(1, MAP_TYPE)
aconst(name)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
"getOrFail",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
false,
)
checkcast(tType)
var idx = argumentsLocals.indexOf(name)
if (idx == -1) {
argumentsLocals += name
idx = argumentsLocals.lastIndex
}
store(2 + idx, tType)
}
/**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first.
*/
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces }
val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount
?: error("Provided function object doesn't contain invoke method")
val type = getType(`interface`)
loadObjectConstant(function, type)
parameters(this)
invokeMethodVisitor.invokeinterface(
type.internalName,
"invoke",
getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }),
)
invokeMethodVisitor.checkcast(tType)
}
companion object {
/**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
hashMapOf(
Byte::class.java.asm to BYTE_TYPE,
Short::class.java.asm to SHORT_TYPE,
Integer::class.java.asm to INT_TYPE,
Long::class.java.asm to LONG_TYPE,
Float::class.java.asm to FLOAT_TYPE,
Double::class.java.asm to DOUBLE_TYPE,
)
}
/**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
BOXED_TO_PRIMITIVES.entries.stream().collect(
toMap(Map.Entry<Type, Type>::value, Map.Entry<Type, Type>::key),
)
}
protected companion object {
/** /**
* ASM type for [Expression]. * ASM type for [Expression].
*/ */
val EXPRESSION_TYPE: Type by lazy { getObjectType("space/kscience/kmath/expressions/Expression") } val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Expression") }
/** /**
* ASM type for [java.util.Map]. * ASM type for [java.util.Map].
*/ */
val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") } val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") }
/** /**
* ASM type for [java.lang.Object]. * ASM type for [java.lang.Object].
*/ */
val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") } val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") }
/**
* ASM type for array of [java.lang.Object].
*/
val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") }
/** /**
* ASM type for [java.lang.String]. * ASM type for [java.lang.String].
*/ */
val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") } val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") }
/** /**
* ASM type for MapIntrinsics. * ASM type for MapIntrinsics.
*/ */
val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") } val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") }
/** /**
* ASM Type for [space.kscience.kmath.expressions.Symbol]. * ASM Type for [space.kscience.kmath.expressions.Symbol].
*/ */
val SYMBOL_TYPE: Type by lazy { getObjectType("space/kscience/kmath/expressions/Symbol") } val SYMBOL_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Symbol") }
} }
} }

View File

@ -0,0 +1 @@
package space.kscience.kmath.asm.internal

View File

@ -0,0 +1,325 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.asm.internal
import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.expressions.*
import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType
import java.nio.file.Paths
import java.util.stream.Collectors.toMap
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.io.path.writeBytes
/**
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
*
* @property T the type of AsmExpression to unwrap.
* @property className the unique class name of new loaded class.
* @property expressionResultCallback the function to apply to this object when generating expression value.
* @author Iaroslav Postovalov
*/
internal class GenericAsmBuilder<T>(
classOfT: Class<*>,
private val className: String,
private val variablesPrepareCallback: GenericAsmBuilder<T>.() -> Unit,
private val expressionResultCallback: GenericAsmBuilder<T>.() -> Unit,
) : AsmBuilder() {
/**
* ASM type for [T].
*/
private val tType: Type = classOfT.asm
/**
* ASM type for new class.
*/
private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/'))
/**
* List of constants to provide to the subclass.
*/
private val constants = mutableListOf<Any>()
/**
* Method visitor of `invoke` method of the subclass.
*/
private lateinit var invokeMethodVisitor: InstructionAdapter
/**
* Local variables indices are indices of symbols in this list.
*/
private val argumentsLocals = mutableListOf<String>()
/**
* Subclasses, loads and instantiates [Expression] for given parameters.
*
* The built instance is cached.
*/
@Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE")
val instance: Expression<T> by lazy {
val hasConstants: Boolean
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit(
V1_8,
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
classType.internalName,
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
OBJECT_TYPE.internalName,
arrayOf(EXPRESSION_TYPE.internalName),
)
visitMethod(
ACC_PUBLIC or ACC_FINAL,
"invoke",
getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
null,
).instructionAdapter {
invokeMethodVisitor = this
visitCode()
val preparingVariables = label()
variablesPrepareCallback()
val expressionResult = label()
expressionResultCallback()
areturn(tType)
val end = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
preparingVariables,
end,
0,
)
visitLocalVariable(
"arguments",
MAP_TYPE.descriptor,
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
preparingVariables,
end,
1,
)
visitMaxs(0, 0)
visitEnd()
}
visitMethod(
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
"invoke",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null,
null,
).instructionAdapter {
visitCode()
val start = label()
load(0, OBJECT_TYPE)
load(1, MAP_TYPE)
invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false)
areturn(tType)
val end = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
start,
end,
0,
)
visitMaxs(0, 0)
visitEnd()
}
hasConstants = constants.isNotEmpty()
if (hasConstants)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd,
)
visitMethod(
ACC_PUBLIC,
"<init>",
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
null,
null,
).instructionAdapter {
val l0 = label()
load(0, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
label()
load(0, classType)
if (hasConstants) {
label()
load(0, classType)
load(1, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
}
label()
visitInsn(RETURN)
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
if (hasConstants)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1)
visitMaxs(0, 0)
visitEnd()
}
visitEnd()
}
val binary = classWriter.toByteArray()
val cls = classLoader.defineClass(className, binary)
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
Paths.get("${className.split('.').last()}.class").writeBytes(binary)
val l = MethodHandles.publicLookup()
(if (hasConstants)
l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array<Any>::class.java))(constants.toTypedArray())
else
l.findConstructor(cls, MethodType.methodType(Void.TYPE))()) as Expression<T>
}
/**
* Loads [java.lang.Object] constant from constants.
*/
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
invokeMethodVisitor.load(0, classType)
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
iconst(idx)
visitInsn(AALOAD)
if (type != OBJECT_TYPE) checkcast(type)
}
/**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool.
*/
fun loadNumberConstant(value: Number) {
val boxed = value.javaClass.asm
val primitive = BOXED_TO_PRIMITIVES[boxed]
if (primitive != null) {
when (primitive) {
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
val r = boxed
invokeMethodVisitor.invokestatic(
r.internalName,
"valueOf",
getMethodDescriptor(r, primitive),
false,
)
return
}
loadObjectConstant(value, boxed)
}
/**
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable].
*/
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run
load(1, MAP_TYPE)
aconst(name)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
"getOrFail",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
false,
)
checkcast(tType)
var idx = argumentsLocals.indexOf(name)
if (idx == -1) {
argumentsLocals += name
idx = argumentsLocals.lastIndex
}
store(2 + idx, tType)
}
/**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first.
*/
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces }
val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount
?: error("Provided function object doesn't contain invoke method")
val type = getType(`interface`)
loadObjectConstant(function, type)
parameters(this)
invokeMethodVisitor.invokeinterface(
type.internalName,
"invoke",
getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }),
)
invokeMethodVisitor.checkcast(tType)
}
private companion object {
/**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
hashMapOf(
Byte::class.java.asm to BYTE_TYPE,
Short::class.java.asm to SHORT_TYPE,
Integer::class.java.asm to INT_TYPE,
Long::class.java.asm to LONG_TYPE,
Float::class.java.asm to FLOAT_TYPE,
Double::class.java.asm to DOUBLE_TYPE,
)
}
/**
* ASM type for array of [java.lang.Object].
*/
val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") }
}
}

View File

@ -0,0 +1,411 @@
package space.kscience.kmath.asm.internal
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.*
import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType
import java.nio.file.Paths
import kotlin.io.path.writeBytes
internal sealed class PrimitiveAsmBuilder<T : Number>(
protected val algebra: Algebra<T>,
classOfT: Class<*>,
protected val classOfTPrimitive: Class<*>,
protected val target: MST,
) : AsmBuilder() {
private val className: String = buildName(target)
/**
* ASM type for [T].
*/
private val tType: Type = classOfT.asm
/**
* ASM type for [T].
*/
protected val tTypePrimitive: Type = classOfTPrimitive.asm
/**
* ASM type for new class.
*/
private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/'))
/**
* Method visitor of `invoke` method of the subclass.
*/
protected lateinit var invokeMethodVisitor: InstructionAdapter
/**
* Local variables indices are indices of symbols in this list.
*/
private val argumentsLocals = mutableListOf<String>()
/**
* Subclasses, loads and instantiates [Expression] for given parameters.
*
* The built instance is cached.
*/
@Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE")
val instance: Expression<T> by lazy {
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit(
Opcodes.V1_8,
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
classType.internalName,
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
OBJECT_TYPE.internalName,
arrayOf(EXPRESSION_TYPE.internalName),
)
visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
"invoke",
getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
null,
).instructionAdapter {
invokeMethodVisitor = this
visitCode()
val preparingVariables = label()
visitVariables(target)
val expressionResult = label()
visitExpression(target)
box()
areturn(tType)
val end = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
preparingVariables,
end,
0,
)
visitLocalVariable(
"arguments",
MAP_TYPE.descriptor,
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
preparingVariables,
end,
1,
)
visitMaxs(0, 0)
visitEnd()
}
visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
"invoke",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null,
null,
).instructionAdapter {
visitCode()
val start = label()
load(0, OBJECT_TYPE)
load(1, MAP_TYPE)
invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false)
areturn(tType)
val end = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
start,
end,
0,
)
visitMaxs(0, 0)
visitEnd()
}
visitMethod(
Opcodes.ACC_PUBLIC,
"<init>",
getMethodDescriptor(VOID_TYPE),
null,
null,
).instructionAdapter {
val start = label()
load(0, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
label()
load(0, classType)
label()
visitInsn(Opcodes.RETURN)
val end = label()
visitLocalVariable("this", classType.descriptor, null, start, end, 0)
visitMaxs(0, 0)
visitEnd()
}
visitEnd()
}
val binary = classWriter.toByteArray()
val cls = classLoader.defineClass(className, binary)
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
Paths.get("${className.split('.').last()}.class").writeBytes(binary)
MethodHandles.publicLookup().findConstructor(cls, MethodType.methodType(Void.TYPE))() as Expression<T>
}
/**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool.
*/
fun loadNumberConstant(value: Number) {
when (tTypePrimitive) {
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
}
/**
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable].
*/
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run
load(1, MAP_TYPE)
aconst(name)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
"getOrFail",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
false,
)
checkcast(tType)
var idx = argumentsLocals.indexOf(name)
if (idx == -1) {
argumentsLocals += name
idx = argumentsLocals.lastIndex
}
unbox()
store(2 + idx, tTypePrimitive)
}
/**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first.
*/
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tTypePrimitive)
private fun unbox() = invokeMethodVisitor.run {
invokevirtual(
NUMBER_TYPE.internalName,
"${classOfTPrimitive.simpleName}Value",
getMethodDescriptor(tTypePrimitive),
false
)
}
private fun box() = invokeMethodVisitor.run {
invokestatic(tType.internalName, "valueOf", getMethodDescriptor(tType, tTypePrimitive), false)
}
protected fun visitVariables(node: MST): Unit = when (node) {
is Symbol -> prepareVariable(node.identity)
is MST.Unary -> visitVariables(node.value)
is MST.Binary -> {
visitVariables(node.left)
visitVariables(node.right)
}
else -> Unit
}
protected fun visitExpression(mst: MST): Unit = when (mst) {
is Symbol -> loadVariable(mst.identity)
is MST.Numeric -> loadNumberConstant(mst.value)
is MST.Unary -> when {
algebra is NumericAlgebra && mst.value is MST.Numeric -> {
loadNumberConstant(
MST.Numeric(
algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as MST.Numeric).value)),
).value,
)
}
else -> visitUnary(mst)
}
is MST.Binary -> when {
algebra is NumericAlgebra && mst.left is MST.Numeric && mst.right is MST.Numeric -> {
loadNumberConstant(
MST.Numeric(
algebra.binaryOperationFunction(mst.operation)(
algebra.number((mst.left as MST.Numeric).value),
algebra.number((mst.right as MST.Numeric).value),
),
).value,
)
}
else -> visitBinary(mst)
}
}
protected open fun visitUnary(mst: MST.Unary) {
visitExpression(mst.value)
}
protected open fun visitBinary(mst: MST.Binary) {
visitExpression(mst.left)
visitExpression(mst.right)
}
protected companion object {
/**
* ASM type for [java.lang.Number].
*/
val NUMBER_TYPE: Type by lazy { getObjectType("java/lang/Number") }
}
}
internal class DoubleAsmBuilder(target: MST) :
PrimitiveAsmBuilder<Double>(DoubleField, java.lang.Double::class.java, java.lang.Double.TYPE, target) {
private fun buildUnaryJavaMathCall(name: String) {
invokeMethodVisitor.invokestatic(
MATH_TYPE.internalName,
name,
getMethodDescriptor(tTypePrimitive, tTypePrimitive),
false,
)
}
private fun buildBinaryJavaMathCall(name: String) {
invokeMethodVisitor.invokestatic(
MATH_TYPE.internalName,
name,
getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive),
false,
)
}
private fun buildUnaryKotlinMathCall(name: String) {
invokeMethodVisitor.invokestatic(
MATH_KT_TYPE.internalName,
name,
getMethodDescriptor(tTypePrimitive, tTypePrimitive),
false,
)
}
override fun visitUnary(mst: MST.Unary) {
super.visitUnary(mst)
when (mst.operation) {
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DNEG)
GroupOps.PLUS_OPERATION -> Unit
PowerOperations.SQRT_OPERATION -> buildUnaryJavaMathCall("sqrt")
TrigonometricOperations.SIN_OPERATION -> buildUnaryJavaMathCall("sin")
TrigonometricOperations.COS_OPERATION -> buildUnaryJavaMathCall("cos")
TrigonometricOperations.TAN_OPERATION -> buildUnaryJavaMathCall("tan")
TrigonometricOperations.ASIN_OPERATION -> buildUnaryJavaMathCall("asin")
TrigonometricOperations.ACOS_OPERATION -> buildUnaryJavaMathCall("acos")
TrigonometricOperations.ATAN_OPERATION -> buildUnaryJavaMathCall("atan")
ExponentialOperations.SINH_OPERATION -> buildUnaryJavaMathCall("sqrt")
ExponentialOperations.COSH_OPERATION -> buildUnaryJavaMathCall("cosh")
ExponentialOperations.TANH_OPERATION -> buildUnaryJavaMathCall("tanh")
ExponentialOperations.ASINH_OPERATION -> buildUnaryKotlinMathCall("asinh")
ExponentialOperations.ACOSH_OPERATION -> buildUnaryKotlinMathCall("acosh")
ExponentialOperations.ATANH_OPERATION -> buildUnaryKotlinMathCall("atanh")
ExponentialOperations.EXP_OPERATION -> buildUnaryJavaMathCall("exp")
ExponentialOperations.LN_OPERATION -> buildUnaryJavaMathCall("log")
else -> super.visitUnary(mst)
}
}
override fun visitBinary(mst: MST.Binary) {
super.visitBinary(mst)
when (mst.operation) {
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DADD)
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DSUB)
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DMUL)
FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DDIV)
PowerOperations.POW_OPERATION -> buildBinaryJavaMathCall("pow")
else -> super.visitBinary(mst)
}
}
companion object {
val MATH_TYPE: Type by lazy { getObjectType("java/lang/Math") }
val MATH_KT_TYPE: Type by lazy { getObjectType("kotlin/math/MathKt") }
}
}
internal class IntAsmBuilder(target: MST) :
PrimitiveAsmBuilder<Int>(IntRing, Integer::class.java, Integer.TYPE, target) {
override fun visitUnary(mst: MST.Unary) {
super.visitUnary(mst)
when (mst.operation) {
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.INEG)
GroupOps.PLUS_OPERATION -> Unit
else -> super.visitUnary(mst)
}
}
override fun visitBinary(mst: MST.Binary) {
super.visitBinary(mst)
when (mst.operation) {
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IADD)
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.ISUB)
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IMUL)
else -> super.visitBinary(mst)
}
}
}
internal class LongAsmBuilder(target: MST) :
PrimitiveAsmBuilder<Long>(LongRing, java.lang.Long::class.java, java.lang.Long.TYPE, target) {
override fun visitUnary(mst: MST.Unary) {
super.visitUnary(mst)
when (mst.operation) {
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LNEG)
GroupOps.PLUS_OPERATION -> Unit
else -> super.visitUnary(mst)
}
}
override fun visitBinary(mst: MST.Binary) {
super.visitBinary(mst)
when (mst.operation) {
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LADD)
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LSUB)
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LMUL)
else -> super.visitBinary(mst)
}
}
}

View File

@ -14,4 +14,5 @@ import space.kscience.kmath.expressions.Symbol
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@Suppress("unused")
internal fun <V> Map<Symbol, V>.getOrFail(key: String): V = getValue(Symbol(key)) internal fun <V> Map<Symbol, V>.getOrFail(key: String): V = getValue(Symbol(key))

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
@ -15,7 +16,21 @@ import kotlin.contracts.contract
import space.kscience.kmath.asm.compile as asmCompile import space.kscience.kmath.asm.compile as asmCompile
import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression
private object AsmCompilerTestContext : CompilerTestContext { private object GenericAsmCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> =
asmCompileToExpression(algebra as Algebra<Int>)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
asmCompile(algebra as Algebra<Int>, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> =
asmCompileToExpression(algebra as Algebra<Double>)
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
asmCompile(algebra as Algebra<Double>, arguments)
}
private object PrimitiveAsmCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = asmCompileToExpression(algebra) override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = asmCompileToExpression(algebra)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = asmCompile(algebra, arguments) override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = asmCompile(algebra, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = asmCompileToExpression(algebra) override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = asmCompileToExpression(algebra)
@ -24,7 +39,9 @@ private object AsmCompilerTestContext : CompilerTestContext {
asmCompile(algebra, arguments) asmCompile(algebra, arguments)
} }
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) { internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
action(AsmCompilerTestContext) action(GenericAsmCompilerTestContext)
action(PrimitiveAsmCompilerTestContext)
} }

View File

@ -8,7 +8,7 @@ Complex and hypercomplex number systems in KMath.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -18,7 +18,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-complex:0.3.0-dev-14' implementation 'space.kscience:kmath-complex:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -29,6 +29,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-complex:0.3.0-dev-14") implementation("space.kscience:kmath-complex:0.3.0-dev-17")
} }
``` ```

View File

@ -15,7 +15,7 @@ performance calculations to code generation.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -25,7 +25,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-core:0.3.0-dev-14' implementation 'space.kscience:kmath-core:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -36,6 +36,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-core:0.3.0-dev-14") implementation("space.kscience:kmath-core:0.3.0-dev-17")
} }
``` ```

View File

@ -9,8 +9,8 @@ import kotlin.jvm.JvmInline
import kotlin.properties.ReadOnlyProperty import kotlin.properties.ReadOnlyProperty
/** /**
* A marker interface for a symbol. A symbol must have an identity. * A marker interface for a symbol. A symbol must have an identity with equality relation based on it.
* Ic * Other properties are to store additional, transient data only.
*/ */
public interface Symbol : MST { public interface Symbol : MST {
/** /**

View File

@ -9,7 +9,7 @@ EJML based linear algebra implementation.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -19,7 +19,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-ejml:0.3.0-dev-14' implementation 'space.kscience:kmath-ejml:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -30,6 +30,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-ejml:0.3.0-dev-14") implementation("space.kscience:kmath-ejml:0.3.0-dev-17")
} }
``` ```

View File

@ -9,7 +9,7 @@ Specialization of KMath APIs for Double numbers.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -19,7 +19,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-for-real:0.3.0-dev-14' implementation 'space.kscience:kmath-for-real:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -30,6 +30,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-for-real:0.3.0-dev-14") implementation("space.kscience:kmath-for-real:0.3.0-dev-17")
} }
``` ```

View File

@ -11,7 +11,7 @@ Functions and interpolations.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -21,7 +21,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-functions:0.3.0-dev-14' implementation 'space.kscience:kmath-functions:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -32,6 +32,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-functions:0.3.0-dev-14") implementation("space.kscience:kmath-functions:0.3.0-dev-17")
} }
``` ```

View File

@ -7,7 +7,7 @@ Integration with [Jafama](https://github.com/jeffhain/jafama).
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-jafama:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-jafama:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -17,7 +17,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-jafama:0.3.0-dev-14' implementation 'space.kscience:kmath-jafama:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -28,7 +28,7 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-jafama:0.3.0-dev-14") implementation("space.kscience:kmath-jafama:0.3.0-dev-17")
} }
``` ```

View File

@ -8,7 +8,7 @@
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-kotlingrad:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-kotlingrad:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -18,7 +18,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-kotlingrad:0.3.0-dev-14' implementation 'space.kscience:kmath-kotlingrad:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -29,6 +29,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-kotlingrad:0.3.0-dev-14") implementation("space.kscience:kmath-kotlingrad:0.3.0-dev-17")
} }
``` ```

View File

@ -9,7 +9,7 @@ ND4J based implementations of KMath abstractions.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -19,7 +19,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-nd4j:0.3.0-dev-14' implementation 'space.kscience:kmath-nd4j:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -30,7 +30,7 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-nd4j:0.3.0-dev-14") implementation("space.kscience:kmath-nd4j:0.3.0-dev-17")
} }
``` ```

View File

@ -9,7 +9,7 @@ Common linear algebra operations on tensors.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-14`. The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-17`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -19,7 +19,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-tensors:0.3.0-dev-14' implementation 'space.kscience:kmath-tensors:0.3.0-dev-17'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -30,6 +30,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-tensors:0.3.0-dev-14") implementation("space.kscience:kmath-tensors:0.3.0-dev-17")
} }
``` ```