Implement Commons RNG-like samplers in kmath-prob module for Multiplatform #164

Merged
CommanderTvis merged 44 commits from feature/mp-samplers into dev 2021-03-31 09:25:44 +03:00
9 changed files with 278 additions and 161 deletions
Showing only changes of commit 019f60c721 - Show all commits

View File

@ -1,13 +1,44 @@
# AST-based expression representation and operations (`kmath-ast`) # Abstract syntax tree expression representation and operations (`kmath-ast`)
This subproject implements the following features: This subproject implements the following features:
- Expression Language and its parser. - Expression Language and its parser.
- MST as expression language's syntax intermediate representation. - MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation.
- Type-safe builder of MST. - Type-safe builder for MST.
- Evaluating expressions by traversing MST. - Evaluating expressions by traversing MST.
## Dynamic expression code generation with OW2 ASM > #### Artifact:
> This module is distributed in the artifact `scientifik:kmath-ast:0.1.4-dev-8`.
>
> **Gradle:**
>
> ```gradle
> repositories {
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url https://dl.bintray.com/hotkeytlt/maven' }
> }
>
> dependencies {
> implementation 'scientifik:kmath-ast:0.1.4-dev-8'
> }
> ```
> **Gradle Kotlin DSL:**
>
> ```kotlin
> repositories {
> maven("https://dl.bintray.com/mipt-npm/scientifik")
> maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven")
> }
>
> dependencies {
> implementation("scientifik:kmath-ast:0.1.4-dev-8")
> }
> ```
>
## Dynamic expression code generation with ObjectWeb ASM
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
a special implementation of `Expression<T>` with implemented `invoke` function. a special implementation of `Expression<T>` with implemented `invoke` function.
@ -30,15 +61,13 @@ import scientifik.kmath.operations.RealField;
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> { public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
private final RealField algebra; private final RealField algebra;
private final Object[] constants;
public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) {
this.algebra = algebra;
this.constants = constants;
}
public final Double invoke(Map<String, ? extends Double> arguments) { public final Double invoke(Map<String, ? extends Double> arguments) {
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D); return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D);
}
public AsmCompiledExpression_1073786867_0(RealField algebra) {
this.algebra = algebra;
} }
} }
@ -46,7 +75,7 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression<Doub
### Example Usage ### Example Usage
This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression: This API is an extension to MST and MstExpression, so you may optimize as both of them:
```kotlin ```kotlin
RealField.mstInField { symbol("x") + 2 }.compile() RealField.mstInField { symbol("x") + 2 }.compile()

View File

