Dev #127
@ -1,14 +1,15 @@
|
|||||||
package scientifik.kmath.asm
|
package scientifik.kmath.asm
|
||||||
|
|
||||||
|
import org.objectweb.asm.Type
|
||||||
import scientifik.kmath.asm.internal.AsmBuilder
|
import scientifik.kmath.asm.internal.AsmBuilder
|
||||||
|
import scientifik.kmath.asm.internal.buildExpectationStack
|
||||||
import scientifik.kmath.asm.internal.buildName
|
import scientifik.kmath.asm.internal.buildName
|
||||||
import scientifik.kmath.asm.internal.hasSpecific
|
|
||||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||||
import scientifik.kmath.ast.*
|
import scientifik.kmath.ast.MST
|
||||||
|
import scientifik.kmath.ast.MSTExpression
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
import scientifik.kmath.operations.NumericAlgebra
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -18,6 +19,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
fun AsmBuilder<T>.visit(node: MST) {
|
fun AsmBuilder<T>.visit(node: MST) {
|
||||||
when (node) {
|
when (node) {
|
||||||
is MST.Symbolic -> loadVariable(node.value)
|
is MST.Symbolic -> loadVariable(node.value)
|
||||||
|
|
||||||
is MST.Numeric -> {
|
is MST.Numeric -> {
|
||||||
val constant = if (algebra is NumericAlgebra<T>)
|
val constant = if (algebra is NumericAlgebra<T>)
|
||||||
algebra.number(node.value)
|
algebra.number(node.value)
|
||||||
@ -26,48 +28,49 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
|
|
||||||
loadTConstant(constant)
|
loadTConstant(constant)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Unary -> {
|
is MST.Unary -> {
|
||||||
loadAlgebra()
|
loadAlgebra()
|
||||||
|
if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation)
|
||||||
if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation)
|
|
||||||
|
|
||||||
visit(node.value)
|
visit(node.value)
|
||||||
|
|
||||||
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
|
if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation(
|
||||||
invokeAlgebraOperation(
|
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
method = "unaryOperation",
|
||||||
method = "unaryOperation",
|
|
||||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
descriptor = Type.getMethodDescriptor(
|
||||||
"L${AsmBuilder.OBJECT_CLASS};)" +
|
AsmBuilder.OBJECT_TYPE,
|
||||||
"L${AsmBuilder.OBJECT_CLASS};"
|
AsmBuilder.STRING_TYPE,
|
||||||
)
|
AsmBuilder.OBJECT_TYPE
|
||||||
}
|
),
|
||||||
|
|
||||||
|
tArity = 1
|
||||||
|
)
|
||||||
}
|
}
|
||||||
is MST.Binary -> {
|
is MST.Binary -> {
|
||||||
loadAlgebra()
|
loadAlgebra()
|
||||||
|
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
|
||||||
if (!hasSpecific(algebra, node.operation, 2))
|
|
||||||
loadStringConstant(node.operation)
|
|
||||||
|
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
|
|
||||||
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
|
if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation(
|
||||||
|
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||||
|
method = "binaryOperation",
|
||||||
|
|
||||||
invokeAlgebraOperation(
|
descriptor = Type.getMethodDescriptor(
|
||||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
AsmBuilder.OBJECT_TYPE,
|
||||||
method = "binaryOperation",
|
AsmBuilder.STRING_TYPE,
|
||||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
AsmBuilder.OBJECT_TYPE,
|
||||||
"L${AsmBuilder.OBJECT_CLASS};" +
|
AsmBuilder.OBJECT_TYPE
|
||||||
"L${AsmBuilder.OBJECT_CLASS};)" +
|
),
|
||||||
"L${AsmBuilder.OBJECT_CLASS};"
|
|
||||||
)
|
tArity = 2
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -79,7 +82,3 @@ inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = ms
|
|||||||
* Optimize performance of an [MSTExpression] using ASM codegen
|
* Optimize performance of an [MSTExpression] using ASM codegen
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
||||||
|
|
||||||
fun main() {
|
|
||||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
|
||||||
}
|
|
||||||
|
@ -1,25 +1,25 @@
|
|||||||
package scientifik.kmath.asm.internal
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
import org.objectweb.asm.ClassWriter
|
import org.objectweb.asm.*
|
||||||
import org.objectweb.asm.Label
|
import org.objectweb.asm.Opcodes.AALOAD
|
||||||
import org.objectweb.asm.MethodVisitor
|
import org.objectweb.asm.Opcodes.RETURN
|
||||||
import org.objectweb.asm.Opcodes
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import java.util.*
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM Builder is a structure that abstracts building a class for unwrapping [MST] to plain Java expression.
|
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
||||||
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
||||||
*
|
*
|
||||||
* @param T the type generated [AsmCompiledExpression] operates.
|
* @param T the type of AsmExpression to unwrap.
|
||||||
* @property classOfT the [Class] of T.
|
* @param algebra the algebra the applied AsmExpressions use.
|
||||||
* @property algebra the algebra the applied AsmExpressions use.
|
* @param className the unique class name of new loaded class.
|
||||||
* @property className the unique class name of new loaded class.
|
|
||||||
* @property invokeLabel0Visitor the function applied to this object when the L0 label of invoke implementation is
|
|
||||||
* being written.
|
|
||||||
*/
|
*/
|
||||||
internal class AsmBuilder<T> internal constructor(
|
internal class AsmBuilder<T> internal constructor(
|
||||||
private val classOfT: Class<*>,
|
private val classOfT: KClass<*>,
|
||||||
private val algebra: Algebra<T>,
|
private val algebra: Algebra<T>,
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||||
@ -34,16 +34,16 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* The instance of [ClassLoader] used by this builder.
|
* The instance of [ClassLoader] used by this builder.
|
||||||
*/
|
*/
|
||||||
private val classLoader: ClassLoader =
|
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||||
ClassLoader(javaClass.classLoader)
|
|
||||||
|
|
||||||
@Suppress("PrivatePropertyName")
|
@Suppress("PrivatePropertyName")
|
||||||
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
|
private val T_ALGEBRA_TYPE: Type = algebra::class.asm
|
||||||
|
|
||||||
@Suppress("PrivatePropertyName")
|
@Suppress("PrivatePropertyName")
|
||||||
private val T_CLASS: String = classOfT.name.replace('.', '/')
|
internal val T_TYPE: Type = classOfT.asm
|
||||||
|
|
||||||
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
@Suppress("PrivatePropertyName")
|
||||||
|
private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass.
|
* Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass.
|
||||||
@ -63,7 +63,16 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Method visitor of `invoke` method of [AsmCompiledExpression] subclass.
|
* Method visitor of `invoke` method of [AsmCompiledExpression] subclass.
|
||||||
*/
|
*/
|
||||||
private lateinit var invokeMethodVisitor: MethodVisitor
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
internal var primitiveMode = false
|
||||||
|
|
||||||
|
@Suppress("PropertyName")
|
||||||
|
internal var PRIMITIVE_MASK: Type = OBJECT_TYPE
|
||||||
|
|
||||||
|
@Suppress("PropertyName")
|
||||||
|
internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE
|
||||||
|
private val typeStack = Stack<Type>()
|
||||||
|
internal val expectationStack = Stack<Type>().apply { push(T_TYPE) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The cache of [AsmCompiledExpression] subclass built by this builder.
|
* The cache of [AsmCompiledExpression] subclass built by this builder.
|
||||||
@ -76,81 +85,88 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* The built instance is cached.
|
* The built instance is cached.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
internal fun getInstance(): AsmCompiledExpression<T> {
|
fun getInstance(): AsmCompiledExpression<T> {
|
||||||
generatedInstance?.let { return it }
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
|
if (SIGNATURE_LETTERS.containsKey(classOfT.java)) {
|
||||||
|
primitiveMode = true
|
||||||
|
PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java)
|
||||||
|
PRIMITIVE_MASK_BOXED = T_TYPE
|
||||||
|
}
|
||||||
|
|
||||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||||
visit(
|
visit(
|
||||||
Opcodes.V1_8,
|
Opcodes.V1_8,
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
||||||
slashesClassName,
|
CLASS_TYPE.internalName,
|
||||||
"L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
|
"L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;",
|
||||||
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
ASM_COMPILED_EXPRESSION_TYPE.internalName,
|
||||||
arrayOf()
|
arrayOf()
|
||||||
)
|
)
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
access = Opcodes.ACC_PUBLIC,
|
Opcodes.ACC_PUBLIC,
|
||||||
name = "<init>",
|
"<init>",
|
||||||
descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V",
|
Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE),
|
||||||
signature = null,
|
null,
|
||||||
exceptions = null
|
null
|
||||||
) {
|
).instructionAdapter {
|
||||||
val thisVar = 0
|
val thisVar = 0
|
||||||
val algebraVar = 1
|
val algebraVar = 1
|
||||||
val constantsVar = 2
|
val constantsVar = 2
|
||||||
val l0 = Label()
|
val l0 = Label()
|
||||||
visitLabel(l0)
|
visitLabel(l0)
|
||||||
visitLoadObjectVar(thisVar)
|
load(thisVar, CLASS_TYPE)
|
||||||
visitLoadObjectVar(algebraVar)
|
load(algebraVar, ALGEBRA_TYPE)
|
||||||
visitLoadObjectVar(constantsVar)
|
load(constantsVar, OBJECT_ARRAY_TYPE)
|
||||||
|
|
||||||
visitInvokeSpecial(
|
invokespecial(
|
||||||
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
ASM_COMPILED_EXPRESSION_TYPE.internalName,
|
||||||
"<init>",
|
"<init>",
|
||||||
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V"
|
Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE),
|
||||||
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
val l1 = Label()
|
val l1 = Label()
|
||||||
visitLabel(l1)
|
visitLabel(l1)
|
||||||
visitReturn()
|
visitInsn(RETURN)
|
||||||
val l2 = Label()
|
val l2 = Label()
|
||||||
visitLabel(l2)
|
visitLabel(l2)
|
||||||
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
|
visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"algebra",
|
"algebra",
|
||||||
"L$ALGEBRA_CLASS;",
|
ALGEBRA_TYPE.descriptor,
|
||||||
"L$ALGEBRA_CLASS<L$T_CLASS;>;",
|
"L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;",
|
||||||
l0,
|
l0,
|
||||||
l2,
|
l2,
|
||||||
algebraVar
|
algebraVar
|
||||||
)
|
)
|
||||||
|
|
||||||
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar)
|
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar)
|
||||||
visitMaxs(0, 3)
|
visitMaxs(0, 3)
|
||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
||||||
name = "invoke",
|
"invoke",
|
||||||
descriptor = "(L$MAP_CLASS;)L$T_CLASS;",
|
Type.getMethodDescriptor(T_TYPE, MAP_TYPE),
|
||||||
signature = "(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}",
|
||||||
exceptions = null
|
null
|
||||||
) {
|
).instructionAdapter {
|
||||||
invokeMethodVisitor = this
|
invokeMethodVisitor = this
|
||||||
visitCode()
|
visitCode()
|
||||||
val l0 = Label()
|
val l0 = Label()
|
||||||
visitLabel(l0)
|
visitLabel(l0)
|
||||||
invokeLabel0Visitor()
|
invokeLabel0Visitor()
|
||||||
visitReturnObject()
|
areturn(T_TYPE)
|
||||||
val l1 = Label()
|
val l1 = Label()
|
||||||
visitLabel(l1)
|
visitLabel(l1)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"this",
|
"this",
|
||||||
"L$slashesClassName;",
|
CLASS_TYPE.descriptor,
|
||||||
null,
|
null,
|
||||||
l0,
|
l0,
|
||||||
l1,
|
l1,
|
||||||
@ -159,8 +175,8 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"arguments",
|
"arguments",
|
||||||
"L$MAP_CLASS;",
|
MAP_TYPE.descriptor,
|
||||||
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;",
|
||||||
l0,
|
l0,
|
||||||
l1,
|
l1,
|
||||||
invokeArgumentsVar
|
invokeArgumentsVar
|
||||||
@ -171,28 +187,28 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
||||||
name = "invoke",
|
"invoke",
|
||||||
descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||||
signature = null,
|
null,
|
||||||
exceptions = null
|
null
|
||||||
) {
|
).instructionAdapter {
|
||||||
val thisVar = 0
|
val thisVar = 0
|
||||||
val argumentsVar = 1
|
val argumentsVar = 1
|
||||||
visitCode()
|
visitCode()
|
||||||
val l0 = Label()
|
val l0 = Label()
|
||||||
visitLabel(l0)
|
visitLabel(l0)
|
||||||
visitLoadObjectVar(thisVar)
|
load(thisVar, OBJECT_TYPE)
|
||||||
visitLoadObjectVar(argumentsVar)
|
load(argumentsVar, MAP_TYPE)
|
||||||
visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;")
|
invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false)
|
||||||
visitReturnObject()
|
areturn(T_TYPE)
|
||||||
val l1 = Label()
|
val l1 = Label()
|
||||||
visitLabel(l1)
|
visitLabel(l1)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"this",
|
"this",
|
||||||
"L$slashesClassName;",
|
CLASS_TYPE.descriptor,
|
||||||
T_CLASS,
|
null,
|
||||||
l0,
|
l0,
|
||||||
l1,
|
l1,
|
||||||
thisVar
|
thisVar
|
||||||
@ -219,89 +235,111 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* Loads a constant from
|
* Loads a constant from
|
||||||
*/
|
*/
|
||||||
internal fun loadTConstant(value: T) {
|
internal fun loadTConstant(value: T) {
|
||||||
if (classOfT in INLINABLE_NUMBERS) {
|
if (classOfT.java in INLINABLE_NUMBERS) {
|
||||||
loadNumberConstant(value as Number)
|
val expectedType = expectationStack.pop()!!
|
||||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||||
|
loadNumberConstant(value as Number, mustBeBoxed)
|
||||||
|
if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loadConstant(value as Any, T_CLASS)
|
loadConstant(value as Any, T_TYPE)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private fun box(): Unit = invokeMethodVisitor.invokestatic(
|
||||||
* Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type].
|
T_TYPE.internalName,
|
||||||
*/
|
"valueOf",
|
||||||
private fun loadConstant(value: Any, type: String) {
|
Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK),
|
||||||
|
false
|
||||||
|
)
|
||||||
|
|
||||||
|
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
|
||||||
|
NUMBER_TYPE.internalName,
|
||||||
|
NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK),
|
||||||
|
Type.getMethodDescriptor(PRIMITIVE_MASK),
|
||||||
|
false
|
||||||
|
)
|
||||||
|
|
||||||
|
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||||
|
loadThis()
|
||||||
invokeMethodVisitor.run {
|
getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
loadThis()
|
iconst(idx)
|
||||||
visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;")
|
visitInsn(AALOAD)
|
||||||
visitLdcOrIntConstant(idx)
|
checkcast(type)
|
||||||
visitGetObjectArrayElement()
|
|
||||||
invokeMethodVisitor.visitCheckCast(type)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar)
|
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive
|
* Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive
|
||||||
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
|
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
|
||||||
* from it).
|
* from it).
|
||||||
*/
|
*/
|
||||||
private fun loadNumberConstant(value: Number) {
|
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
||||||
val clazz = value.javaClass
|
val boxed = value::class.asm
|
||||||
val c = clazz.name.replace('.', '/')
|
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
||||||
val sigLetter = SIGNATURE_LETTERS[clazz]
|
|
||||||
|
|
||||||
if (sigLetter != null) {
|
if (primitive != null) {
|
||||||
when (value) {
|
when (primitive) {
|
||||||
is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt())
|
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt())
|
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
||||||
is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value)
|
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
|
||||||
is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value)
|
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
|
||||||
is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value)
|
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value)
|
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
else -> invokeMethodVisitor.visitLdcInsn(value)
|
}
|
||||||
|
|
||||||
|
if (mustBeBoxed) {
|
||||||
|
box()
|
||||||
|
invokeMethodVisitor.checkcast(T_TYPE)
|
||||||
}
|
}
|
||||||
|
|
||||||
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loadConstant(value, c)
|
loadConstant(value, boxed)
|
||||||
|
if (!mustBeBoxed) unbox()
|
||||||
|
else invokeMethodVisitor.checkcast(T_TYPE)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided.
|
* Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided.
|
||||||
*/
|
*/
|
||||||
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||||
visitLoadObjectVar(invokeArgumentsVar)
|
load(invokeArgumentsVar, OBJECT_ARRAY_TYPE)
|
||||||
|
|
||||||
if (defaultValue != null) {
|
if (defaultValue != null) {
|
||||||
visitLdcInsn(name)
|
loadStringConstant(name)
|
||||||
loadTConstant(defaultValue)
|
loadTConstant(defaultValue)
|
||||||
|
|
||||||
visitInvokeInterface(
|
invokeinterface(
|
||||||
owner = MAP_CLASS,
|
MAP_TYPE.internalName,
|
||||||
name = "getOrDefault",
|
"getOrDefault",
|
||||||
descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;"
|
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE)
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
invokeMethodVisitor.checkcast(T_TYPE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
visitLdcInsn(name)
|
loadStringConstant(name)
|
||||||
|
|
||||||
visitInvokeInterface(
|
invokeinterface(
|
||||||
owner = MAP_CLASS,
|
MAP_TYPE.internalName,
|
||||||
name = "get",
|
"get",
|
||||||
descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;"
|
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
invokeMethodVisitor.checkcast(T_TYPE)
|
||||||
|
val expectedType = expectationStack.pop()!!
|
||||||
|
|
||||||
|
if (expectedType.sort == Type.OBJECT)
|
||||||
|
typeStack.push(T_TYPE)
|
||||||
|
else {
|
||||||
|
unbox()
|
||||||
|
typeStack.push(PRIMITIVE_MASK)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -310,13 +348,10 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
internal fun loadAlgebra() {
|
internal fun loadAlgebra() {
|
||||||
loadThis()
|
loadThis()
|
||||||
|
|
||||||
invokeMethodVisitor.visitGetField(
|
invokeMethodVisitor.run {
|
||||||
owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor)
|
||||||
name = "algebra",
|
checkcast(T_ALGEBRA_TYPE)
|
||||||
descriptor = "L$ALGEBRA_CLASS;"
|
}
|
||||||
)
|
|
||||||
|
|
||||||
invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -330,29 +365,70 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
owner: String,
|
owner: String,
|
||||||
method: String,
|
method: String,
|
||||||
descriptor: String,
|
descriptor: String,
|
||||||
|
tArity: Int,
|
||||||
opcode: Int = Opcodes.INVOKEINTERFACE
|
opcode: Int = Opcodes.INVOKEINTERFACE
|
||||||
) {
|
) {
|
||||||
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, opcode == Opcodes.INVOKEINTERFACE)
|
repeat(tArity) { if (!typeStack.empty()) typeStack.pop() }
|
||||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
|
||||||
|
invokeMethodVisitor.visitMethodInsn(
|
||||||
|
opcode,
|
||||||
|
owner,
|
||||||
|
method,
|
||||||
|
descriptor,
|
||||||
|
opcode == Opcodes.INVOKEINTERFACE
|
||||||
|
)
|
||||||
|
|
||||||
|
invokeMethodVisitor.checkcast(T_TYPE)
|
||||||
|
val isLastExpr = expectationStack.size == 1
|
||||||
|
val expectedType = expectationStack.pop()!!
|
||||||
|
|
||||||
|
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
||||||
|
typeStack.push(T_TYPE)
|
||||||
|
else {
|
||||||
|
unbox()
|
||||||
|
typeStack.push(PRIMITIVE_MASK)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Writes a LDC Instruction with string constant provided.
|
* Writes a LDC Instruction with string constant provided.
|
||||||
*/
|
*/
|
||||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string)
|
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
||||||
|
|
||||||
internal companion object {
|
internal companion object {
|
||||||
/**
|
/**
|
||||||
* Maps JVM primitive numbers boxed types to their letters of JVM signature convention.
|
* Maps JVM primitive numbers boxed types to their letters of JVM signature convention.
|
||||||
*/
|
*/
|
||||||
private val SIGNATURE_LETTERS: Map<Class<out Any>, String> by lazy {
|
private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
|
||||||
hashMapOf(
|
hashMapOf(
|
||||||
java.lang.Byte::class.java to "B",
|
java.lang.Byte::class.java to Type.BYTE_TYPE,
|
||||||
java.lang.Short::class.java to "S",
|
java.lang.Short::class.java to Type.SHORT_TYPE,
|
||||||
java.lang.Integer::class.java to "I",
|
java.lang.Integer::class.java to Type.INT_TYPE,
|
||||||
java.lang.Long::class.java to "J",
|
java.lang.Long::class.java to Type.LONG_TYPE,
|
||||||
java.lang.Float::class.java to "F",
|
java.lang.Float::class.java to Type.FLOAT_TYPE,
|
||||||
java.lang.Double::class.java to "D"
|
java.lang.Double::class.java to Type.DOUBLE_TYPE
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
|
||||||
|
hashMapOf(
|
||||||
|
java.lang.Byte::class.asm to Type.BYTE_TYPE,
|
||||||
|
java.lang.Short::class.asm to Type.SHORT_TYPE,
|
||||||
|
java.lang.Integer::class.asm to Type.INT_TYPE,
|
||||||
|
java.lang.Long::class.asm to Type.LONG_TYPE,
|
||||||
|
java.lang.Float::class.asm to Type.FLOAT_TYPE,
|
||||||
|
java.lang.Double::class.asm to Type.DOUBLE_TYPE
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
|
||||||
|
hashMapOf(
|
||||||
|
Type.BYTE_TYPE to "byteValue",
|
||||||
|
Type.SHORT_TYPE to "shortValue",
|
||||||
|
Type.INT_TYPE to "intValue",
|
||||||
|
Type.LONG_TYPE to "longValue",
|
||||||
|
Type.FLOAT_TYPE to "floatValue",
|
||||||
|
Type.DOUBLE_TYPE to "doubleValue"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -360,13 +436,14 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
||||||
*/
|
*/
|
||||||
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||||
|
internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm
|
||||||
|
internal val NUMBER_TYPE: Type = java.lang.Number::class.asm
|
||||||
|
internal val MAP_TYPE: Type = java.util.Map::class.asm
|
||||||
|
internal val OBJECT_TYPE: Type = java.lang.Object::class.asm
|
||||||
|
|
||||||
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
||||||
"scientifik/kmath/asm/internal/AsmCompiledExpression"
|
internal val OBJECT_ARRAY_TYPE: Type = Array<java.lang.Object>::class.asm
|
||||||
|
internal val ALGEBRA_TYPE: Type = Algebra::class.asm
|
||||||
internal const val MAP_CLASS = "java/util/Map"
|
internal val STRING_TYPE: Type = java.lang.String::class.asm
|
||||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
|
||||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
|
||||||
internal const val STRING_CLASS = "java/lang/String"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,4 +16,3 @@ internal abstract class AsmCompiledExpression<T> internal constructor(
|
|||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
abstract override fun invoke(arguments: Map<String, T>): T
|
abstract override fun invoke(arguments: Map<String, T>): T
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import org.objectweb.asm.Type
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
internal val KClass<*>.asm: Type
|
||||||
|
get() = Type.getType(java)
|
@ -1,60 +1,9 @@
|
|||||||
package scientifik.kmath.asm.internal
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
import org.objectweb.asm.MethodVisitor
|
import org.objectweb.asm.MethodVisitor
|
||||||
import org.objectweb.asm.Opcodes.*
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
|
|
||||||
internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) {
|
fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
|
||||||
-1 -> visitInsn(ICONST_M1)
|
|
||||||
0 -> visitInsn(ICONST_0)
|
|
||||||
1 -> visitInsn(ICONST_1)
|
|
||||||
2 -> visitInsn(ICONST_2)
|
|
||||||
3 -> visitInsn(ICONST_3)
|
|
||||||
4 -> visitInsn(ICONST_4)
|
|
||||||
5 -> visitInsn(ICONST_5)
|
|
||||||
in -128..127 -> visitIntInsn(BIPUSH, value)
|
|
||||||
in -32768..32767 -> visitIntInsn(SIPUSH, value)
|
|
||||||
else -> visitLdcInsn(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) {
|
fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
|
||||||
0.0 -> visitInsn(DCONST_0)
|
instructionAdapter().apply(block)
|
||||||
1.0 -> visitInsn(DCONST_1)
|
|
||||||
else -> visitLdcInsn(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) {
|
|
||||||
0L -> visitInsn(LCONST_0)
|
|
||||||
1L -> visitInsn(LCONST_1)
|
|
||||||
else -> visitLdcInsn(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) {
|
|
||||||
0f -> visitInsn(FCONST_0)
|
|
||||||
1f -> visitInsn(FCONST_1)
|
|
||||||
2f -> visitInsn(FCONST_2)
|
|
||||||
else -> visitLdcInsn(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit =
|
|
||||||
visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit =
|
|
||||||
visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit =
|
|
||||||
visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit =
|
|
||||||
visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit =
|
|
||||||
visitFieldInsn(GETFIELD, owner, name, descriptor)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN)
|
|
||||||
internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN)
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.asm.internal
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
import org.objectweb.asm.Opcodes
|
import org.objectweb.asm.Opcodes
|
||||||
|
import org.objectweb.asm.Type
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
|
|
||||||
private val methodNameAdapters: Map<String, String> by lazy {
|
private val methodNameAdapters: Map<String, String> by lazy {
|
||||||
@ -12,17 +13,19 @@ private val methodNameAdapters: Map<String, String> by lazy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity].
|
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
|
||||||
|
* type expectation stack for needed arity.
|
||||||
*
|
*
|
||||||
* @return `true` if contains, else `false`.
|
* @return `true` if contains, else `false`.
|
||||||
*/
|
*/
|
||||||
internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
val aName = methodNameAdapters[name] ?: name
|
val aName = methodNameAdapters[name] ?: name
|
||||||
|
|
||||||
context.javaClass.methods.find { it.name == aName && it.parameters.size == arity }
|
val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null
|
||||||
?: return false
|
val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE
|
||||||
|
repeat(arity) { expectationStack.push(t) }
|
||||||
|
|
||||||
return true
|
return hasSpecific
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -34,22 +37,23 @@ internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boo
|
|||||||
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
val aName = methodNameAdapters[name] ?: name
|
val aName = methodNameAdapters[name] ?: name
|
||||||
|
|
||||||
context.javaClass.methods.find { it.name == aName && it.parameters.size == arity }
|
val method =
|
||||||
?: return false
|
context.javaClass.methods.find {
|
||||||
|
var suitableSignature = it.name == aName && it.parameters.size == arity
|
||||||
|
|
||||||
|
if (primitiveMode && it.isBridge)
|
||||||
|
suitableSignature = false
|
||||||
|
|
||||||
|
suitableSignature
|
||||||
|
} ?: return false
|
||||||
|
|
||||||
val owner = context::class.java.name.replace('.', '/')
|
val owner = context::class.java.name.replace('.', '/')
|
||||||
|
|
||||||
val sig = buildString {
|
|
||||||
append('(')
|
|
||||||
repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") }
|
|
||||||
append(')')
|
|
||||||
append("L${AsmBuilder.OBJECT_CLASS};")
|
|
||||||
}
|
|
||||||
|
|
||||||
invokeAlgebraOperation(
|
invokeAlgebraOperation(
|
||||||
owner = owner,
|
owner = owner,
|
||||||
method = aName,
|
method = aName,
|
||||||
descriptor = sig,
|
descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }),
|
||||||
|
tArity = arity,
|
||||||
opcode = Opcodes.INVOKEVIRTUAL
|
opcode = Opcodes.INVOKEVIRTUAL
|
||||||
)
|
)
|
||||||
|
|
@ -1,66 +1,110 @@
|
|||||||
package scietifik.kmath.asm
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
//
|
import scientifik.kmath.asm.compile
|
||||||
//class TestAsmAlgebras {
|
import scientifik.kmath.ast.mstInField
|
||||||
// @Test
|
import scientifik.kmath.ast.mstInRing
|
||||||
// fun space() {
|
import scientifik.kmath.ast.mstInSpace
|
||||||
// val res = ByteRing.asmInRing {
|
import scientifik.kmath.expressions.invoke
|
||||||
// binaryOperation(
|
import scientifik.kmath.operations.ByteRing
|
||||||
// "+",
|
import scientifik.kmath.operations.RealField
|
||||||
//
|
import kotlin.test.Test
|
||||||
// unaryOperation(
|
import kotlin.test.assertEquals
|
||||||
// "+",
|
|
||||||
// 3.toByte() - (2.toByte() + (multiply(
|
class TestAsmAlgebras {
|
||||||
// add(number(1), number(1)),
|
@Test
|
||||||
// 2
|
fun space() {
|
||||||
// ) + 1.toByte()) * 3.toByte() - 1.toByte())
|
val res1 = ByteRing.mstInSpace {
|
||||||
// ),
|
binaryOperation(
|
||||||
//
|
"+",
|
||||||
// number(1)
|
|
||||||
// ) + symbol("x") + zero
|
unaryOperation(
|
||||||
// }("x" to 2.toByte())
|
"+",
|
||||||
//
|
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||||
// assertEquals(16, res)
|
add(number(1), number(1)),
|
||||||
// }
|
2
|
||||||
//
|
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||||
// @Test
|
),
|
||||||
// fun ring() {
|
|
||||||
// val res = ByteRing.asmInRing {
|
number(1)
|
||||||
// binaryOperation(
|
) + symbol("x") + zero
|
||||||
// "+",
|
}("x" to 2.toByte())
|
||||||
//
|
|
||||||
// unaryOperation(
|
val res2 = ByteRing.mstInSpace {
|
||||||
// "+",
|
binaryOperation(
|
||||||
// (3.toByte() - (2.toByte() + (multiply(
|
"+",
|
||||||
// add(const(1), const(1)),
|
|
||||||
// 2
|
unaryOperation(
|
||||||
// ) + 1.toByte()))) * 3.0 - 1.toByte()
|
"+",
|
||||||
// ),
|
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||||
//
|
add(number(1), number(1)),
|
||||||
// number(1)
|
2
|
||||||
// ) * const(2)
|
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||||
// }()
|
),
|
||||||
//
|
|
||||||
// assertEquals(24, res)
|
number(1)
|
||||||
// }
|
) + symbol("x") + zero
|
||||||
//
|
}.compile()("x" to 2.toByte())
|
||||||
// @Test
|
|
||||||
// fun field() {
|
assertEquals(res1, res2)
|
||||||
// val res = RealField.asmInField {
|
}
|
||||||
// +(3 - 2 + 2*(number(1)+1.0)
|
|
||||||
//
|
@Test
|
||||||
// unaryOperation(
|
fun ring() {
|
||||||
// "+",
|
val res1 = ByteRing.mstInRing {
|
||||||
// (3.0 - (2.0 + (multiply(
|
binaryOperation(
|
||||||
// add((1.0), const(1.0)),
|
"+",
|
||||||
// 2
|
|
||||||
// ) + 1.0))) * 3 - 1.0
|
unaryOperation(
|
||||||
// )+
|
"+",
|
||||||
//
|
(symbol("x") - (2.toByte() + (multiply(
|
||||||
// number(1)
|
add(number(1), number(1)),
|
||||||
// ) / 2, const(2.0)) * one
|
2
|
||||||
// }()
|
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||||
//
|
),
|
||||||
// assertEquals(3.0, res)
|
|
||||||
// }
|
number(1)
|
||||||
//}
|
) * number(2)
|
||||||
|
}("x" to 3.toByte())
|
||||||
|
|
||||||
|
val res2 = ByteRing.mstInRing {
|
||||||
|
binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
(symbol("x") - (2.toByte() + (multiply(
|
||||||
|
add(number(1), number(1)),
|
||||||
|
2
|
||||||
|
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) * number(2)
|
||||||
|
}.compile()("x" to 3.toByte())
|
||||||
|
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun field() {
|
||||||
|
val res1 = RealField.mstInField {
|
||||||
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||||
|
"+",
|
||||||
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
|
+ number(1),
|
||||||
|
1 / 2 + number(2.0) * one
|
||||||
|
)
|
||||||
|
}("x" to 2.0)
|
||||||
|
|
||||||
|
val res2 = RealField.mstInField {
|
||||||
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||||
|
"+",
|
||||||
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
|
+ number(1),
|
||||||
|
1 / 2 + number(2.0) * one
|
||||||
|
)
|
||||||
|
}.compile()("x" to 2.0)
|
||||||
|
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -16,6 +16,13 @@ class TestAsmExpressions {
|
|||||||
assertEquals(-2.0, res)
|
assertEquals(-2.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBinaryOperationInvocation() {
|
||||||
|
val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile()
|
||||||
|
val res = expression("x" to 2.0)
|
||||||
|
assertEquals(-1.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testConstProductInvocation() {
|
fun testConstProductInvocation() {
|
||||||
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
package scietifik.kmath.ast
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
import scientifik.kmath.ast.evaluate
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.asm.expression
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
import scientifik.kmath.ast.parseMath
|
import scientifik.kmath.ast.parseMath
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -9,9 +12,15 @@ import kotlin.test.assertEquals
|
|||||||
|
|
||||||
class AsmTest {
|
class AsmTest {
|
||||||
@Test
|
@Test
|
||||||
fun parsedExpression() {
|
fun `compile MST`() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
val res = ComplexField.evaluate(mst)
|
val res = ComplexField.expression(mst)()
|
||||||
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `compile MSTExpression`() {
|
||||||
|
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()()
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,17 +1,25 @@
|
|||||||
package scietifik.kmath.ast
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
import scientifik.kmath.ast.evaluate
|
import scientifik.kmath.ast.evaluate
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
import scientifik.kmath.ast.parseMath
|
import scientifik.kmath.ast.parseMath
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
import kotlin.test.assertEquals
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class ParserTest {
|
internal class ParserTest {
|
||||||
@Test
|
@Test
|
||||||
fun parsedExpression() {
|
fun `evaluate MST`() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
val res = ComplexField.evaluate(mst)
|
val res = ComplexField.evaluate(mst)
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MSTExpression`() {
|
||||||
|
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||||
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user