Add more tests, fix constant product and product operations impl.

This commit is contained in:
Iaroslav 2020-06-05 23:02:16 +07:00
parent fdd2551c3f
commit 557142c2ba
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
3 changed files with 97 additions and 32 deletions

View File

@ -7,6 +7,7 @@ import org.objectweb.asm.Opcodes.*
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import java.io.File
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) { abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) {
abstract fun evaluate(arguments: Map<String, T>): T abstract fun evaluate(arguments: Map<String, T>): T
@ -184,10 +185,13 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar) private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value) fun visitNumberConstant(value: Number) {
maxStack++
evaluateMethodVisitor.visitBoxedNumberConstant(value)
}
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run { fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
maxStack++ maxStack += 2
visitVarInsn(ALOAD, evaluateArgumentsVar) visitVarInsn(ALOAD, evaluateArgumentsVar)
if (defaultValue != null) { if (defaultValue != null) {
@ -228,7 +232,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
visitCastToT() visitCastToT()
} }
fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS) private fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
companion object { companion object {
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled" const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
@ -236,11 +240,11 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
const val MAP_CLASS = "java/util/Map" const val MAP_CLASS = "java/util/Map"
const val OBJECT_CLASS = "java/lang/Object" const val OBJECT_CLASS = "java/lang/Object"
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
const val SPACE_CLASS = "scientifik/kmath/operations/Space"
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
const val FIELD_CLASS = "scientifik/kmath/operations/Field"
const val STRING_CLASS = "java/lang/String" const val STRING_CLASS = "java/lang/String"
const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations" const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
const val NUMBER_CLASS = "java/lang/Number"
} }
} }
@ -285,8 +289,8 @@ internal class AsmProductExpression<T>(
second.invoke(gen) second.invoke(gen)
gen.visitAlgebraOperation( gen.visitAlgebraOperation(
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, owner = AsmGenerationContext.RING_OPERATIONS_CLASS,
method = "times", method = "multiply",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
) )
} }
@ -298,13 +302,13 @@ internal class AsmConstProductExpression<T>(
) : AsmExpression<T> { ) : AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) { override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra() gen.visitLoadAlgebra()
expr.invoke(gen)
gen.visitNumberConstant(const) gen.visitNumberConstant(const)
expr.invoke(gen)
gen.visitAlgebraOperation( gen.visitAlgebraOperation(
owner = AsmGenerationContext.SPACE_CLASS, owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
method = "multiply", method = "multiply",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
) )
} }
} }
@ -326,7 +330,6 @@ internal class AsmDivExpression<T>(
} }
} }
open class AsmFunctionalExpressionSpace<T>( open class AsmFunctionalExpressionSpace<T>(
val space: Space<T>, val space: Space<T>,
one: T one: T

View File

@ -1,17 +1,34 @@
package scientifik.kmath.expressions package scientifik.kmath.expressions
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes.*
fun MethodVisitor.visitLdcOrIConstInsn(value: Int) { fun MethodVisitor.visitLdcOrIConstInsn(value: Int) {
when (value) { when (value) {
-1 -> visitInsn(Opcodes.ICONST_M1) -1 -> visitInsn(ICONST_M1)
0 -> visitInsn(Opcodes.ICONST_0) 0 -> visitInsn(ICONST_0)
1 -> visitInsn(Opcodes.ICONST_1) 1 -> visitInsn(ICONST_1)
2 -> visitInsn(Opcodes.ICONST_2) 2 -> visitInsn(ICONST_2)
3 -> visitInsn(Opcodes.ICONST_3) 3 -> visitInsn(ICONST_3)
4 -> visitInsn(Opcodes.ICONST_4) 4 -> visitInsn(ICONST_4)
5 -> visitInsn(Opcodes.ICONST_5) 5 -> visitInsn(ICONST_5)
else -> visitLdcInsn(value) else -> visitLdcInsn(value)
} }
} }
private val signatureLetters = mapOf(
java.lang.Byte::class.java to "B",
java.lang.Short::class.java to "S",
java.lang.Integer::class.java to "I",
java.lang.Long::class.java to "J",
java.lang.Float::class.java to "F",
java.lang.Double::class.java to "D",
java.lang.Short::class.java to "S"
)
fun MethodVisitor.visitBoxedNumberConstant(number: Number) {
val clazz = number.javaClass
val c = clazz.name.replace('.', '/')
visitLdcInsn(number)
visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false)
}

View File

@ -1,23 +1,68 @@
package scientifik.kmath.expressions package scientifik.kmath.expressions
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class AsmTest { class AsmTest {
@Test private fun <T> testExpressionValue(
fun test() { expectedValue: T,
val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")) expr: AsmExpression<T>,
arguments: Map<String, T>,
algebra: Algebra<T>,
clazz: Class<*>
) {
assertEquals(
expectedValue, AsmGenerationContext(
clazz,
algebra,
"TestAsmCompiled"
).also(expr::invoke).generate().evaluate(arguments)
)
}
val gen = AsmGenerationContext( @Suppress("UNCHECKED_CAST")
java.lang.Double::class.java, private fun testDoubleExpressionValue(
RealField, expectedValue: Double,
"MyAsmCompiled" expr: AsmExpression<Double>,
arguments: Map<String, Double>,
algebra: Algebra<Double> = RealField,
clazz: Class<Double> = java.lang.Double::class.java as Class<Double>
) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz)
@Test
fun testSum() = testDoubleExpressionValue(
25.0,
AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")),
mapOf("x" to 24.0)
) )
expr.invoke(gen) @Test
val compiled = gen.generate() fun testConst() = testDoubleExpressionValue(
val value = compiled.evaluate(mapOf("x" to 25.0)) 123.0,
assertEquals(26.0, value) AsmConstantExpression(123.0),
} mapOf()
)
@Test
fun testDiv() = testDoubleExpressionValue(
0.5,
AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)),
mapOf()
)
@Test
fun testProduct() = testDoubleExpressionValue(
25.0,
AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")),
mapOf("x" to 5.0)
)
@Test
fun testCProduct() = testDoubleExpressionValue(
25.0,
AsmConstProductExpression(AsmVariableExpression("x"), 5.0),
mapOf("x" to 5.0)
)
} }