forked from kscience/kmath
Add mapIntrinsics.kt, update specialization mappings
This commit is contained in:
parent
e2cc3c8efe
commit
5ab6960e9b
@ -47,6 +47,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
||||
expectedArity = 1
|
||||
)
|
||||
}
|
||||
|
||||
is MST.Binary -> {
|
||||
loadAlgebra()
|
||||
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
|
||||
|
@ -354,31 +354,23 @@ internal class AsmBuilder<T> internal constructor(
|
||||
* Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided.
|
||||
*/
|
||||
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||
load(invokeArgumentsVar, OBJECT_ARRAY_TYPE)
|
||||
load(invokeArgumentsVar, MAP_TYPE)
|
||||
aconst(name)
|
||||
|
||||
if (defaultValue != null) {
|
||||
loadStringConstant(name)
|
||||
if (defaultValue != null)
|
||||
loadTConstant(defaultValue)
|
||||
else
|
||||
aconst(null)
|
||||
|
||||
invokeinterface(
|
||||
MAP_TYPE.internalName,
|
||||
"getOrDefault",
|
||||
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE)
|
||||
)
|
||||
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
return
|
||||
}
|
||||
|
||||
loadStringConstant(name)
|
||||
|
||||
invokeinterface(
|
||||
MAP_TYPE.internalName,
|
||||
"get",
|
||||
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
|
||||
invokestatic(
|
||||
MAP_INTRINSICS_TYPE.internalName,
|
||||
"getOrFail",
|
||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
|
||||
false
|
||||
)
|
||||
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
checkcast(tType)
|
||||
|
||||
val expectedType = expectationStack.pop()!!
|
||||
|
||||
if (expectedType.sort == Type.OBJECT)
|
||||
@ -517,5 +509,10 @@ internal class AsmBuilder<T> internal constructor(
|
||||
* ASM type for [java.lang.String].
|
||||
*/
|
||||
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for MapIntrinsics.
|
||||
*/
|
||||
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,7 @@
|
||||
@file:JvmName("MapIntrinsics")
|
||||
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
|
||||
return this[key] ?: default ?: error("Parameter not found: $key")
|
||||
}
|
@ -4,8 +4,15 @@ import org.objectweb.asm.Opcodes
|
||||
import org.objectweb.asm.Type
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
private val methodNameAdapters: Map<String, String> by lazy {
|
||||
hashMapOf("+" to "add", "*" to "multiply", "/" to "divide")
|
||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||
hashMapOf(
|
||||
"+" to 2 to "add",
|
||||
"*" to 2 to "multiply",
|
||||
"/" to 2 to "divide",
|
||||
"+" to 1 to "unaryPlus",
|
||||
"-" to 1 to "unaryMinus",
|
||||
"-" to 2 to "minus"
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -15,7 +22,7 @@ private val methodNameAdapters: Map<String, String> by lazy {
|
||||
* @return `true` if contains, else `false`.
|
||||
*/
|
||||
internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||
val theName = methodNameAdapters[name] ?: name
|
||||
val theName = methodNameAdapters[name to arity] ?: name
|
||||
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
|
||||
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
|
||||
repeat(arity) { expectationStack.push(t) }
|
||||
@ -29,7 +36,7 @@ internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name:
|
||||
* @return `true` if contains, else `false`.
|
||||
*/
|
||||
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||
val theName = methodNameAdapters[name] ?: name
|
||||
val theName = methodNameAdapters[name to arity] ?: name
|
||||
|
||||
context.javaClass.methods.find {
|
||||
var suitableSignature = it.name == theName && it.parameters.size == arity
|
||||
|
@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class TestAsmAlgebras {
|
||||
internal class TestAsmAlgebras {
|
||||
@Test
|
||||
fun space() {
|
||||
val res1 = ByteRing.mstInSpace {
|
||||
|
@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class TestAsmExpressions {
|
||||
internal class TestAsmExpressions {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||
|
@ -0,0 +1,45 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
||||
val res = expr("x" to 2.0)
|
||||
assertEquals(2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
||||
val res = expr("x" to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
||||
val res = expr("x" to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
||||
val res = expr("x" to 2.0)
|
||||
assertEquals(0.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||
val res = expr("x" to 2.0)
|
||||
assertEquals(1.0, res)
|
||||
}
|
||||
}
|
@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class AsmTest {
|
||||
internal class AsmTest {
|
||||
@Test
|
||||
fun `compile MST`() {
|
||||
val mst = "2+2*(2+2)".parseMath()
|
||||
|
Loading…
Reference in New Issue
Block a user