Add more tests, fix constant product and product operations impl.

This commit is contained in:
Iaroslav 2020-06-05 23:02:16 +07:00
parent fdd2551c3f
commit 557142c2ba
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
3 changed files with 97 additions and 32 deletions

View File

@ -7,6 +7,7 @@ import org.objectweb.asm.Opcodes.*
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space
import java.io.File
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) {
abstract fun evaluate(arguments: Map<String, T>): T
@ -184,10 +185,13 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value)
fun visitNumberConstant(value: Number) {
maxStack++
evaluateMethodVisitor.visitBoxedNumberConstant(value)
}
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
maxStack++
maxStack += 2
visitVarInsn(ALOAD, evaluateArgumentsVar)
if (defaultValue != null) {
@ -228,7 +232,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
visitCastToT()
}
fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
private fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
companion object {
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
@ -236,11 +240,11 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
const val MAP_CLASS = "java/util/Map"
const val OBJECT_CLASS = "java/lang/Object"
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
const val SPACE_CLASS = "scientifik/kmath/operations/Space"
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
const val FIELD_CLASS = "scientifik/kmath/operations/Field"
const val STRING_CLASS = "java/lang/String"
const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
const val NUMBER_CLASS = "java/lang/Number"
}
}
@ -285,8 +289,8 @@ internal class AsmProductExpression<T>(
second.invoke(gen)
gen.visitAlgebraOperation(
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
method = "times",
owner = AsmGenerationContext.RING_OPERATIONS_CLASS,
method = "multiply",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
)
}
@ -298,13 +302,13 @@ internal class AsmConstProductExpression<T>(
) : AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
expr.invoke(gen)
gen.visitNumberConstant(const)
expr.invoke(gen)
gen.visitAlgebraOperation(
owner = AsmGenerationContext.SPACE_CLASS,
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
method = "multiply",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
@ -326,7 +330,6 @@ internal class AsmDivExpression<T>(
}
}
open class AsmFunctionalExpressionSpace<T>(
val space: Space<T>,
one: T

View File

@ -1,17 +1,34 @@
package scientifik.kmath.expressions
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Opcodes.*
fun MethodVisitor.visitLdcOrIConstInsn(value: Int) {
when (value) {
-1 -> visitInsn(Opcodes.ICONST_M1)
0 -> visitInsn(Opcodes.ICONST_0)
1 -> visitInsn(Opcodes.ICONST_1)
2 -> visitInsn(Opcodes.ICONST_2)
3 -> visitInsn(Opcodes.ICONST_3)
4 -> visitInsn(Opcodes.ICONST_4)
5 -> visitInsn(Opcodes.ICONST_5)
-1 -> visitInsn(ICONST_M1)
0 -> visitInsn(ICONST_0)
1 -> visitInsn(ICONST_1)
2 -> visitInsn(ICONST_2)
3 -> visitInsn(ICONST_3)
4 -> visitInsn(ICONST_4)
5 -> visitInsn(ICONST_5)
else -> visitLdcInsn(value)
}
}
private val signatureLetters = mapOf(
java.lang.Byte::class.java to "B",
java.lang.Short::class.java to "S",
java.lang.Integer::class.java to "I",
java.lang.Long::class.java to "J",
java.lang.Float::class.java to "F",
java.lang.Double::class.java to "D",
java.lang.Short::class.java to "S"
)
fun MethodVisitor.visitBoxedNumberConstant(number: Number) {
val clazz = number.javaClass
val c = clazz.name.replace('.', '/')
visitLdcInsn(number)
visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false)
}

View File

@ -1,23 +1,68 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
class AsmTest {
@Test
fun test() {
val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x"))
private fun <T> testExpressionValue(
expectedValue: T,
expr: AsmExpression<T>,
arguments: Map<String, T>,
algebra: Algebra<T>,
clazz: Class<*>
) {
assertEquals(
expectedValue, AsmGenerationContext(
clazz,
algebra,
"TestAsmCompiled"
).also(expr::invoke).generate().evaluate(arguments)
)
}
val gen = AsmGenerationContext(
java.lang.Double::class.java,
RealField,
"MyAsmCompiled"
@Suppress("UNCHECKED_CAST")
private fun testDoubleExpressionValue(
expectedValue: Double,
expr: AsmExpression<Double>,
arguments: Map<String, Double>,
algebra: Algebra<Double> = RealField,
clazz: Class<Double> = java.lang.Double::class.java as Class<Double>
) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz)
@Test
fun testSum() = testDoubleExpressionValue(
25.0,
AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")),
mapOf("x" to 24.0)
)
expr.invoke(gen)
val compiled = gen.generate()
val value = compiled.evaluate(mapOf("x" to 25.0))
assertEquals(26.0, value)
}
@Test
fun testConst() = testDoubleExpressionValue(
123.0,
AsmConstantExpression(123.0),
mapOf()
)
@Test
fun testDiv() = testDoubleExpressionValue(
0.5,
AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)),
mapOf()
)
@Test
fun testProduct() = testDoubleExpressionValue(
25.0,
AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")),
mapOf("x" to 5.0)
)
@Test
fun testCProduct() = testDoubleExpressionValue(
25.0,
AsmConstProductExpression(AsmVariableExpression("x"), 5.0),
mapOf("x" to 5.0)
)
}