Apply the suggested changes

This commit is contained in:
Iaroslav 2020-06-14 00:18:40 +07:00
parent e91f6470d3
commit af410dde70
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
5 changed files with 42 additions and 19 deletions

View File

@ -18,9 +18,10 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String
return buildName(expression, collision + 1) return buildName(expression, collision + 1)
} }
inline fun <reified T> AsmExpression<T>.compile(algebra: Algebra<T>): Expression<T> { @PublishedApi
internal inline fun <reified T> AsmExpression<T>.compile(algebra: Algebra<T>): Expression<T> {
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(this)) val ctx = AsmGenerationContext(T::class.java, algebra, buildName(this))
invoke(ctx) compile(ctx)
return ctx.generate() return ctx.generate()
} }

View File

@ -1,14 +1,31 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.optimize
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
import kotlin.reflect.full.memberFunctions
import kotlin.reflect.jvm.jvmName
/**
* A function declaration that could be compiled to [AsmGenerationContext].
*
* @param T the type the stored function returns.
*/
interface AsmExpression<T> { interface AsmExpression<T> {
/**
* Tries to evaluate this function without its variables. This method is intended for optimization.
*
* @return `null` if the function depends on its variables, the value if the function is a constant.
*/
fun tryEvaluate(): T? = null fun tryEvaluate(): T? = null
fun invoke(gen: AsmGenerationContext<T>)
/**
* Compiles this declaration.
*
* @param gen the target [AsmGenerationContext].
*/
fun compile(gen: AsmGenerationContext<T>)
} }
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) : internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) :
@ -16,13 +33,13 @@ internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val
private val expr: AsmExpression<T> = expr.optimize() private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) }
override fun invoke(gen: AsmGenerationContext<T>) { override fun compile(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra() gen.visitLoadAlgebra()
if (!hasSpecific(context, name, 1)) if (!hasSpecific(context, name, 1))
gen.visitStringConstant(name) gen.visitStringConstant(name)
expr.invoke(gen) expr.compile(gen)
if (gen.tryInvokeSpecific(context, name, 1)) if (gen.tryInvokeSpecific(context, name, 1))
return return
@ -54,14 +71,14 @@ internal class AsmBinaryOperation<T>(
) )
} }
override fun invoke(gen: AsmGenerationContext<T>) { override fun compile(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra() gen.visitLoadAlgebra()
if (!hasSpecific(context, name, 2)) if (!hasSpecific(context, name, 2))
gen.visitStringConstant(name) gen.visitStringConstant(name)
first.invoke(gen) first.compile(gen)
second.invoke(gen) second.compile(gen)
if (gen.tryInvokeSpecific(context, name, 2)) if (gen.tryInvokeSpecific(context, name, 2))
return return
@ -79,13 +96,13 @@ internal class AsmBinaryOperation<T>(
internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) : internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) :
AsmExpression<T> { AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromVariables(name, default) override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromVariables(name, default)
} }
internal class AsmConstantExpression<T>(private val value: T) : internal class AsmConstantExpression<T>(private val value: T) :
AsmExpression<T> { AsmExpression<T> {
override fun tryEvaluate(): T = value override fun tryEvaluate(): T = value
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value) override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
} }
internal class AsmConstProductExpression<T>( internal class AsmConstProductExpression<T>(
@ -98,10 +115,10 @@ internal class AsmConstProductExpression<T>(
override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const }
override fun invoke(gen: AsmGenerationContext<T>) { override fun compile(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra() gen.visitLoadAlgebra()
gen.visitNumberConstant(const) gen.visitNumberConstant(const)
expr.invoke(gen) expr.compile(gen)
gen.visitAlgebraOperation( gen.visitAlgebraOperation(
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
@ -117,7 +134,7 @@ internal class AsmNumberExpression<T>(private val context: NumericAlgebra<T>, pr
AsmExpression<T> { AsmExpression<T> {
override fun tryEvaluate(): T? = context.number(value) override fun tryEvaluate(): T? = context.number(value)
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitNumberConstant(value) override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitNumberConstant(value)
} }
internal abstract class FunctionalCompiledExpression<T> internal constructor( internal abstract class FunctionalCompiledExpression<T> internal constructor(

View File

@ -5,6 +5,9 @@ import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.AsmGenerationContext.ClassLoader import scientifik.kmath.asm.AsmGenerationContext.ClassLoader
import scientifik.kmath.asm.internal.visitLdcOrDConstInsn
import scientifik.kmath.asm.internal.visitLdcOrFConstInsn
import scientifik.kmath.asm.internal.visitLdcOrIConstInsn
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
/** /**
@ -16,7 +19,7 @@ import scientifik.kmath.operations.Algebra
* @param algebra the algebra the applied AsmExpressions use. * @param algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class. * @param className the unique class name of new loaded class.
*/ */
class AsmGenerationContext<T>( class AsmGenerationContext<T> @PublishedApi internal constructor(
private val classOfT: Class<*>, private val classOfT: Class<*>,
private val algebra: Algebra<T>, private val algebra: Algebra<T>,
private val className: String private val className: String

View File

@ -1,4 +1,4 @@
package scientifik.kmath.asm package scientifik.kmath.asm.internal
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Opcodes.*

View File

@ -1,8 +1,10 @@
package scientifik.kmath.asm package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.AsmConstantExpression
import scientifik.kmath.asm.AsmExpression
import scientifik.kmath.asm.AsmGenerationContext
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.ByteRing
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide") private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")