forked from kscience/kmath
Implement advanced specialization for numeric functions
This commit is contained in:
parent
7372197fe1
commit
2a34110f1d
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.internal.AsmBuilder
|
||||
import scientifik.kmath.asm.internal.MstType
|
||||
import scientifik.kmath.asm.internal.buildAlgebraOperationCall
|
||||
import scientifik.kmath.asm.internal.buildName
|
||||
import scientifik.kmath.ast.MST
|
||||
@ -17,28 +18,20 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
||||
fun AsmBuilder<T>.visit(node: MST) {
|
||||
when (node) {
|
||||
is MST.Symbolic -> loadVariable(node.value)
|
||||
|
||||
is MST.Numeric -> {
|
||||
val constant = if (algebra is NumericAlgebra<T>)
|
||||
algebra.number(node.value)
|
||||
else
|
||||
error("Number literals are not supported in $algebra")
|
||||
|
||||
loadTConstant(constant)
|
||||
}
|
||||
is MST.Numeric -> loadNumeric(node.value)
|
||||
|
||||
is MST.Unary -> buildAlgebraOperationCall(
|
||||
context = algebra,
|
||||
name = node.operation,
|
||||
fallbackMethodName = "unaryOperation",
|
||||
arity = 1
|
||||
parameterTypes = arrayOf(MstType.fromMst(node.value))
|
||||
) { visit(node.value) }
|
||||
|
||||
is MST.Binary -> buildAlgebraOperationCall(
|
||||
context = algebra,
|
||||
name = node.operation,
|
||||
fallbackMethodName = "binaryOperation",
|
||||
arity = 2
|
||||
parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
|
||||
) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
|
@ -7,7 +7,9 @@ import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||
import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.NumericAlgebra
|
||||
import java.util.*
|
||||
import java.util.stream.Collectors
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
@ -20,7 +22,7 @@ import kotlin.reflect.KClass
|
||||
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
||||
*/
|
||||
internal class AsmBuilder<T> internal constructor(
|
||||
private val classOfT: KClass<*>,
|
||||
internal val classOfT: KClass<*>,
|
||||
private val algebra: Algebra<T>,
|
||||
private val className: String,
|
||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||
@ -100,7 +102,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
/**
|
||||
* Stack of useful objects types on stack expected by algebra calls.
|
||||
*/
|
||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>().apply { push(tType) }
|
||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque(listOf(tType))
|
||||
|
||||
/**
|
||||
* The cache for instance built by this builder.
|
||||
@ -286,42 +288,50 @@ internal class AsmBuilder<T> internal constructor(
|
||||
/**
|
||||
* Loads a [T] constant from [constants].
|
||||
*/
|
||||
internal fun loadTConstant(value: T) {
|
||||
private fun loadTConstant(value: T) {
|
||||
if (classOfT in INLINABLE_NUMBERS) {
|
||||
val expectedType = expectationStack.pop()
|
||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||
loadNumberConstant(value as Number, mustBeBoxed)
|
||||
|
||||
if (mustBeBoxed)
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
|
||||
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
||||
return
|
||||
}
|
||||
|
||||
loadConstant(value as Any, tType)
|
||||
loadObjectConstant(value as Any, tType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Boxes the current value and pushes it.
|
||||
*/
|
||||
private fun box(): Unit = invokeMethodVisitor.invokestatic(
|
||||
tType.internalName,
|
||||
"valueOf",
|
||||
Type.getMethodDescriptor(tType, primitiveMask),
|
||||
false
|
||||
)
|
||||
private fun box(primitive: Type) {
|
||||
val r = PRIMITIVES_TO_BOXED.getValue(primitive)
|
||||
|
||||
invokeMethodVisitor.invokestatic(
|
||||
r.internalName,
|
||||
"valueOf",
|
||||
Type.getMethodDescriptor(r, primitive),
|
||||
false
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Unboxes the current boxed value and pushes it.
|
||||
*/
|
||||
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
|
||||
private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual(
|
||||
NUMBER_TYPE.internalName,
|
||||
NUMBER_CONVERTER_METHODS.getValue(primitiveMask),
|
||||
Type.getMethodDescriptor(primitiveMask),
|
||||
NUMBER_CONVERTER_METHODS.getValue(primitive),
|
||||
Type.getMethodDescriptor(primitive),
|
||||
false
|
||||
)
|
||||
|
||||
/**
|
||||
* Loads [java.lang.Object] constant from constants.
|
||||
*/
|
||||
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||
private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
loadThis()
|
||||
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||
@ -330,6 +340,15 @@ internal class AsmBuilder<T> internal constructor(
|
||||
checkcast(type)
|
||||
}
|
||||
|
||||
fun loadNumeric(value: Number) {
|
||||
if (expectationStack.peek() == NUMBER_TYPE) {
|
||||
loadNumberConstant(value, true)
|
||||
expectationStack.pop()
|
||||
typeStack.push(NUMBER_TYPE)
|
||||
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
|
||||
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads this variable.
|
||||
*/
|
||||
@ -354,18 +373,16 @@ internal class AsmBuilder<T> internal constructor(
|
||||
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
}
|
||||
|
||||
if (mustBeBoxed) {
|
||||
box()
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
}
|
||||
if (mustBeBoxed)
|
||||
box(primitive)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
loadConstant(value, boxed)
|
||||
loadObjectConstant(value, boxed)
|
||||
|
||||
if (!mustBeBoxed) unbox()
|
||||
else invokeMethodVisitor.checkcast(tType)
|
||||
if (!mustBeBoxed)
|
||||
unboxTo(primitiveMask)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -397,7 +414,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
if (expectedType.sort == Type.OBJECT)
|
||||
typeStack.push(tType)
|
||||
else {
|
||||
unbox()
|
||||
unboxTo(primitiveMask)
|
||||
typeStack.push(primitiveMask)
|
||||
}
|
||||
}
|
||||
@ -446,7 +463,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
||||
typeStack.push(tType)
|
||||
else {
|
||||
unbox()
|
||||
unboxTo(primitiveMask)
|
||||
typeStack.push(primitiveMask)
|
||||
}
|
||||
}
|
||||
@ -476,6 +493,18 @@ internal class AsmBuilder<T> internal constructor(
|
||||
*/
|
||||
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
||||
|
||||
/**
|
||||
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||
*/
|
||||
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
|
||||
BOXED_TO_PRIMITIVES.entries.stream().collect(
|
||||
Collectors.toMap(
|
||||
Map.Entry<Type, Type>::value,
|
||||
Map.Entry<Type, Type>::key
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps primitive ASM types to [Number] functions unboxing them.
|
||||
*/
|
||||
|
@ -0,0 +1,17 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import scientifik.kmath.ast.MST
|
||||
|
||||
internal enum class MstType {
|
||||
GENERAL,
|
||||
NUMBER;
|
||||
|
||||
companion object {
|
||||
fun fromMst(mst: MST): MstType {
|
||||
if (mst is MST.Numeric)
|
||||
return NUMBER
|
||||
|
||||
return GENERAL
|
||||
}
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ import org.objectweb.asm.commons.InstructionAdapter
|
||||
import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import java.lang.reflect.Method
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||
@ -25,10 +26,8 @@ internal val KClass<*>.asm: Type
|
||||
/**
|
||||
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
||||
*/
|
||||
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> = if (predicate(this))
|
||||
arrayOf(this)
|
||||
else
|
||||
emptyArray()
|
||||
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> =
|
||||
if (predicate(this)) arrayOf(this) else emptyArray()
|
||||
|
||||
/**
|
||||
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
||||
@ -44,11 +43,7 @@ internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Un
|
||||
/**
|
||||
* Constructs a [Label], then applies it to this visitor.
|
||||
*/
|
||||
internal fun MethodVisitor.label(): Label {
|
||||
val l = Label()
|
||||
visitLabel(l)
|
||||
return l
|
||||
}
|
||||
internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
|
||||
|
||||
/**
|
||||
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
||||
@ -81,44 +76,71 @@ internal inline fun ClassWriter.visitField(
|
||||
block: FieldVisitor.() -> Unit
|
||||
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
||||
|
||||
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
|
||||
context.javaClass.methods.find { method ->
|
||||
val nameValid = method.name == name
|
||||
val arityValid = method.parameters.size == parameterTypes.size
|
||||
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
|
||||
|
||||
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
|
||||
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
|
||||
}
|
||||
|
||||
nameValid && arityValid && notBridgeInPrimitive && paramsValid
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
|
||||
* type expectation stack for needed arity.
|
||||
*
|
||||
* @return `true` if contains, else `false`.
|
||||
*/
|
||||
private fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||
val theName = methodNameAdapters[name to arity] ?: name
|
||||
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
|
||||
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
|
||||
repeat(arity) { expectationStack.push(t) }
|
||||
return hasSpecific
|
||||
private fun <T> AsmBuilder<T>.buildExpectationStack(
|
||||
context: Algebra<T>,
|
||||
name: String,
|
||||
parameterTypes: Array<MstType>
|
||||
): Boolean {
|
||||
val arity = parameterTypes.size
|
||||
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
|
||||
|
||||
if (specific != null)
|
||||
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
||||
else
|
||||
repeat(arity) { expectationStack.push(tType) }
|
||||
|
||||
return specific != null
|
||||
}
|
||||
|
||||
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
|
||||
.parameterTypes
|
||||
.zip(parameterTypes)
|
||||
.map { (type, mstType) ->
|
||||
when {
|
||||
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
|
||||
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
|
||||
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
||||
*
|
||||
* @return `true` if contains, else `false`.
|
||||
*/
|
||||
private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||
private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
||||
context: Algebra<T>,
|
||||
name: String,
|
||||
parameterTypes: Array<MstType>
|
||||
): Boolean {
|
||||
val arity = parameterTypes.size
|
||||
val theName = methodNameAdapters[name to arity] ?: name
|
||||
|
||||
context.javaClass.methods.find {
|
||||
var suitableSignature = it.name == theName && it.parameters.size == arity
|
||||
|
||||
if (primitiveMode && it.isBridge)
|
||||
suitableSignature = false
|
||||
|
||||
suitableSignature
|
||||
} ?: return false
|
||||
|
||||
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
||||
val owner = context::class.asm
|
||||
|
||||
invokeAlgebraOperation(
|
||||
owner = owner.internalName,
|
||||
method = theName,
|
||||
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }),
|
||||
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
|
||||
expectedArity = arity,
|
||||
opcode = INVOKEVIRTUAL
|
||||
)
|
||||
@ -133,14 +155,15 @@ internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||
context: Algebra<T>,
|
||||
name: String,
|
||||
fallbackMethodName: String,
|
||||
arity: Int,
|
||||
parameterTypes: Array<MstType>,
|
||||
parameters: AsmBuilder<T>.() -> Unit
|
||||
) {
|
||||
val arity = parameterTypes.size
|
||||
loadAlgebra()
|
||||
if (!buildExpectationStack(context, name, arity)) loadStringConstant(name)
|
||||
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
|
||||
parameters()
|
||||
|
||||
if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation(
|
||||
if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||
method = fallbackMethodName,
|
||||
|
||||
|
@ -43,4 +43,13 @@ internal class TestAsmSpecialization {
|
||||
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(1.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = RealField
|
||||
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
|
||||
.compile()
|
||||
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user