Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-improved-trigonometry

This commit is contained in:
Iaroslav 2020-06-20 02:29:03 +07:00
commit d7968c08c6
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
5 changed files with 101 additions and 12 deletions

View File

@ -73,7 +73,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
/** /**
* Compile an [MST] to ASM using given algebra * Compile an [MST] to ASM using given algebra
*/ */
inline fun <reified T : Any> Algebra<T>.expresion(mst: MST): Expression<T> = mst.compileWith(T::class, this) inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
/** /**
* Optimize performance of an [MSTExpression] using ASM codegen * Optimize performance of an [MSTExpression] using ASM codegen

View File

@ -11,9 +11,12 @@ import scientifik.kmath.operations.Algebra
* ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
* *
* @param T the type of AsmExpression to unwrap. * @param T the type generated [AsmCompiledExpression] operates.
* @param algebra the algebra the applied AsmExpressions use. * @property classOfT the [Class] of T.
* @param className the unique class name of new loaded class. * @property algebra the algebra the applied AsmExpressions use.
* @property className the unique class name of new loaded class.
* @property invokeLabel0Visitor the function applied to this object when the L0 label of invoke implementation is
* being written.
*/ */
internal class AsmBuilder<T> internal constructor( internal class AsmBuilder<T> internal constructor(
private val classOfT: Class<*>, private val classOfT: Class<*>,
@ -21,10 +24,16 @@ internal class AsmBuilder<T> internal constructor(
private val className: String, private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
) { ) {
/**
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
*/
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
} }
/**
* The instance of [ClassLoader] used by this builder.
*/
private val classLoader: ClassLoader = private val classLoader: ClassLoader =
ClassLoader(javaClass.classLoader) ClassLoader(javaClass.classLoader)
@ -35,14 +44,39 @@ internal class AsmBuilder<T> internal constructor(
private val T_CLASS: String = classOfT.name.replace('.', '/') private val T_CLASS: String = classOfT.name.replace('.', '/')
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
/**
* Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass.
*/
private val invokeThisVar: Int = 0 private val invokeThisVar: Int = 0
/**
* Index of `arguments` variable in invoke method of [AsmCompiledExpression] built subclass.
*/
private val invokeArgumentsVar: Int = 1 private val invokeArgumentsVar: Int = 1
/**
* List of constants to provide to [AsmCompiledExpression] subclass.
*/
private val constants: MutableList<Any> = mutableListOf() private val constants: MutableList<Any> = mutableListOf()
/**
* Method visitor of `invoke` method of [AsmCompiledExpression] subclass.
*/
private lateinit var invokeMethodVisitor: MethodVisitor private lateinit var invokeMethodVisitor: MethodVisitor
/**
* The cache of [AsmCompiledExpression] subclass built by this builder.
*/
private var generatedInstance: AsmCompiledExpression<T>? = null private var generatedInstance: AsmCompiledExpression<T>? = null
/**
* Subclasses, loads and instantiates the [AsmCompiledExpression] for given parameters.
*
* The built instance is cached.
*/
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
fun getInstance(): AsmCompiledExpression<T> { internal fun getInstance(): AsmCompiledExpression<T> {
generatedInstance?.let { return it } generatedInstance?.let { return it }
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
@ -181,6 +215,9 @@ internal class AsmBuilder<T> internal constructor(
return new return new
} }
/**
* Loads a constant from
*/
internal fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
loadNumberConstant(value as Number) loadNumberConstant(value as Number)
@ -191,6 +228,9 @@ internal class AsmBuilder<T> internal constructor(
loadConstant(value as Any, T_CLASS) loadConstant(value as Any, T_CLASS)
} }
/**
* Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type].
*/
private fun loadConstant(value: Any, type: String) { private fun loadConstant(value: Any, type: String) {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
@ -205,7 +245,12 @@ internal class AsmBuilder<T> internal constructor(
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar)
internal fun loadNumberConstant(value: Number) { /**
* Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
* from it).
*/
private fun loadNumberConstant(value: Number) {
val clazz = value.javaClass val clazz = value.javaClass
val c = clazz.name.replace('.', '/') val c = clazz.name.replace('.', '/')
val sigLetter = SIGNATURE_LETTERS[clazz] val sigLetter = SIGNATURE_LETTERS[clazz]
@ -225,6 +270,9 @@ internal class AsmBuilder<T> internal constructor(
loadConstant(value, c) loadConstant(value, c)
} }
/**
* Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided.
*/
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
visitLoadObjectVar(invokeArgumentsVar) visitLoadObjectVar(invokeArgumentsVar)
@ -253,6 +301,9 @@ internal class AsmBuilder<T> internal constructor(
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.visitCheckCast(T_CLASS)
} }
/**
* Loads algebra from according field of [AsmCompiledExpression] and casts it to class of [algebra] provided.
*/
internal fun loadAlgebra() { internal fun loadAlgebra() {
loadThis() loadThis()
@ -265,20 +316,32 @@ internal class AsmBuilder<T> internal constructor(
invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS)
} }
/**
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be
* called before the arguments and this operation.
*
* The result is casted to [T] automatically.
*/
internal fun invokeAlgebraOperation( internal fun invokeAlgebraOperation(
owner: String, owner: String,
method: String, method: String,
descriptor: String, descriptor: String,
opcode: Int = Opcodes.INVOKEINTERFACE, opcode: Int = Opcodes.INVOKEINTERFACE
isInterface: Boolean = true
) { ) {
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, opcode == Opcodes.INVOKEINTERFACE)
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.visitCheckCast(T_CLASS)
} }
/**
* Writes a LDC Instruction with string constant provided.
*/
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string)
internal companion object { internal companion object {
/**
* Maps JVM primitive numbers boxed types to their letters of JVM signature convention.
*/
private val SIGNATURE_LETTERS: Map<Class<out Any>, String> by lazy { private val SIGNATURE_LETTERS: Map<Class<out Any>, String> by lazy {
mapOf( mapOf(
java.lang.Byte::class.java to "B", java.lang.Byte::class.java to "B",
@ -290,6 +353,9 @@ internal class AsmBuilder<T> internal constructor(
) )
} }
/**
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
*/
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys } private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =

View File

@ -3,6 +3,13 @@ package scientifik.kmath.asm.internal
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
/**
* [Expression] partial implementation to have it subclassed by actual implementations. Provides unified storage for
* objects needed to implement the expression.
*
* @property algebra the algebra to delegate calls.
* @property constants the constants array to have persistent objects to reference in [invoke].
*/
internal abstract class AsmCompiledExpression<T> internal constructor( internal abstract class AsmCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>, @JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any> @JvmField protected val constants: Array<Any>

View File

@ -2,7 +2,13 @@ package scientifik.kmath.asm.internal
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
internal fun buildName(mst: MST, collision: Int = 0): String { /**
* Creates a class name for [AsmCompiledExpression] subclassed to implement [mst] provided.
*
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
*/
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try { try {

View File

@ -5,6 +5,11 @@ import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide") private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity].
*
* @return `true` if contains, else `false`.
*/
internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name val aName = methodNameAdapters[name] ?: name
@ -14,6 +19,12 @@ internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boo
return true return true
} }
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method.
*
* @return `true` if contains, else `false`.
*/
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name val aName = methodNameAdapters[name] ?: name
@ -33,8 +44,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
owner = owner, owner = owner,
method = aName, method = aName,
descriptor = sig, descriptor = sig,
opcode = Opcodes.INVOKEVIRTUAL, opcode = Opcodes.INVOKEVIRTUAL
isInterface = false
) )
return true return true