forked from kscience/kmath
Add more tests, fix constant product and product operations impl.
This commit is contained in:
parent
fdd2551c3f
commit
557142c2ba
@ -7,6 +7,7 @@ import org.objectweb.asm.Opcodes.*
|
|||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) {
|
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) {
|
||||||
abstract fun evaluate(arguments: Map<String, T>): T
|
abstract fun evaluate(arguments: Map<String, T>): T
|
||||||
@ -184,10 +185,13 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
|
|
||||||
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
|
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
|
||||||
|
|
||||||
fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value)
|
fun visitNumberConstant(value: Number) {
|
||||||
|
maxStack++
|
||||||
|
evaluateMethodVisitor.visitBoxedNumberConstant(value)
|
||||||
|
}
|
||||||
|
|
||||||
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
|
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
|
||||||
maxStack++
|
maxStack += 2
|
||||||
visitVarInsn(ALOAD, evaluateArgumentsVar)
|
visitVarInsn(ALOAD, evaluateArgumentsVar)
|
||||||
|
|
||||||
if (defaultValue != null) {
|
if (defaultValue != null) {
|
||||||
@ -228,7 +232,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
visitCastToT()
|
visitCastToT()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
private fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
|
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
|
||||||
@ -236,11 +240,11 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
|
|||||||
const val MAP_CLASS = "java/util/Map"
|
const val MAP_CLASS = "java/util/Map"
|
||||||
const val OBJECT_CLASS = "java/lang/Object"
|
const val OBJECT_CLASS = "java/lang/Object"
|
||||||
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||||
const val SPACE_CLASS = "scientifik/kmath/operations/Space"
|
|
||||||
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||||
const val FIELD_CLASS = "scientifik/kmath/operations/Field"
|
|
||||||
const val STRING_CLASS = "java/lang/String"
|
const val STRING_CLASS = "java/lang/String"
|
||||||
const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
||||||
|
const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
|
||||||
|
const val NUMBER_CLASS = "java/lang/Number"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,8 +289,8 @@ internal class AsmProductExpression<T>(
|
|||||||
second.invoke(gen)
|
second.invoke(gen)
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
gen.visitAlgebraOperation(
|
||||||
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
owner = AsmGenerationContext.RING_OPERATIONS_CLASS,
|
||||||
method = "times",
|
method = "multiply",
|
||||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -298,13 +302,13 @@ internal class AsmConstProductExpression<T>(
|
|||||||
) : AsmExpression<T> {
|
) : AsmExpression<T> {
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
gen.visitLoadAlgebra()
|
gen.visitLoadAlgebra()
|
||||||
expr.invoke(gen)
|
|
||||||
gen.visitNumberConstant(const)
|
gen.visitNumberConstant(const)
|
||||||
|
expr.invoke(gen)
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
gen.visitAlgebraOperation(
|
||||||
owner = AsmGenerationContext.SPACE_CLASS,
|
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
||||||
method = "multiply",
|
method = "multiply",
|
||||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -326,7 +330,6 @@ internal class AsmDivExpression<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
open class AsmFunctionalExpressionSpace<T>(
|
open class AsmFunctionalExpressionSpace<T>(
|
||||||
val space: Space<T>,
|
val space: Space<T>,
|
||||||
one: T
|
one: T
|
||||||
|
@ -1,17 +1,34 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
import org.objectweb.asm.MethodVisitor
|
import org.objectweb.asm.MethodVisitor
|
||||||
import org.objectweb.asm.Opcodes
|
import org.objectweb.asm.Opcodes.*
|
||||||
|
|
||||||
fun MethodVisitor.visitLdcOrIConstInsn(value: Int) {
|
fun MethodVisitor.visitLdcOrIConstInsn(value: Int) {
|
||||||
when (value) {
|
when (value) {
|
||||||
-1 -> visitInsn(Opcodes.ICONST_M1)
|
-1 -> visitInsn(ICONST_M1)
|
||||||
0 -> visitInsn(Opcodes.ICONST_0)
|
0 -> visitInsn(ICONST_0)
|
||||||
1 -> visitInsn(Opcodes.ICONST_1)
|
1 -> visitInsn(ICONST_1)
|
||||||
2 -> visitInsn(Opcodes.ICONST_2)
|
2 -> visitInsn(ICONST_2)
|
||||||
3 -> visitInsn(Opcodes.ICONST_3)
|
3 -> visitInsn(ICONST_3)
|
||||||
4 -> visitInsn(Opcodes.ICONST_4)
|
4 -> visitInsn(ICONST_4)
|
||||||
5 -> visitInsn(Opcodes.ICONST_5)
|
5 -> visitInsn(ICONST_5)
|
||||||
else -> visitLdcInsn(value)
|
else -> visitLdcInsn(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private val signatureLetters = mapOf(
|
||||||
|
java.lang.Byte::class.java to "B",
|
||||||
|
java.lang.Short::class.java to "S",
|
||||||
|
java.lang.Integer::class.java to "I",
|
||||||
|
java.lang.Long::class.java to "J",
|
||||||
|
java.lang.Float::class.java to "F",
|
||||||
|
java.lang.Double::class.java to "D",
|
||||||
|
java.lang.Short::class.java to "S"
|
||||||
|
)
|
||||||
|
|
||||||
|
fun MethodVisitor.visitBoxedNumberConstant(number: Number) {
|
||||||
|
val clazz = number.javaClass
|
||||||
|
val c = clazz.name.replace('.', '/')
|
||||||
|
visitLdcInsn(number)
|
||||||
|
visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false)
|
||||||
|
}
|
||||||
|
@ -1,23 +1,68 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class AsmTest {
|
class AsmTest {
|
||||||
@Test
|
private fun <T> testExpressionValue(
|
||||||
fun test() {
|
expectedValue: T,
|
||||||
val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x"))
|
expr: AsmExpression<T>,
|
||||||
|
arguments: Map<String, T>,
|
||||||
|
algebra: Algebra<T>,
|
||||||
|
clazz: Class<*>
|
||||||
|
) {
|
||||||
|
assertEquals(
|
||||||
|
expectedValue, AsmGenerationContext(
|
||||||
|
clazz,
|
||||||
|
algebra,
|
||||||
|
"TestAsmCompiled"
|
||||||
|
).also(expr::invoke).generate().evaluate(arguments)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
val gen = AsmGenerationContext(
|
@Suppress("UNCHECKED_CAST")
|
||||||
java.lang.Double::class.java,
|
private fun testDoubleExpressionValue(
|
||||||
RealField,
|
expectedValue: Double,
|
||||||
"MyAsmCompiled"
|
expr: AsmExpression<Double>,
|
||||||
|
arguments: Map<String, Double>,
|
||||||
|
algebra: Algebra<Double> = RealField,
|
||||||
|
clazz: Class<Double> = java.lang.Double::class.java as Class<Double>
|
||||||
|
) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSum() = testDoubleExpressionValue(
|
||||||
|
25.0,
|
||||||
|
AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")),
|
||||||
|
mapOf("x" to 24.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
expr.invoke(gen)
|
@Test
|
||||||
val compiled = gen.generate()
|
fun testConst() = testDoubleExpressionValue(
|
||||||
val value = compiled.evaluate(mapOf("x" to 25.0))
|
123.0,
|
||||||
assertEquals(26.0, value)
|
AsmConstantExpression(123.0),
|
||||||
}
|
mapOf()
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDiv() = testDoubleExpressionValue(
|
||||||
|
0.5,
|
||||||
|
AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)),
|
||||||
|
mapOf()
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testProduct() = testDoubleExpressionValue(
|
||||||
|
25.0,
|
||||||
|
AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")),
|
||||||
|
mapOf("x" to 5.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCProduct() = testDoubleExpressionValue(
|
||||||
|
25.0,
|
||||||
|
AsmConstProductExpression(AsmVariableExpression("x"), 5.0),
|
||||||
|
mapOf("x" to 5.0)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user