Add mapIntrinsics.kt, update specialization mappings

This commit is contained in:
Iaroslav 2020-06-26 15:55:01 +07:00
parent e2cc3c8efe
commit 5ab6960e9b
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
8 changed files with 84 additions and 27 deletions

View File

@ -47,6 +47,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
expectedArity = 1 expectedArity = 1
) )
} }
is MST.Binary -> { is MST.Binary -> {
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)

View File

@ -354,31 +354,23 @@ internal class AsmBuilder<T> internal constructor(
* Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided.
*/ */
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) load(invokeArgumentsVar, MAP_TYPE)
aconst(name)
if (defaultValue != null) { if (defaultValue != null)
loadStringConstant(name)
loadTConstant(defaultValue) loadTConstant(defaultValue)
else
aconst(null)
invokeinterface( invokestatic(
MAP_TYPE.internalName, MAP_INTRINSICS_TYPE.internalName,
"getOrDefault", "getOrFail",
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
) false
invokeMethodVisitor.checkcast(tType)
return
}
loadStringConstant(name)
invokeinterface(
MAP_TYPE.internalName,
"get",
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
) )
invokeMethodVisitor.checkcast(tType) checkcast(tType)
val expectedType = expectationStack.pop()!! val expectedType = expectationStack.pop()!!
if (expectedType.sort == Type.OBJECT) if (expectedType.sort == Type.OBJECT)
@ -517,5 +509,10 @@ internal class AsmBuilder<T> internal constructor(
* ASM type for [java.lang.String]. * ASM type for [java.lang.String].
*/ */
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
/**
* ASM type for MapIntrinsics.
*/
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
} }
} }

View File

@ -0,0 +1,7 @@
@file:JvmName("MapIntrinsics")
package scientifik.kmath.asm.internal
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
return this[key] ?: default ?: error("Parameter not found: $key")
}

View File

@ -4,8 +4,15 @@ import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type import org.objectweb.asm.Type
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> by lazy { private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
hashMapOf("+" to "add", "*" to "multiply", "/" to "divide") hashMapOf(
"+" to 2 to "add",
"*" to 2 to "multiply",
"/" to 2 to "divide",
"+" to 1 to "unaryPlus",
"-" to 1 to "unaryMinus",
"-" to 2 to "minus"
)
} }
/** /**
@ -15,7 +22,7 @@ private val methodNameAdapters: Map<String, String> by lazy {
* @return `true` if contains, else `false`. * @return `true` if contains, else `false`.
*/ */
internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
val theName = methodNameAdapters[name] ?: name val theName = methodNameAdapters[name to arity] ?: name
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
val t = if (primitiveMode && hasSpecific) primitiveMask else tType val t = if (primitiveMode && hasSpecific) primitiveMask else tType
repeat(arity) { expectationStack.push(t) } repeat(arity) { expectationStack.push(t) }
@ -29,7 +36,7 @@ internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name:
* @return `true` if contains, else `false`. * @return `true` if contains, else `false`.
*/ */
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val theName = methodNameAdapters[name] ?: name val theName = methodNameAdapters[name to arity] ?: name
context.javaClass.methods.find { context.javaClass.methods.find {
var suitableSignature = it.name == theName && it.parameters.size == arity var suitableSignature = it.name == theName && it.parameters.size == arity

View File

@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class TestAsmAlgebras { internal class TestAsmAlgebras {
@Test @Test
fun space() { fun space() {
val res1 = ByteRing.mstInSpace { val res1 = ByteRing.mstInSpace {

View File

@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class TestAsmExpressions { internal class TestAsmExpressions {
@Test @Test
fun testUnaryOperationInvocation() { fun testUnaryOperationInvocation() {
val expression = RealField.mstInSpace { -symbol("x") }.compile() val expression = RealField.mstInSpace { -symbol("x") }.compile()

View File

@ -0,0 +1,45 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.mstInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestSpecialization {
@Test
fun testUnaryPlus() {
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
val res = expr("x" to 2.0)
assertEquals(2.0, res)
}
@Test
fun testUnaryMinus() {
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
val res = expr("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testAdd() {
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
val res = expr("x" to 2.0)
assertEquals(4.0, res)
}
@Test
fun testMinus() {
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
val res = expr("x" to 2.0)
assertEquals(0.0, res)
}
@Test
fun testDivide() {
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
val res = expr("x" to 2.0)
assertEquals(1.0, res)
}
}

View File

@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class AsmTest { internal class AsmTest {
@Test @Test
fun `compile MST`() { fun `compile MST`() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()