forked from kscience/kmath
Move asm dependency to implementation configuration; rename many ASM API classes, make AsmCompiledExpression implement functional Expression, fix typos, encapsulate AsmGenerationContext
This commit is contained in:
parent
a2e33bf6d8
commit
6ac0297530
@ -4,6 +4,6 @@ plugins {
|
|||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(path = ":kmath-core"))
|
api(project(path = ":kmath-core"))
|
||||||
api("org.ow2.asm:asm:8.0.1")
|
implementation("org.ow2.asm:asm:8.0.1")
|
||||||
api("org.ow2.asm:asm-commons:8.0.1")
|
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||||
}
|
}
|
||||||
|
@ -8,11 +8,30 @@ import scientifik.kmath.operations.Algebra
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<out Any>) {
|
abstract class AsmCompiledExpression<T> internal constructor(
|
||||||
abstract fun evaluate(arguments: Map<String, T>): T
|
@JvmField val algebra: Algebra<T>,
|
||||||
|
@JvmField val constants: MutableList<out Any>
|
||||||
|
) : Expression<T> {
|
||||||
|
abstract override fun invoke(arguments: Map<String, T>): T
|
||||||
}
|
}
|
||||||
|
|
||||||
class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T>, private val className: String) {
|
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
|
||||||
|
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
|
||||||
|
|
||||||
|
try {
|
||||||
|
Class.forName(name)
|
||||||
|
} catch (ignored: ClassNotFoundException) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildName(expression, collision + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
class AsmGenerationContext<T>(
|
||||||
|
classOfT: Class<*>,
|
||||||
|
private val algebra: Algebra<T>,
|
||||||
|
private val className: String
|
||||||
|
) {
|
||||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||||
}
|
}
|
||||||
@ -26,30 +45,23 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
private val T_CLASS: String = classOfT.name.replace('.', '/')
|
private val T_CLASS: String = classOfT.name.replace('.', '/')
|
||||||
|
|
||||||
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
||||||
private val evaluateThisVar: Int = 0
|
private val invokeThisVar: Int = 0
|
||||||
private val evaluateArgumentsVar: Int = 1
|
private val invokeArgumentsVar: Int = 1
|
||||||
private var maxStack: Int = 0
|
private var maxStack: Int = 0
|
||||||
private lateinit var constants: MutableList<Any>
|
private val constants: MutableList<Any> = mutableListOf()
|
||||||
private lateinit var asmCompiledClassWriter: ClassWriter
|
private val asmCompiledClassWriter: ClassWriter = ClassWriter(0)
|
||||||
private lateinit var evaluateMethodVisitor: MethodVisitor
|
private val invokeMethodVisitor: MethodVisitor
|
||||||
private lateinit var evaluateL0: Label
|
private val invokeL0: Label
|
||||||
private lateinit var evaluateL1: Label
|
private lateinit var invokeL1: Label
|
||||||
|
private var generatedInstance: AsmCompiledExpression<T>? = null
|
||||||
|
|
||||||
init {
|
init {
|
||||||
start()
|
|
||||||
}
|
|
||||||
|
|
||||||
fun start() {
|
|
||||||
constants = mutableListOf()
|
|
||||||
asmCompiledClassWriter = ClassWriter(0)
|
|
||||||
maxStack = 0
|
|
||||||
|
|
||||||
asmCompiledClassWriter.visit(
|
asmCompiledClassWriter.visit(
|
||||||
V1_8,
|
V1_8,
|
||||||
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
||||||
slashesClassName,
|
slashesClassName,
|
||||||
"L$ASM_COMPILED_CLASS<L$T_CLASS;>;",
|
"L$ASM_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
|
||||||
ASM_COMPILED_CLASS,
|
ASM_COMPILED_EXPRESSION_CLASS,
|
||||||
arrayOf()
|
arrayOf()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,7 +78,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
|
|
||||||
visitMethodInsn(
|
visitMethodInsn(
|
||||||
INVOKESPECIAL,
|
INVOKESPECIAL,
|
||||||
ASM_COMPILED_CLASS,
|
ASM_COMPILED_EXPRESSION_CLASS,
|
||||||
"<init>",
|
"<init>",
|
||||||
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
||||||
false
|
false
|
||||||
@ -93,45 +105,48 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
|
|
||||||
evaluateMethodVisitor = visitMethod(
|
invokeMethodVisitor = visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL,
|
ACC_PUBLIC or ACC_FINAL,
|
||||||
"evaluate",
|
"invoke",
|
||||||
"(L$MAP_CLASS;)L$T_CLASS;",
|
"(L$MAP_CLASS;)L$T_CLASS;",
|
||||||
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
||||||
null
|
null
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluateMethodVisitor.run {
|
invokeMethodVisitor.run {
|
||||||
visitCode()
|
visitCode()
|
||||||
evaluateL0 = Label()
|
invokeL0 = Label()
|
||||||
visitLabel(evaluateL0)
|
visitLabel(invokeL0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
fun generate(): AsmCompiled<T> {
|
internal fun generate(): AsmCompiledExpression<T> {
|
||||||
evaluateMethodVisitor.run {
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
|
invokeMethodVisitor.run {
|
||||||
visitInsn(ARETURN)
|
visitInsn(ARETURN)
|
||||||
evaluateL1 = Label()
|
invokeL1 = Label()
|
||||||
visitLabel(evaluateL1)
|
visitLabel(invokeL1)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"this",
|
"this",
|
||||||
"L$slashesClassName;",
|
"L$slashesClassName;",
|
||||||
T_CLASS,
|
T_CLASS,
|
||||||
evaluateL0,
|
invokeL0,
|
||||||
evaluateL1,
|
invokeL1,
|
||||||
evaluateThisVar
|
invokeThisVar
|
||||||
)
|
)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"arguments",
|
"arguments",
|
||||||
"L$MAP_CLASS;",
|
"L$MAP_CLASS;",
|
||||||
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
||||||
evaluateL0,
|
invokeL0,
|
||||||
evaluateL1,
|
invokeL1,
|
||||||
evaluateArgumentsVar
|
invokeArgumentsVar
|
||||||
)
|
)
|
||||||
|
|
||||||
visitMaxs(maxStack + 1, 2)
|
visitMaxs(maxStack + 1, 2)
|
||||||
@ -140,7 +155,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
|
|
||||||
asmCompiledClassWriter.visitMethod(
|
asmCompiledClassWriter.visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||||
"evaluate",
|
"invoke",
|
||||||
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
||||||
null,
|
null,
|
||||||
null
|
null
|
||||||
@ -151,7 +166,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
visitLabel(l0)
|
visitLabel(l0)
|
||||||
visitVarInsn(ALOAD, 0)
|
visitVarInsn(ALOAD, 0)
|
||||||
visitVarInsn(ALOAD, 1)
|
visitVarInsn(ALOAD, 1)
|
||||||
visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "evaluate", "(L$MAP_CLASS;)L$T_CLASS;", false)
|
visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false)
|
||||||
visitInsn(ARETURN)
|
visitInsn(ARETURN)
|
||||||
val l1 = Label()
|
val l1 = Label()
|
||||||
visitLabel(l1)
|
visitLabel(l1)
|
||||||
@ -171,31 +186,34 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
|
|
||||||
asmCompiledClassWriter.visitEnd()
|
asmCompiledClassWriter.visitEnd()
|
||||||
|
|
||||||
return classLoader
|
val new = classLoader
|
||||||
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||||
.constructors
|
.constructors
|
||||||
.first()
|
.first()
|
||||||
.newInstance(algebra, constants) as AsmCompiled<T>
|
.newInstance(algebra, constants) as AsmCompiledExpression<T>
|
||||||
|
|
||||||
|
generatedInstance = new
|
||||||
|
return new
|
||||||
}
|
}
|
||||||
|
|
||||||
fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS)
|
internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS)
|
||||||
|
|
||||||
fun visitLoadAnyFromConstants(value: Any, type: String) {
|
private fun visitLoadAnyFromConstants(value: Any, type: String) {
|
||||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||||
maxStack++
|
maxStack++
|
||||||
|
|
||||||
evaluateMethodVisitor.run {
|
invokeMethodVisitor.run {
|
||||||
visitLoadThis()
|
visitLoadThis()
|
||||||
visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;")
|
visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;")
|
||||||
visitLdcOrIConstInsn(idx)
|
visitLdcOrIConstInsn(idx)
|
||||||
visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true)
|
visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true)
|
||||||
evaluateMethodVisitor.visitTypeInsn(CHECKCAST, type)
|
invokeMethodVisitor.visitTypeInsn(CHECKCAST, type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
|
private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar)
|
||||||
|
|
||||||
fun visitNumberConstant(value: Number) {
|
internal fun visitNumberConstant(value: Number) {
|
||||||
maxStack++
|
maxStack++
|
||||||
val clazz = value.javaClass
|
val clazz = value.javaClass
|
||||||
val c = clazz.name.replace('.', '/')
|
val c = clazz.name.replace('.', '/')
|
||||||
@ -204,22 +222,22 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
|
|
||||||
if (sigLetter != null) {
|
if (sigLetter != null) {
|
||||||
when (value) {
|
when (value) {
|
||||||
is Int -> evaluateMethodVisitor.visitLdcOrIConstInsn(value)
|
is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value)
|
||||||
is Double -> evaluateMethodVisitor.visitLdcOrDConstInsn(value)
|
is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value)
|
||||||
is Float -> evaluateMethodVisitor.visitLdcOrFConstInsn(value)
|
is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value)
|
||||||
else -> evaluateMethodVisitor.visitLdcInsn(value)
|
else -> invokeMethodVisitor.visitLdcInsn(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
evaluateMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
|
invokeMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
visitLoadAnyFromConstants(value, c)
|
visitLoadAnyFromConstants(value, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
|
internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run {
|
||||||
maxStack += 2
|
maxStack += 2
|
||||||
visitVarInsn(ALOAD, evaluateArgumentsVar)
|
visitVarInsn(ALOAD, invokeArgumentsVar)
|
||||||
|
|
||||||
if (defaultValue != null) {
|
if (defaultValue != null) {
|
||||||
visitLdcInsn(name)
|
visitLdcInsn(name)
|
||||||
@ -242,26 +260,26 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
visitCastToT()
|
visitCastToT()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun visitLoadAlgebra() {
|
internal fun visitLoadAlgebra() {
|
||||||
evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
|
invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar)
|
||||||
|
|
||||||
evaluateMethodVisitor.visitFieldInsn(
|
invokeMethodVisitor.visitFieldInsn(
|
||||||
GETFIELD,
|
GETFIELD,
|
||||||
ASM_COMPILED_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS)
|
invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun visitAlgebraOperation(owner: String, method: String, descriptor: String) {
|
internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) {
|
||||||
maxStack++
|
maxStack++
|
||||||
evaluateMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true)
|
invokeMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true)
|
||||||
visitCastToT()
|
visitCastToT()
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
||||||
|
|
||||||
companion object {
|
internal companion object {
|
||||||
private val SIGNATURE_LETTERS = mapOf(
|
private val SIGNATURE_LETTERS = mapOf(
|
||||||
java.lang.Byte::class.java to "B",
|
java.lang.Byte::class.java to "B",
|
||||||
java.lang.Short::class.java to "S",
|
java.lang.Short::class.java to "S",
|
||||||
@ -271,16 +289,16 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
java.lang.Double::class.java to "D"
|
java.lang.Double::class.java to "D"
|
||||||
)
|
)
|
||||||
|
|
||||||
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
|
internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression"
|
||||||
const val LIST_CLASS = "java/util/List"
|
internal const val LIST_CLASS = "java/util/List"
|
||||||
const val MAP_CLASS = "java/util/Map"
|
internal const val MAP_CLASS = "java/util/Map"
|
||||||
const val OBJECT_CLASS = "java/lang/Object"
|
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||||
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||||
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||||
const val STRING_CLASS = "java/lang/String"
|
internal const val STRING_CLASS = "java/lang/String"
|
||||||
const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
||||||
const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
|
internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
|
||||||
const val NUMBER_CLASS = "java/lang/Number"
|
internal const val NUMBER_CLASS = "java/lang/Number"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,43 +384,25 @@ internal class AsmDivExpression<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
open class AsmFunctionalExpressionSpace<T>(
|
open class AsmExpressionSpace<T>(space: Space<T>) : Space<AsmExpression<T>>,
|
||||||
val space: Space<T>,
|
|
||||||
one: T
|
|
||||||
) : Space<AsmExpression<T>>,
|
|
||||||
ExpressionSpace<T, AsmExpression<T>> {
|
ExpressionSpace<T, AsmExpression<T>> {
|
||||||
override val zero: AsmExpression<T> =
|
override val zero: AsmExpression<T> = AsmConstantExpression(space.zero)
|
||||||
AsmConstantExpression(space.zero)
|
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
||||||
|
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
||||||
override fun const(value: T): AsmExpression<T> =
|
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmSumExpression(a, b)
|
||||||
AsmConstantExpression(value)
|
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(a, k)
|
||||||
|
|
||||||
override fun variable(name: String, default: T?): AsmExpression<T> =
|
|
||||||
AsmVariableExpression(
|
|
||||||
name,
|
|
||||||
default
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
|
||||||
AsmSumExpression(a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> =
|
|
||||||
AsmConstProductExpression(a, k)
|
|
||||||
|
|
||||||
|
|
||||||
operator fun AsmExpression<T>.plus(arg: T) = this + const(arg)
|
operator fun AsmExpression<T>.plus(arg: T) = this + const(arg)
|
||||||
operator fun AsmExpression<T>.minus(arg: T) = this - const(arg)
|
operator fun AsmExpression<T>.minus(arg: T) = this - const(arg)
|
||||||
|
|
||||||
operator fun T.plus(arg: AsmExpression<T>) = arg + this
|
operator fun T.plus(arg: AsmExpression<T>) = arg + this
|
||||||
operator fun T.minus(arg: AsmExpression<T>) = arg - this
|
operator fun T.minus(arg: AsmExpression<T>) = arg - this
|
||||||
}
|
}
|
||||||
|
|
||||||
class AsmFunctionalExpressionField<T>(val field: Field<T>) : ExpressionField<T, AsmExpression<T>>,
|
class AsmExpressionField<T>(private val field: Field<T>) : ExpressionField<T, AsmExpression<T>>,
|
||||||
AsmFunctionalExpressionSpace<T>(field, field.one) {
|
AsmExpressionSpace<T>(field) {
|
||||||
override val one: AsmExpression<T>
|
override val one: AsmExpression<T>
|
||||||
get() = const(this.field.one)
|
get() = const(this.field.one)
|
||||||
|
|
||||||
override fun const(value: Double): AsmExpression<T> = const(field.run { one * value })
|
override fun number(value: Number): AsmExpression<T> = const(field.run { one * value })
|
||||||
|
|
||||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||||
AsmProductExpression(a, b)
|
AsmProductExpression(a, b)
|
||||||
@ -412,7 +412,6 @@ class AsmFunctionalExpressionField<T>(val field: Field<T>) : ExpressionField<T,
|
|||||||
|
|
||||||
operator fun AsmExpression<T>.times(arg: T) = this * const(arg)
|
operator fun AsmExpression<T>.times(arg: T) = this * const(arg)
|
||||||
operator fun AsmExpression<T>.div(arg: T) = this / const(arg)
|
operator fun AsmExpression<T>.div(arg: T) = this / const(arg)
|
||||||
|
|
||||||
operator fun T.times(arg: AsmExpression<T>) = arg * this
|
operator fun T.times(arg: AsmExpression<T>) = arg * this
|
||||||
operator fun T.div(arg: AsmExpression<T>) = arg / this
|
operator fun T.div(arg: AsmExpression<T>) = arg / this
|
||||||
}
|
}
|
||||||
|
@ -12,15 +12,12 @@ class AsmTest {
|
|||||||
arguments: Map<String, T>,
|
arguments: Map<String, T>,
|
||||||
algebra: Algebra<T>,
|
algebra: Algebra<T>,
|
||||||
clazz: Class<*>
|
clazz: Class<*>
|
||||||
) {
|
): Unit = assertEquals(
|
||||||
assertEquals(
|
expectedValue, AsmGenerationContext(clazz, algebra, "TestAsmCompiled")
|
||||||
expectedValue, AsmGenerationContext(
|
.also(expr::invoke)
|
||||||
clazz,
|
.generate()
|
||||||
algebra,
|
.invoke(arguments)
|
||||||
"TestAsmCompiled"
|
)
|
||||||
).also(expr::invoke).generate().evaluate(arguments)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
private fun testDoubleExpressionValue(
|
private fun testDoubleExpressionValue(
|
||||||
@ -29,7 +26,7 @@ class AsmTest {
|
|||||||
arguments: Map<String, Double>,
|
arguments: Map<String, Double>,
|
||||||
algebra: Algebra<Double> = RealField,
|
algebra: Algebra<Double> = RealField,
|
||||||
clazz: Class<Double> = java.lang.Double::class.java as Class<Double>
|
clazz: Class<Double> = java.lang.Double::class.java as Class<Double>
|
||||||
) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz)
|
): Unit = testExpressionValue(expectedValue, expr, arguments, algebra, clazz)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSum() = testDoubleExpressionValue(
|
fun testSum() = testDoubleExpressionValue(
|
||||||
@ -39,28 +36,28 @@ class AsmTest {
|
|||||||
)
|
)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testConst() = testDoubleExpressionValue(
|
fun testConst(): Unit = testDoubleExpressionValue(
|
||||||
123.0,
|
123.0,
|
||||||
AsmConstantExpression(123.0),
|
AsmConstantExpression(123.0),
|
||||||
mapOf()
|
mapOf()
|
||||||
)
|
)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDiv() = testDoubleExpressionValue(
|
fun testDiv(): Unit = testDoubleExpressionValue(
|
||||||
0.5,
|
0.5,
|
||||||
AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)),
|
AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)),
|
||||||
mapOf()
|
mapOf()
|
||||||
)
|
)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testProduct() = testDoubleExpressionValue(
|
fun testProduct(): Unit = testDoubleExpressionValue(
|
||||||
25.0,
|
25.0,
|
||||||
AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")),
|
AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")),
|
||||||
mapOf("x" to 5.0)
|
mapOf("x" to 5.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCProduct() = testDoubleExpressionValue(
|
fun testCProduct(): Unit = testDoubleExpressionValue(
|
||||||
25.0,
|
25.0,
|
||||||
AsmConstProductExpression(AsmVariableExpression("x"), 5.0),
|
AsmConstProductExpression(AsmVariableExpression("x"), 5.0),
|
||||||
mapOf("x" to 5.0)
|
mapOf("x" to 5.0)
|
||||||
|
@ -76,7 +76,7 @@ interface ExpressionSpace<T, E> : Space<E>, ExpressionContext<T, E> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
interface ExpressionField<T, E> : Field<E>, ExpressionSpace<T, E> {
|
interface ExpressionField<T, E> : Field<E>, ExpressionSpace<T, E> {
|
||||||
fun const(value: Double): E = one.times(value)
|
fun number(value: Number): E = one * value
|
||||||
|
|
||||||
override fun produce(node: SyntaxTreeNode): E {
|
override fun produce(node: SyntaxTreeNode): E {
|
||||||
if (node is BinaryNode) {
|
if (node is BinaryNode) {
|
||||||
|
@ -27,55 +27,41 @@ internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<
|
|||||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
internal class ConstProductExpression<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
||||||
Expression<T> {
|
Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
internal class DivExpression<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
||||||
Expression<T> {
|
Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
open class FunctionalExpressionSpace<T>(
|
open class FunctionalExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionSpace<T, Expression<T>> {
|
||||||
val space: Space<T>,
|
|
||||||
one: T
|
|
||||||
) : Space<Expression<T>>, ExpressionSpace<T,Expression<T>> {
|
|
||||||
|
|
||||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
override val zero: Expression<T> = ConstantExpression(space.zero)
|
||||||
|
|
||||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
||||||
|
|
||||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
||||||
|
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
||||||
|
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpression(space, a, k)
|
||||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
|
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||||
|
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
|
||||||
|
|
||||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
|
||||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
|
||||||
}
|
}
|
||||||
|
|
||||||
open class FunctionalExpressionField<T>(
|
open class FunctionalExpressionField<T>(
|
||||||
val field: Field<T>
|
val field: Field<T>
|
||||||
) : ExpressionField<T,Expression<T>>, FunctionalExpressionSpace<T>(field, field.one) {
|
) : ExpressionField<T, Expression<T>>, FunctionalExpressionSpace<T>(field) {
|
||||||
|
|
||||||
override val one: Expression<T>
|
override val one: Expression<T>
|
||||||
get() = const(this.field.one)
|
get() = const(this.field.one)
|
||||||
|
|
||||||
override fun const(value: Double): Expression<T> = const(field.run { one*value})
|
override fun number(value: Number): Expression<T> = const(field.run { one * value })
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
||||||
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpression(field, a, b)
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
|
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
|
|
||||||
operator fun T.times(arg: Expression<T>) = arg * this
|
|
||||||
operator fun T.div(arg: Expression<T>) = arg / this
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user