@ -1,6 +1,7 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.MstType
import scientifik.kmath.asm.internal.buildAlgebraOperationCall import scientifik.kmath.asm.internal.buildAlgebraOperationCall
import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.buildName
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
@ -17,28 +18,20 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> loadVariable(node.value)
is MST.Numeric -> loadNumeric(node.value)
is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>)
algebra.number(node.value)
else
error("Number literals are not supported in $algebra")
loadTConstant(constant)
}
is MST.Unary -> buildAlgebraOperationCall( is MST.Unary -> buildAlgebraOperationCall(
context = algebra, context = algebra,
name = node.operation, name = node.operation,
fallbackMethodName = "unaryOperation", fallbackMethodName = "unaryOperation",
arity = 1 parameterTypes = arrayOf(MstType.fromMst(node.value))
) { visit(node.value) } ) { visit(node.value) }
is MST.Binary -> buildAlgebraOperationCall( is MST.Binary -> buildAlgebraOperationCall(
context = algebra, context = algebra,
name = node.operation, name = node.operation,
fallbackMethodName = "binaryOperation", fallbackMethodName = "binaryOperation",
arity = 2 parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
) { ) {
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)

View File

@ -7,17 +7,19 @@ import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import java.util.* import java.util.*
import java.util.stream.Collectors
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
* *
* @param T the type of AsmExpression to unwrap. * @property T the type of AsmExpression to unwrap.
* @param algebra the algebra the applied AsmExpressions use. * @property algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class. * @property className the unique class name of new loaded class.
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
*/ */
internal class AsmBuilder<T> internal constructor( internal class AsmBuilder<T> internal constructor(
private val classOfT: KClass<*>, private val classOfT: KClass<*>,
@ -38,17 +40,17 @@ internal class AsmBuilder<T> internal constructor(
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
/** /**
* ASM Type for [algebra] * ASM Type for [algebra].
*/ */
private val tAlgebraType: Type = algebra::class.asm private val tAlgebraType: Type = algebra::class.asm
/** /**
* ASM type for [T] * ASM type for [T].
*/ */
internal val tType: Type = classOfT.asm internal val tType: Type = classOfT.asm
/** /**
* ASM type for new class * ASM type for new class.
*/ */
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
@ -72,6 +74,11 @@ internal class AsmBuilder<T> internal constructor(
*/ */
private lateinit var invokeMethodVisitor: InstructionAdapter private lateinit var invokeMethodVisitor: InstructionAdapter
/**
* State if this [AsmBuilder] needs to generate constants field.
*/
private var hasConstants: Boolean = true
/** /**
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
*/ */
@ -95,7 +102,7 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Stack of useful objects types on stack expected by algebra calls. * Stack of useful objects types on stack expected by algebra calls.
*/ */
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>().apply { push(tType) } internal val expectationStack: ArrayDeque<Type> = ArrayDeque(listOf(tType))
/** /**
* The cache for instance built by this builder. * The cache for instance built by this builder.
@ -108,7 +115,7 @@ internal class AsmBuilder<T> internal constructor(
* The built instance is cached. * The built instance is cached.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
fun getInstance(): Expression<T> { internal fun getInstance(): Expression<T> {
generatedInstance?.let { return it } generatedInstance?.let { return it }
if (SIGNATURE_LETTERS.containsKey(classOfT)) { if (SIGNATURE_LETTERS.containsKey(classOfT)) {
@ -127,64 +134,6 @@ internal class AsmBuilder<T> internal constructor(
arrayOf(EXPRESSION_TYPE.internalName) arrayOf(EXPRESSION_TYPE.internalName)
) )
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "algebra",
descriptor = tAlgebraType.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitMethod(
ACC_PUBLIC,
"<init>",
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
null,
null
).instructionAdapter {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = label()
load(thisVar, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
label()
load(thisVar, classType)
load(algebraVar, tAlgebraType)
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
label()
load(thisVar, classType)
load(constantsVar, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
label()
visitInsn(RETURN)
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
visitLocalVariable(
"algebra",
tAlgebraType.descriptor,
null,
l0,
l4,
algebraVar
)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
visitMaxs(0, 3)
visitEnd()
}
visitMethod( visitMethod(
ACC_PUBLIC or ACC_FINAL, ACC_PUBLIC or ACC_FINAL,
"invoke", "invoke",
@ -251,6 +200,78 @@ internal class AsmBuilder<T> internal constructor(
visitEnd() visitEnd()
} }
hasConstants = constants.isNotEmpty()
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "algebra",
descriptor = tAlgebraType.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
if (hasConstants)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitMethod(
ACC_PUBLIC,
"<init>",
Type.getMethodDescriptor(
Type.VOID_TYPE,
tAlgebraType,
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
null,
null
).instructionAdapter {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = label()
load(thisVar, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
label()
load(thisVar, classType)
load(algebraVar, tAlgebraType)
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
if (hasConstants) {
label()
load(thisVar, classType)
load(constantsVar, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
}
label()
visitInsn(RETURN)
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
visitLocalVariable(
"algebra",
tAlgebraType.descriptor,
null,
l0,
l4,
algebraVar
)
if (hasConstants)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
visitMaxs(0, 3)
visitEnd()
}
visitEnd() visitEnd()
} }
@ -258,7 +279,7 @@ internal class AsmBuilder<T> internal constructor(
.defineClass(className, classWriter.toByteArray()) .defineClass(className, classWriter.toByteArray())
.constructors .constructors
.first() .first()
.newInstance(algebra, constants.toTypedArray()) as Expression<T> .newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
generatedInstance = new generatedInstance = new
return new return new
@ -267,42 +288,50 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Loads a [T] constant from [constants]. * Loads a [T] constant from [constants].
*/ */
internal fun loadTConstant(value: T) { private fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop() val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT val mustBeBoxed = expectedType.sort == Type.OBJECT
loadNumberConstant(value as Number, mustBeBoxed) loadNumberConstant(value as Number, mustBeBoxed)
if (mustBeBoxed)
invokeMethodVisitor.checkcast(tType)
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
return return
} }
loadConstant(value as Any, tType) loadObjectConstant(value as Any, tType)
} }
/** /**
* Boxes the current value and pushes it. * Boxes the current value and pushes it.
*/ */
private fun box(): Unit = invokeMethodVisitor.invokestatic( private fun box(primitive: Type) {
tType.internalName, val r = PRIMITIVES_TO_BOXED.getValue(primitive)
invokeMethodVisitor.invokestatic(
r.internalName,
"valueOf", "valueOf",
Type.getMethodDescriptor(tType, primitiveMask), Type.getMethodDescriptor(r, primitive),
false false
) )
}
/** /**
* Unboxes the current boxed value and pushes it. * Unboxes the current boxed value and pushes it.
*/ */
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual(
NUMBER_TYPE.internalName, NUMBER_TYPE.internalName,
NUMBER_CONVERTER_METHODS.getValue(primitiveMask), NUMBER_CONVERTER_METHODS.getValue(primitive),
Type.getMethodDescriptor(primitiveMask), Type.getMethodDescriptor(primitive),
false false
) )
/** /**
* Loads [java.lang.Object] constant from constants. * Loads [java.lang.Object] constant from constants.
*/ */
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
loadThis() loadThis()
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
@ -311,6 +340,15 @@ internal class AsmBuilder<T> internal constructor(
checkcast(type) checkcast(type)
} }
fun loadNumeric(value: Number) {
if (expectationStack.peek() == NUMBER_TYPE) {
loadNumberConstant(value, true)
expectationStack.pop()
typeStack.push(NUMBER_TYPE)
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
}
/** /**
* Loads this variable. * Loads this variable.
*/ */
@ -335,18 +373,16 @@ internal class AsmBuilder<T> internal constructor(
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
} }
if (mustBeBoxed) { if (mustBeBoxed)
box() box(primitive)
invokeMethodVisitor.checkcast(tType)
}
return return
} }
loadConstant(value, boxed) loadObjectConstant(value, boxed)
if (!mustBeBoxed) unbox() if (!mustBeBoxed)
else invokeMethodVisitor.checkcast(tType) unboxTo(primitiveMask)
} }
/** /**
@ -359,24 +395,26 @@ internal class AsmBuilder<T> internal constructor(
if (defaultValue != null) if (defaultValue != null)
loadTConstant(defaultValue) loadTConstant(defaultValue)
else
aconst(null)
invokestatic( invokestatic(
MAP_INTRINSICS_TYPE.internalName, MAP_INTRINSICS_TYPE.internalName,
"getOrFail", "getOrFail",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
Type.getMethodDescriptor(
OBJECT_TYPE,
MAP_TYPE,
OBJECT_TYPE,
*OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
false false
) )
checkcast(tType) checkcast(tType)
val expectedType = expectationStack.pop() val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT) if (expectedType.sort == Type.OBJECT)
typeStack.push(tType) typeStack.push(tType)
else { else {
unbox() unboxTo(primitiveMask)
typeStack.push(primitiveMask) typeStack.push(primitiveMask)
} }
} }
@ -425,7 +463,7 @@ internal class AsmBuilder<T> internal constructor(
if (expectedType.sort == Type.OBJECT || isLastExpr) if (expectedType.sort == Type.OBJECT || isLastExpr)
typeStack.push(tType) typeStack.push(tType)
else { else {
unbox() unboxTo(primitiveMask)
typeStack.push(primitiveMask) typeStack.push(primitiveMask)
} }
} }
@ -455,6 +493,18 @@ internal class AsmBuilder<T> internal constructor(
*/ */
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
/**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
BOXED_TO_PRIMITIVES.entries.stream().collect(
Collectors.toMap(
Map.Entry<Type, Type>::value,
Map.Entry<Type, Type>::key
)
)
}
/** /**
* Maps primitive ASM types to [Number] functions unboxing them. * Maps primitive ASM types to [Number] functions unboxing them.
*/ */

