Add mapIntrinsics.kt, update specialization mappings
This commit is contained in:
parent
e2cc3c8efe
commit
5ab6960e9b
@ -47,6 +47,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
expectedArity = 1
|
expectedArity = 1
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Binary -> {
|
is MST.Binary -> {
|
||||||
loadAlgebra()
|
loadAlgebra()
|
||||||
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
|
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
|
||||||
|
@ -354,31 +354,23 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided.
|
* Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided.
|
||||||
*/
|
*/
|
||||||
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||||
load(invokeArgumentsVar, OBJECT_ARRAY_TYPE)
|
load(invokeArgumentsVar, MAP_TYPE)
|
||||||
|
aconst(name)
|
||||||
|
|
||||||
if (defaultValue != null) {
|
if (defaultValue != null)
|
||||||
loadStringConstant(name)
|
|
||||||
loadTConstant(defaultValue)
|
loadTConstant(defaultValue)
|
||||||
|
else
|
||||||
|
aconst(null)
|
||||||
|
|
||||||
invokeinterface(
|
invokestatic(
|
||||||
MAP_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
"getOrDefault",
|
"getOrFail",
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE)
|
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
|
||||||
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeMethodVisitor.checkcast(tType)
|
checkcast(tType)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
loadStringConstant(name)
|
|
||||||
|
|
||||||
invokeinterface(
|
|
||||||
MAP_TYPE.internalName,
|
|
||||||
"get",
|
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
|
|
||||||
)
|
|
||||||
|
|
||||||
invokeMethodVisitor.checkcast(tType)
|
|
||||||
val expectedType = expectationStack.pop()!!
|
val expectedType = expectationStack.pop()!!
|
||||||
|
|
||||||
if (expectedType.sort == Type.OBJECT)
|
if (expectedType.sort == Type.OBJECT)
|
||||||
@ -517,5 +509,10 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* ASM type for [java.lang.String].
|
* ASM type for [java.lang.String].
|
||||||
*/
|
*/
|
||||||
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for MapIntrinsics.
|
||||||
|
*/
|
||||||
|
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
@file:JvmName("MapIntrinsics")
|
||||||
|
|
||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
|
||||||
|
return this[key] ?: default ?: error("Parameter not found: $key")
|
||||||
|
}
|
@ -4,8 +4,15 @@ import org.objectweb.asm.Opcodes
|
|||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
|
|
||||||
private val methodNameAdapters: Map<String, String> by lazy {
|
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||||
hashMapOf("+" to "add", "*" to "multiply", "/" to "divide")
|
hashMapOf(
|
||||||
|
"+" to 2 to "add",
|
||||||
|
"*" to 2 to "multiply",
|
||||||
|
"/" to 2 to "divide",
|
||||||
|
"+" to 1 to "unaryPlus",
|
||||||
|
"-" to 1 to "unaryMinus",
|
||||||
|
"-" to 2 to "minus"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -15,7 +22,7 @@ private val methodNameAdapters: Map<String, String> by lazy {
|
|||||||
* @return `true` if contains, else `false`.
|
* @return `true` if contains, else `false`.
|
||||||
*/
|
*/
|
||||||
internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
val theName = methodNameAdapters[name] ?: name
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
|
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
|
||||||
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
|
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
|
||||||
repeat(arity) { expectationStack.push(t) }
|
repeat(arity) { expectationStack.push(t) }
|
||||||
@ -29,7 +36,7 @@ internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name:
|
|||||||
* @return `true` if contains, else `false`.
|
* @return `true` if contains, else `false`.
|
||||||
*/
|
*/
|
||||||
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
val theName = methodNameAdapters[name] ?: name
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
|
|
||||||
context.javaClass.methods.find {
|
context.javaClass.methods.find {
|
||||||
var suitableSignature = it.name == theName && it.parameters.size == arity
|
var suitableSignature = it.name == theName && it.parameters.size == arity
|
||||||
|
@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class TestAsmAlgebras {
|
internal class TestAsmAlgebras {
|
||||||
@Test
|
@Test
|
||||||
fun space() {
|
fun space() {
|
||||||
val res1 = ByteRing.mstInSpace {
|
val res1 = ByteRing.mstInSpace {
|
||||||
|
@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class TestAsmExpressions {
|
internal class TestAsmExpressions {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryOperationInvocation() {
|
fun testUnaryOperationInvocation() {
|
||||||
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class TestSpecialization {
|
||||||
|
@Test
|
||||||
|
fun testUnaryPlus() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
||||||
|
val res = expr("x" to 2.0)
|
||||||
|
assertEquals(2.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testUnaryMinus() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
||||||
|
val res = expr("x" to 2.0)
|
||||||
|
assertEquals(-2.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAdd() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
||||||
|
val res = expr("x" to 2.0)
|
||||||
|
assertEquals(4.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
||||||
|
val res = expr("x" to 2.0)
|
||||||
|
assertEquals(0.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivide() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||||
|
val res = expr("x" to 2.0)
|
||||||
|
assertEquals(1.0, res)
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class AsmTest {
|
internal class AsmTest {
|
||||||
@Test
|
@Test
|
||||||
fun `compile MST`() {
|
fun `compile MST`() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
|
Loading…
Reference in New Issue
Block a user