Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116
@ -1,13 +1,42 @@
|
|||||||
# AST-based expression representation and operations (`kmath-ast`)
|
# Abstract syntax tree expression representation and operations (`kmath-ast`)
|
||||||
|
|
||||||
This subproject implements the following features:
|
This subproject implements the following features:
|
||||||
|
|
||||||
- Expression Language and its parser.
|
- Expression Language and its parser.
|
||||||
- MST as expression language's syntax intermediate representation.
|
- MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation.
|
||||||
- Type-safe builder of MST.
|
- Type-safe builder for MST.
|
||||||
- Evaluating expressions by traversing MST.
|
- Evaluating expressions by traversing MST.
|
||||||
|
|
||||||
## Dynamic expression code generation with OW2 ASM
|
> #### Artifact:
|
||||||
|
> This module is distributed in the artifact `scientifik:kmath-ast:0.1.4-dev-8`.
|
||||||
|
>
|
||||||
|
> **Gradle:**
|
||||||
|
>
|
||||||
|
> ```gradle
|
||||||
|
> repositories {
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {
|
||||||
|
> implementation 'scientifik:kmath-ast:0.1.4-dev-8'
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
> **Gradle Kotlin DSL:**
|
||||||
|
>
|
||||||
|
> ```kotlin
|
||||||
|
> repositories {
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/scientifik")
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {
|
||||||
|
> implementation("scientifik:kmath-ast:0.1.4-dev-8")
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
|
||||||
|
## Dynamic expression code generation with ObjectWeb ASM
|
||||||
|
|
||||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||||
a special implementation of `Expression<T>` with implemented `invoke` function.
|
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||||
@ -30,15 +59,13 @@ import scientifik.kmath.operations.RealField;
|
|||||||
|
|
||||||
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
|
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
|
||||||
private final RealField algebra;
|
private final RealField algebra;
|
||||||
private final Object[] constants;
|
|
||||||
|
|
||||||
public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) {
|
|
||||||
this.algebra = algebra;
|
|
||||||
this.constants = constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final Double invoke(Map<String, ? extends Double> arguments) {
|
public final Double invoke(Map<String, ? extends Double> arguments) {
|
||||||
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D);
|
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsmCompiledExpression_1073786867_0(RealField algebra) {
|
||||||
|
this.algebra = algebra;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,7 +73,7 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression<Doub
|
|||||||
|
|
||||||
### Example Usage
|
### Example Usage
|
||||||
|
|
||||||
This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression:
|
This API is an extension to MST and MstExpression, so you may optimize as both of them:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.asm
|
package scientifik.kmath.asm
|
||||||
|
|
||||||
import scientifik.kmath.asm.internal.AsmBuilder
|
import scientifik.kmath.asm.internal.AsmBuilder
|
||||||
|
import scientifik.kmath.asm.internal.MstType
|
||||||
import scientifik.kmath.asm.internal.buildAlgebraOperationCall
|
import scientifik.kmath.asm.internal.buildAlgebraOperationCall
|
||||||
import scientifik.kmath.asm.internal.buildName
|
import scientifik.kmath.asm.internal.buildName
|
||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
@ -17,28 +18,20 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
fun AsmBuilder<T>.visit(node: MST) {
|
fun AsmBuilder<T>.visit(node: MST) {
|
||||||
when (node) {
|
when (node) {
|
||||||
is MST.Symbolic -> loadVariable(node.value)
|
is MST.Symbolic -> loadVariable(node.value)
|
||||||
|
is MST.Numeric -> loadNumeric(node.value)
|
||||||
is MST.Numeric -> {
|
|
||||||
val constant = if (algebra is NumericAlgebra<T>)
|
|
||||||
algebra.number(node.value)
|
|
||||||
else
|
|
||||||
error("Number literals are not supported in $algebra")
|
|
||||||
|
|
||||||
loadTConstant(constant)
|
|
||||||
}
|
|
||||||
|
|
||||||
is MST.Unary -> buildAlgebraOperationCall(
|
is MST.Unary -> buildAlgebraOperationCall(
|
||||||
context = algebra,
|
context = algebra,
|
||||||
name = node.operation,
|
name = node.operation,
|
||||||
fallbackMethodName = "unaryOperation",
|
fallbackMethodName = "unaryOperation",
|
||||||
arity = 1
|
parameterTypes = arrayOf(MstType.fromMst(node.value))
|
||||||
) { visit(node.value) }
|
) { visit(node.value) }
|
||||||
|
|
||||||
is MST.Binary -> buildAlgebraOperationCall(
|
is MST.Binary -> buildAlgebraOperationCall(
|
||||||
context = algebra,
|
context = algebra,
|
||||||
name = node.operation,
|
name = node.operation,
|
||||||
fallbackMethodName = "binaryOperation",
|
fallbackMethodName = "binaryOperation",
|
||||||
arity = 2
|
parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
|
||||||
) {
|
) {
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
|
@ -7,7 +7,9 @@ import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
|||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import scientifik.kmath.operations.NumericAlgebra
|
||||||
import java.util.*
|
import java.util.*
|
||||||
|
import java.util.stream.Collectors
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -20,7 +22,7 @@ import kotlin.reflect.KClass
|
|||||||
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
||||||
*/
|
*/
|
||||||
internal class AsmBuilder<T> internal constructor(
|
internal class AsmBuilder<T> internal constructor(
|
||||||
private val classOfT: KClass<*>,
|
internal val classOfT: KClass<*>,
|
||||||
private val algebra: Algebra<T>,
|
private val algebra: Algebra<T>,
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||||
@ -38,17 +40,17 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM Type for [algebra]
|
* ASM Type for [algebra].
|
||||||
*/
|
*/
|
||||||
private val tAlgebraType: Type = algebra::class.asm
|
private val tAlgebraType: Type = algebra::class.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [T]
|
* ASM type for [T].
|
||||||
*/
|
*/
|
||||||
internal val tType: Type = classOfT.asm
|
internal val tType: Type = classOfT.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for new class
|
* ASM type for new class.
|
||||||
*/
|
*/
|
||||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||||
|
|
||||||
@ -72,6 +74,11 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
|
/**
|
||||||
|
* State if this [AsmBuilder] needs to generate constants field.
|
||||||
|
*/
|
||||||
|
private var hasConstants: Boolean = true
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||||
*/
|
*/
|
||||||
@ -95,7 +102,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Stack of useful objects types on stack expected by algebra calls.
|
* Stack of useful objects types on stack expected by algebra calls.
|
||||||
*/
|
*/
|
||||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>().apply { push(tType) }
|
internal val expectationStack: ArrayDeque<Type> = ArrayDeque(listOf(tType))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The cache for instance built by this builder.
|
* The cache for instance built by this builder.
|
||||||
@ -108,7 +115,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* The built instance is cached.
|
* The built instance is cached.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
fun getInstance(): Expression<T> {
|
internal fun getInstance(): Expression<T> {
|
||||||
generatedInstance?.let { return it }
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
||||||
@ -127,64 +134,6 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
arrayOf(EXPRESSION_TYPE.internalName)
|
arrayOf(EXPRESSION_TYPE.internalName)
|
||||||
)
|
)
|
||||||
|
|
||||||
visitField(
|
|
||||||
access = ACC_PRIVATE or ACC_FINAL,
|
|
||||||
name = "algebra",
|
|
||||||
descriptor = tAlgebraType.descriptor,
|
|
||||||
signature = null,
|
|
||||||
value = null,
|
|
||||||
block = FieldVisitor::visitEnd
|
|
||||||
)
|
|
||||||
|
|
||||||
visitField(
|
|
||||||
access = ACC_PRIVATE or ACC_FINAL,
|
|
||||||
name = "constants",
|
|
||||||
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
|
||||||
signature = null,
|
|
||||||
value = null,
|
|
||||||
block = FieldVisitor::visitEnd
|
|
||||||
)
|
|
||||||
|
|
||||||
visitMethod(
|
|
||||||
ACC_PUBLIC,
|
|
||||||
"<init>",
|
|
||||||
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
|
|
||||||
null,
|
|
||||||
null
|
|
||||||
).instructionAdapter {
|
|
||||||
val thisVar = 0
|
|
||||||
val algebraVar = 1
|
|
||||||
val constantsVar = 2
|
|
||||||
val l0 = label()
|
|
||||||
load(thisVar, classType)
|
|
||||||
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
|
||||||
label()
|
|
||||||
load(thisVar, classType)
|
|
||||||
load(algebraVar, tAlgebraType)
|
|
||||||
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
|
||||||
label()
|
|
||||||
load(thisVar, classType)
|
|
||||||
load(constantsVar, OBJECT_ARRAY_TYPE)
|
|
||||||
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
|
||||||
label()
|
|
||||||
visitInsn(RETURN)
|
|
||||||
val l4 = label()
|
|
||||||
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
|
||||||
|
|
||||||
visitLocalVariable(
|
|
||||||
"algebra",
|
|
||||||
tAlgebraType.descriptor,
|
|
||||||
null,
|
|
||||||
l0,
|
|
||||||
l4,
|
|
||||||
algebraVar
|
|
||||||
)
|
|
||||||
|
|
||||||
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
|
||||||
visitMaxs(0, 3)
|
|
||||||
visitEnd()
|
|
||||||
}
|
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL,
|
ACC_PUBLIC or ACC_FINAL,
|
||||||
"invoke",
|
"invoke",
|
||||||
@ -251,6 +200,78 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hasConstants = constants.isNotEmpty()
|
||||||
|
|
||||||
|
visitField(
|
||||||
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
|
name = "algebra",
|
||||||
|
descriptor = tAlgebraType.descriptor,
|
||||||
|
signature = null,
|
||||||
|
value = null,
|
||||||
|
block = FieldVisitor::visitEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
if (hasConstants)
|
||||||
|
visitField(
|
||||||
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
|
name = "constants",
|
||||||
|
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
||||||
|
signature = null,
|
||||||
|
value = null,
|
||||||
|
block = FieldVisitor::visitEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMethod(
|
||||||
|
ACC_PUBLIC,
|
||||||
|
"<init>",
|
||||||
|
|
||||||
|
Type.getMethodDescriptor(
|
||||||
|
Type.VOID_TYPE,
|
||||||
|
tAlgebraType,
|
||||||
|
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
||||||
|
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
).instructionAdapter {
|
||||||
|
val thisVar = 0
|
||||||
|
val algebraVar = 1
|
||||||
|
val constantsVar = 2
|
||||||
|
val l0 = label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
||||||
|
label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
load(algebraVar, tAlgebraType)
|
||||||
|
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||||
|
|
||||||
|
if (hasConstants) {
|
||||||
|
label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
load(constantsVar, OBJECT_ARRAY_TYPE)
|
||||||
|
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
label()
|
||||||
|
visitInsn(RETURN)
|
||||||
|
val l4 = label()
|
||||||
|
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"algebra",
|
||||||
|
tAlgebraType.descriptor,
|
||||||
|
null,
|
||||||
|
l0,
|
||||||
|
l4,
|
||||||
|
algebraVar
|
||||||
|
)
|
||||||
|
|
||||||
|
if (hasConstants)
|
||||||
|
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
||||||
|
|
||||||
|
visitMaxs(0, 3)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +279,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
.defineClass(className, classWriter.toByteArray())
|
.defineClass(className, classWriter.toByteArray())
|
||||||
.constructors
|
.constructors
|
||||||
.first()
|
.first()
|
||||||
.newInstance(algebra, constants.toTypedArray()) as Expression<T>
|
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
|
||||||
|
|
||||||
generatedInstance = new
|
generatedInstance = new
|
||||||
return new
|
return new
|
||||||
@ -267,42 +288,50 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Loads a [T] constant from [constants].
|
* Loads a [T] constant from [constants].
|
||||||
*/
|
*/
|
||||||
internal fun loadTConstant(value: T) {
|
private fun loadTConstant(value: T) {
|
||||||
if (classOfT in INLINABLE_NUMBERS) {
|
if (classOfT in INLINABLE_NUMBERS) {
|
||||||
val expectedType = expectationStack.pop()
|
val expectedType = expectationStack.pop()
|
||||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||||
loadNumberConstant(value as Number, mustBeBoxed)
|
loadNumberConstant(value as Number, mustBeBoxed)
|
||||||
|
|
||||||
|
if (mustBeBoxed)
|
||||||
|
invokeMethodVisitor.checkcast(tType)
|
||||||
|
|
||||||
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loadConstant(value as Any, tType)
|
loadObjectConstant(value as Any, tType)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Boxes the current value and pushes it.
|
* Boxes the current value and pushes it.
|
||||||
*/
|
*/
|
||||||
private fun box(): Unit = invokeMethodVisitor.invokestatic(
|
private fun box(primitive: Type) {
|
||||||
tType.internalName,
|
val r = PRIMITIVES_TO_BOXED.getValue(primitive)
|
||||||
|
|
||||||
|
invokeMethodVisitor.invokestatic(
|
||||||
|
r.internalName,
|
||||||
"valueOf",
|
"valueOf",
|
||||||
Type.getMethodDescriptor(tType, primitiveMask),
|
Type.getMethodDescriptor(r, primitive),
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unboxes the current boxed value and pushes it.
|
* Unboxes the current boxed value and pushes it.
|
||||||
*/
|
*/
|
||||||
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
|
private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual(
|
||||||
NUMBER_TYPE.internalName,
|
NUMBER_TYPE.internalName,
|
||||||
NUMBER_CONVERTER_METHODS.getValue(primitiveMask),
|
NUMBER_CONVERTER_METHODS.getValue(primitive),
|
||||||
Type.getMethodDescriptor(primitiveMask),
|
Type.getMethodDescriptor(primitive),
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads [java.lang.Object] constant from constants.
|
* Loads [java.lang.Object] constant from constants.
|
||||||
*/
|
*/
|
||||||
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||||
loadThis()
|
loadThis()
|
||||||
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
@ -311,6 +340,15 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
checkcast(type)
|
checkcast(type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun loadNumeric(value: Number) {
|
||||||
|
if (expectationStack.peek() == NUMBER_TYPE) {
|
||||||
|
loadNumberConstant(value, true)
|
||||||
|
expectationStack.pop()
|
||||||
|
typeStack.push(NUMBER_TYPE)
|
||||||
|
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
|
||||||
|
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads this variable.
|
* Loads this variable.
|
||||||
*/
|
*/
|
||||||
@ -335,18 +373,16 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mustBeBoxed) {
|
if (mustBeBoxed)
|
||||||
box()
|
box(primitive)
|
||||||
invokeMethodVisitor.checkcast(tType)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loadConstant(value, boxed)
|
loadObjectConstant(value, boxed)
|
||||||
|
|
||||||
if (!mustBeBoxed) unbox()
|
if (!mustBeBoxed)
|
||||||
else invokeMethodVisitor.checkcast(tType)
|
unboxTo(primitiveMask)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -359,24 +395,26 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
|
|
||||||
if (defaultValue != null)
|
if (defaultValue != null)
|
||||||
loadTConstant(defaultValue)
|
loadTConstant(defaultValue)
|
||||||
else
|
|
||||||
aconst(null)
|
|
||||||
|
|
||||||
invokestatic(
|
invokestatic(
|
||||||
MAP_INTRINSICS_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
"getOrFail",
|
"getOrFail",
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
|
|
||||||
|
Type.getMethodDescriptor(
|
||||||
|
OBJECT_TYPE,
|
||||||
|
MAP_TYPE,
|
||||||
|
OBJECT_TYPE,
|
||||||
|
*OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
checkcast(tType)
|
checkcast(tType)
|
||||||
|
|
||||||
val expectedType = expectationStack.pop()
|
val expectedType = expectationStack.pop()
|
||||||
|
|
||||||
if (expectedType.sort == Type.OBJECT)
|
if (expectedType.sort == Type.OBJECT)
|
||||||
typeStack.push(tType)
|
typeStack.push(tType)
|
||||||
else {
|
else {
|
||||||
unbox()
|
unboxTo(primitiveMask)
|
||||||
typeStack.push(primitiveMask)
|
typeStack.push(primitiveMask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -425,7 +463,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
||||||
typeStack.push(tType)
|
typeStack.push(tType)
|
||||||
else {
|
else {
|
||||||
unbox()
|
unboxTo(primitiveMask)
|
||||||
typeStack.push(primitiveMask)
|
typeStack.push(primitiveMask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -455,6 +493,18 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||||
|
*/
|
||||||
|
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
|
||||||
|
BOXED_TO_PRIMITIVES.entries.stream().collect(
|
||||||
|
Collectors.toMap(
|
||||||
|
Map.Entry<Type, Type>::value,
|
||||||
|
Map.Entry<Type, Type>::key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps primitive ASM types to [Number] functions unboxing them.
|
* Maps primitive ASM types to [Number] functions unboxing them.
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
|
||||||
|
internal enum class MstType {
|
||||||
|
GENERAL,
|
||||||
|
NUMBER;
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun fromMst(mst: MST): MstType {
|
||||||
|
if (mst is MST.Numeric)
|
||||||
|
return NUMBER
|
||||||
|
|
||||||
|
return GENERAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,7 @@ import org.objectweb.asm.commons.InstructionAdapter
|
|||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import java.lang.reflect.Method
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||||
@ -22,6 +23,12 @@ private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
|||||||
internal val KClass<*>.asm: Type
|
internal val KClass<*>.asm: Type
|
||||||
get() = Type.getType(java)
|
get() = Type.getType(java)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
||||||
|
*/
|
||||||
|
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> =
|
||||||
|
if (predicate(this)) arrayOf(this) else emptyArray()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
||||||
*/
|
*/
|
||||||
@ -36,11 +43,7 @@ internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Un
|
|||||||
/**
|
/**
|
||||||
* Constructs a [Label], then applies it to this visitor.
|
* Constructs a [Label], then applies it to this visitor.
|
||||||
*/
|
*/
|
||||||
internal fun MethodVisitor.label(): Label {
|
internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
|
||||||
val l = Label()
|
|
||||||
visitLabel(l)
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
||||||
@ -73,44 +76,71 @@ internal inline fun ClassWriter.visitField(
|
|||||||
block: FieldVisitor.() -> Unit
|
block: FieldVisitor.() -> Unit
|
||||||
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
||||||
|
|
||||||
|
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
|
||||||
|
context.javaClass.methods.find { method ->
|
||||||
|
val nameValid = method.name == name
|
||||||
|
val arityValid = method.parameters.size == parameterTypes.size
|
||||||
|
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
|
||||||
|
|
||||||
|
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
|
||||||
|
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameValid && arityValid && notBridgeInPrimitive && paramsValid
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
|
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
|
||||||
* type expectation stack for needed arity.
|
* type expectation stack for needed arity.
|
||||||
*
|
*
|
||||||
* @return `true` if contains, else `false`.
|
* @return `true` if contains, else `false`.
|
||||||
*/
|
*/
|
||||||
private fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
private fun <T> AsmBuilder<T>.buildExpectationStack(
|
||||||
val theName = methodNameAdapters[name to arity] ?: name
|
context: Algebra<T>,
|
||||||
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
|
name: String,
|
||||||
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
|
parameterTypes: Array<MstType>
|
||||||
repeat(arity) { expectationStack.push(t) }
|
): Boolean {
|
||||||
return hasSpecific
|
val arity = parameterTypes.size
|
||||||
|
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
|
||||||
|
|
||||||
|
if (specific != null)
|
||||||
|
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
||||||
|
else
|
||||||
|
repeat(arity) { expectationStack.push(tType) }
|
||||||
|
|
||||||
|
return specific != null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
|
||||||
|
.parameterTypes
|
||||||
|
.zip(parameterTypes)
|
||||||
|
.map { (type, mstType) ->
|
||||||
|
when {
|
||||||
|
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
|
||||||
|
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
|
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
|
||||||
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
||||||
*
|
*
|
||||||
* @return `true` if contains, else `false`.
|
* @return `true` if contains, else `false`.
|
||||||
*/
|
*/
|
||||||
private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
||||||
|
context: Algebra<T>,
|
||||||
|
name: String,
|
||||||
|
parameterTypes: Array<MstType>
|
||||||
|
): Boolean {
|
||||||
|
val arity = parameterTypes.size
|
||||||
val theName = methodNameAdapters[name to arity] ?: name
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
|
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
||||||
context.javaClass.methods.find {
|
|
||||||
var suitableSignature = it.name == theName && it.parameters.size == arity
|
|
||||||
|
|
||||||
if (primitiveMode && it.isBridge)
|
|
||||||
suitableSignature = false
|
|
||||||
|
|
||||||
suitableSignature
|
|
||||||
} ?: return false
|
|
||||||
|
|
||||||
val owner = context::class.asm
|
val owner = context::class.asm
|
||||||
|
|
||||||
invokeAlgebraOperation(
|
invokeAlgebraOperation(
|
||||||
owner = owner.internalName,
|
owner = owner.internalName,
|
||||||
method = theName,
|
method = theName,
|
||||||
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }),
|
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
|
||||||
expectedArity = arity,
|
expectedArity = arity,
|
||||||
opcode = INVOKEVIRTUAL
|
opcode = INVOKEVIRTUAL
|
||||||
)
|
)
|
||||||
@ -121,18 +151,19 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Strin
|
|||||||
/**
|
/**
|
||||||
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
||||||
*/
|
*/
|
||||||
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||||
context: Algebra<T>,
|
context: Algebra<T>,
|
||||||
name: String,
|
name: String,
|
||||||
fallbackMethodName: String,
|
fallbackMethodName: String,
|
||||||
arity: Int,
|
parameterTypes: Array<MstType>,
|
||||||
parameters: AsmBuilder<T>.() -> Unit
|
parameters: AsmBuilder<T>.() -> Unit
|
||||||
) {
|
) {
|
||||||
|
val arity = parameterTypes.size
|
||||||
loadAlgebra()
|
loadAlgebra()
|
||||||
if (!buildExpectationStack(context, name, arity)) loadStringConstant(name)
|
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
|
||||||
parameters()
|
parameters()
|
||||||
|
|
||||||
if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation(
|
if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
|
||||||
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||||
method = fallbackMethodName,
|
method = fallbackMethodName,
|
||||||
|
|
||||||
@ -145,4 +176,3 @@ internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
|||||||
expectedArity = arity
|
expectedArity = arity
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
package scientifik.kmath.asm.internal
|
|
||||||
|
|
||||||
import org.objectweb.asm.Label
|
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
|
||||||
|
|
||||||
internal fun InstructionAdapter.label(): Label {
|
|
||||||
val l = Label()
|
|
||||||
visitLabel(l)
|
|
||||||
return l
|
|
||||||
}
|
|
@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
package scientifik.kmath.asm.internal
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
|
@JvmOverloads
|
||||||
return this[key] ?: default ?: error("Parameter not found: $key")
|
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
||||||
}
|
this[key] ?: default ?: error("Parameter not found: $key")
|
||||||
|
@ -43,4 +43,13 @@ internal class TestAsmSpecialization {
|
|||||||
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||||
assertEquals(1.0, expr("x" to 2.0))
|
assertEquals(1.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPower() {
|
||||||
|
val expr = RealField
|
||||||
|
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
|
||||||
|
.compile()
|
||||||
|
|
||||||
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,8 +13,7 @@ import kotlin.test.assertEquals
|
|||||||
internal class AsmTest {
|
internal class AsmTest {
|
||||||
@Test
|
@Test
|
||||||
fun `compile MST`() {
|
fun `compile MST`() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val res = ComplexField.expression("2+2*(2+2)".parseMath())()
|
||||||
val res = ComplexField.expression(mst)()
|
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user