View File

@ -0,0 +1,17 @@
package scientifik.kmath.asm.internal
import scientifik.kmath.ast.MST
internal enum class MstType {
GENERAL,
NUMBER;
companion object {
fun fromMst(mst: MST): MstType {
if (mst is MST.Numeric)
return NUMBER
return GENERAL
}
}
}

View File

@ -6,6 +6,7 @@ import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import java.lang.reflect.Method
import kotlin.reflect.KClass import kotlin.reflect.KClass
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy { private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
@ -22,6 +23,12 @@ private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
internal val KClass<*>.asm: Type internal val KClass<*>.asm: Type
get() = Type.getType(java) get() = Type.getType(java)
/**
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
*/
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> =
if (predicate(this)) arrayOf(this) else emptyArray()
/** /**
* Creates an [InstructionAdapter] from this [MethodVisitor]. * Creates an [InstructionAdapter] from this [MethodVisitor].
*/ */
@ -36,11 +43,7 @@ internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Un
/** /**
* Constructs a [Label], then applies it to this visitor. * Constructs a [Label], then applies it to this visitor.
*/ */
internal fun MethodVisitor.label(): Label { internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
val l = Label()
visitLabel(l)
return l
}
/** /**
* Creates a class name for [Expression] subclassed to implement [mst] provided. * Creates a class name for [Expression] subclassed to implement [mst] provided.
@ -73,44 +76,71 @@ internal inline fun ClassWriter.visitField(
block: FieldVisitor.() -> Unit block: FieldVisitor.() -> Unit
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) ): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
context.javaClass.methods.find { method ->
val nameValid = method.name == name
val arityValid = method.parameters.size == parameterTypes.size
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
}
nameValid && arityValid && notBridgeInPrimitive && paramsValid
}
/** /**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
* type expectation stack for needed arity. * type expectation stack for needed arity.
* *
* @return `true` if contains, else `false`. * @return `true` if contains, else `false`.
*/ */
private fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean { private fun <T> AsmBuilder<T>.buildExpectationStack(
val theName = methodNameAdapters[name to arity] ?: name context: Algebra<T>,
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null name: String,
val t = if (primitiveMode && hasSpecific) primitiveMask else tType parameterTypes: Array<MstType>
repeat(arity) { expectationStack.push(t) } ): Boolean {
return hasSpecific val arity = parameterTypes.size
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
if (specific != null)
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
else
repeat(arity) { expectationStack.push(tType) }
return specific != null
} }
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
.parameterTypes
.zip(parameterTypes)
.map { (type, mstType) ->
when {
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
}
}
/** /**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method. * [AsmBuilder.invokeAlgebraOperation] of this method.
* *
* @return `true` if contains, else `false`. * @return `true` if contains, else `false`.
*/ */
private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { private fun <T> AsmBuilder<T>.tryInvokeSpecific(
context: Algebra<T>,
name: String,
parameterTypes: Array<MstType>
): Boolean {
val arity = parameterTypes.size
val theName = methodNameAdapters[name to arity] ?: name val theName = methodNameAdapters[name to arity] ?: name
val spec = findSpecific(context, theName, parameterTypes) ?: return false
context.javaClass.methods.find {
var suitableSignature = it.name == theName && it.parameters.size == arity
if (primitiveMode && it.isBridge)
suitableSignature = false
suitableSignature
} ?: return false
val owner = context::class.asm val owner = context::class.asm
invokeAlgebraOperation( invokeAlgebraOperation(
owner = owner.internalName, owner = owner.internalName,
method = theName, method = theName,
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
expectedArity = arity, expectedArity = arity,
opcode = INVOKEVIRTUAL opcode = INVOKEVIRTUAL
) )
@ -121,18 +151,19 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Strin
/** /**
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String. * Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
*/ */
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall( internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
context: Algebra<T>, context: Algebra<T>,
name: String, name: String,
fallbackMethodName: String, fallbackMethodName: String,
arity: Int, parameterTypes: Array<MstType>,
parameters: AsmBuilder<T>.() -> Unit parameters: AsmBuilder<T>.() -> Unit
) { ) {
val arity = parameterTypes.size
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(context, name, arity)) loadStringConstant(name) if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
parameters() parameters()
if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation( if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName, owner = AsmBuilder.ALGEBRA_TYPE.internalName,
method = fallbackMethodName, method = fallbackMethodName,
@ -145,4 +176,3 @@ internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
expectedArity = arity expectedArity = arity
) )
} }

View File

@ -1,10 +0,0 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Label
import org.objectweb.asm.commons.InstructionAdapter
internal fun InstructionAdapter.label(): Label {
val l = Label()
visitLabel(l)
return l
}

View File

@ -2,6 +2,6 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V { @JvmOverloads
return this[key] ?: default ?: error("Parameter not found: $key") internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
} this[key] ?: default ?: error("Parameter not found: $key")

View File

@ -43,4 +43,13 @@ internal class TestAsmSpecialization {
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
assertEquals(1.0, expr("x" to 2.0)) assertEquals(1.0, expr("x" to 2.0))
} }
@Test
fun testPower() {
val expr = RealField
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
.compile()
assertEquals(4.0, expr("x" to 2.0))
}
} }

View File

@ -13,8 +13,7 @@ import kotlin.test.assertEquals
internal class AsmTest { internal class AsmTest {
@Test @Test
fun `compile MST`() { fun `compile MST`() {
val mst = "2+2*(2+2)".parseMath() val res = ComplexField.expression("2+2*(2+2)".parseMath())()
val res = ComplexField.expression(mst)()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }