forked from kscience/kmath
Add more tests, fix constant product and product operations impl.
This commit is contained in:
parent
fdd2551c3f
commit
557142c2ba
@ -7,6 +7,7 @@ import org.objectweb.asm.Opcodes.*
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Space
|
||||
import java.io.File
|
||||
|
||||
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) {
|
||||
abstract fun evaluate(arguments: Map<String, T>): T
|
||||
@ -184,10 +185,13 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
||||
|
||||
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
|
||||
|
||||
fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value)
|
||||
fun visitNumberConstant(value: Number) {
|
||||
maxStack++
|
||||
evaluateMethodVisitor.visitBoxedNumberConstant(value)
|
||||
}
|
||||
|
||||
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
|
||||
maxStack++
|
||||
maxStack += 2
|
||||
visitVarInsn(ALOAD, evaluateArgumentsVar)
|
||||
|
||||
if (defaultValue != null) {
|
||||
@ -228,7 +232,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
||||
private fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
||||
|
||||
companion object {
|
||||
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
|
||||
@ -236,11 +240,11 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
||||
const val MAP_CLASS = "java/util/Map"
|
||||
const val OBJECT_CLASS = "java/lang/Object"
|
||||
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
const val SPACE_CLASS = "scientifik/kmath/operations/Space"
|
||||
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||
const val FIELD_CLASS = "scientifik/kmath/operations/Field"
|
||||
const val STRING_CLASS = "java/lang/String"
|
||||
const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
||||
const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
|
||||
const val NUMBER_CLASS = "java/lang/Number"
|
||||
}
|
||||
}
|
||||
|
||||
@ -285,8 +289,8 @@ internal class AsmProductExpression<T>(
|
||||
second.invoke(gen)
|
||||
|
||||
gen.visitAlgebraOperation(
|
||||
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
||||
method = "times",
|
||||
owner = AsmGenerationContext.RING_OPERATIONS_CLASS,
|
||||
method = "multiply",
|
||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
@ -298,13 +302,13 @@ internal class AsmConstProductExpression<T>(
|
||||
) : AsmExpression<T> {
|
||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||
gen.visitLoadAlgebra()
|
||||
expr.invoke(gen)
|
||||
gen.visitNumberConstant(const)
|
||||
expr.invoke(gen)
|
||||
|
||||
gen.visitAlgebraOperation(
|
||||
owner = AsmGenerationContext.SPACE_CLASS,
|
||||
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
||||
method = "multiply",
|
||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -326,7 +330,6 @@ internal class AsmDivExpression<T>(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
open class AsmFunctionalExpressionSpace<T>(
|
||||
val space: Space<T>,
|
||||
one: T
|
||||
|
@ -1,17 +1,34 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes
|
||||
import org.objectweb.asm.Opcodes.*
|
||||
|
||||
fun MethodVisitor.visitLdcOrIConstInsn(value: Int) {
|
||||
when (value) {
|
||||
-1 -> visitInsn(Opcodes.ICONST_M1)
|
||||
0 -> visitInsn(Opcodes.ICONST_0)
|
||||
1 -> visitInsn(Opcodes.ICONST_1)
|
||||
2 -> visitInsn(Opcodes.ICONST_2)
|
||||
3 -> visitInsn(Opcodes.ICONST_3)
|
||||
4 -> visitInsn(Opcodes.ICONST_4)
|
||||
5 -> visitInsn(Opcodes.ICONST_5)
|
||||
-1 -> visitInsn(ICONST_M1)
|
||||
0 -> visitInsn(ICONST_0)
|
||||
1 -> visitInsn(ICONST_1)
|
||||
2 -> visitInsn(ICONST_2)
|
||||
3 -> visitInsn(ICONST_3)
|
||||
4 -> visitInsn(ICONST_4)
|
||||
5 -> visitInsn(ICONST_5)
|
||||
else -> visitLdcInsn(value)
|
||||
}
|
||||
}
|
||||
|
||||
private val signatureLetters = mapOf(
|
||||
java.lang.Byte::class.java to "B",
|
||||
java.lang.Short::class.java to "S",
|
||||
java.lang.Integer::class.java to "I",
|
||||
java.lang.Long::class.java to "J",
|
||||
java.lang.Float::class.java to "F",
|
||||
java.lang.Double::class.java to "D",
|
||||
java.lang.Short::class.java to "S"
|
||||
)
|
||||
|
||||
fun MethodVisitor.visitBoxedNumberConstant(number: Number) {
|
||||
val clazz = number.javaClass
|
||||
val c = clazz.name.replace('.', '/')
|
||||
visitLdcInsn(number)
|
||||
visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false)
|
||||
}
|
||||
|
@ -1,23 +1,68 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class AsmTest {
|
||||
@Test
|
||||
fun test() {
|
||||
val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x"))
|
||||
|
||||
val gen = AsmGenerationContext(
|
||||
java.lang.Double::class.java,
|
||||
RealField,
|
||||
"MyAsmCompiled"
|
||||
private fun <T> testExpressionValue(
|
||||
expectedValue: T,
|
||||
expr: AsmExpression<T>,
|
||||
arguments: Map<String, T>,
|
||||
algebra: Algebra<T>,
|
||||
clazz: Class<*>
|
||||
) {
|
||||
assertEquals(
|
||||
expectedValue, AsmGenerationContext(
|
||||
clazz,
|
||||
algebra,
|
||||
"TestAsmCompiled"
|
||||
).also(expr::invoke).generate().evaluate(arguments)
|
||||
)
|
||||
|
||||
expr.invoke(gen)
|
||||
val compiled = gen.generate()
|
||||
val value = compiled.evaluate(mapOf("x" to 25.0))
|
||||
assertEquals(26.0, value)
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun testDoubleExpressionValue(
|
||||
expectedValue: Double,
|
||||
expr: AsmExpression<Double>,
|
||||
arguments: Map<String, Double>,
|
||||
algebra: Algebra<Double> = RealField,
|
||||
clazz: Class<Double> = java.lang.Double::class.java as Class<Double>
|
||||
) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz)
|
||||
|
||||
@Test
|
||||
fun testSum() = testDoubleExpressionValue(
|
||||
25.0,
|
||||
AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")),
|
||||
mapOf("x" to 24.0)
|
||||
)
|
||||
|
||||
@Test
|
||||
fun testConst() = testDoubleExpressionValue(
|
||||
123.0,
|
||||
AsmConstantExpression(123.0),
|
||||
mapOf()
|
||||
)
|
||||
|
||||
@Test
|
||||
fun testDiv() = testDoubleExpressionValue(
|
||||
0.5,
|
||||
AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)),
|
||||
mapOf()
|
||||
)
|
||||
|
||||
@Test
|
||||
fun testProduct() = testDoubleExpressionValue(
|
||||
25.0,
|
||||
AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")),
|
||||
mapOf("x" to 5.0)
|
||||
)
|
||||
|
||||
@Test
|
||||
fun testCProduct() = testDoubleExpressionValue(
|
||||
25.0,
|
||||
AsmConstProductExpression(AsmVariableExpression("x"), 5.0),
|
||||
mapOf("x" to 5.0)
|
||||
)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user