Rewrite ASM codegen to use curried operators, fix bugs, update benchmarks

This commit is contained in:
Iaroslav Postovalov 2020-12-08 14:42:42 +07:00
parent 0595950820
commit e62cf4fc65
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
16 changed files with 158 additions and 479 deletions

View File

@ -4,13 +4,19 @@ import kscience.kmath.asm.compile
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.expressionInField
import kscience.kmath.expressions.invoke
import kscience.kmath.expressions.symbol
import kscience.kmath.operations.Field
import kscience.kmath.operations.RealField
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import kotlin.random.Random
import kotlin.system.measureTimeMillis
@State(Scope.Benchmark)
internal class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField
@Benchmark
fun functionalExpression() {
val expr = algebra.expressionInField {
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
@ -19,6 +25,7 @@ internal class ExpressionsInterpretersBenchmark {
invokeAndSum(expr)
}
@Benchmark
fun mstExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
@ -27,6 +34,7 @@ internal class ExpressionsInterpretersBenchmark {
invokeAndSum(expr)
}
@Benchmark
fun asmExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
@ -35,6 +43,13 @@ internal class ExpressionsInterpretersBenchmark {
invokeAndSum(expr)
}
@Benchmark
fun rawExpression() {
val x by symbol
val expr = Expression<Double> { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 }
invokeAndSum(expr)
}
private fun invokeAndSum(expr: Expression<Double>) {
val random = Random(0)
var sum = 0.0
@ -46,35 +61,3 @@ internal class ExpressionsInterpretersBenchmark {
println(sum)
}
}
/**
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
* core FunctionalExpressions API.
*
* The expected rating is:
*
* 1. ASM.
* 2. MST.
* 3. FE.
*/
fun main() {
val benchmark = ExpressionsInterpretersBenchmark()
val fe = measureTimeMillis {
benchmark.functionalExpression()
}
println("fe=$fe")
val mst = measureTimeMillis {
benchmark.mstExpression()
}
println("mst=$mst")
val asm = measureTimeMillis {
benchmark.asmExpression()
}
println("asm=$asm")
}

View File

@ -16,17 +16,13 @@ internal class ArrayBenchmark {
@Benchmark
fun benchmarkBufferRead() {
var res = 0
for (i in 1..size) res += arrayBuffer.get(
size - i
)
for (i in 1..size) res += arrayBuffer[size - i]
}
@Benchmark
fun nativeBufferRead() {
var res = 0
for (i in 1..size) res += nativeBuffer.get(
size - i
)
for (i in 1..size) res += nativeBuffer[size - i]
}
companion object {

View File

@ -36,7 +36,7 @@ internal class ViktorBenchmark {
@Benchmark
fun rawViktor() {
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
val one = F64Array.full(init = 1.0, shape = intArrayOf(dim, dim))
var res = one
repeat(n) { res = res + one }
}

View File

@ -64,7 +64,8 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField
.binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble())
.binaryOperation(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
number(number)
}

View File

@ -24,13 +24,17 @@ public object MstSpace : Space<MST>, NumericAlgebra<MST> {
public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k))
public override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
public override fun MST.unaryMinus(): MST = unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
public override fun multiply(a: MST, k: Number): MST.Binary =
binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k))
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MstAlgebra.binaryOperation(operation)
override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstAlgebra.unaryOperation(operation)
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary =
MstAlgebra.unaryOperation(operation)
}
/**
@ -47,6 +51,7 @@ public object MstRing : Ring<MST>, NumericAlgebra<MST> {
public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MstSpace.binaryOperation(operation)
@ -71,6 +76,7 @@ public object MstField : Field<MST> {
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MstRing.binaryOperation(operation)
@ -105,6 +111,7 @@ public object MstExtendedField : ExtendedField<MST> {
public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
public override fun power(arg: MST, pow: Number): MST.Binary =
binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))

View File

@ -15,7 +15,12 @@ import kotlin.contracts.contract
*/
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun symbol(value: String): T = try {
algebra.symbol(value)
} catch (ignored: IllegalStateException) {
null
} ?: arguments.getValue(StringSymbol(value))
override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation)
override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation)

View File

@ -1,13 +1,13 @@
package kscience.kmath.asm
import kscience.kmath.asm.internal.AsmBuilder
import kscience.kmath.asm.internal.MstType
import kscience.kmath.asm.internal.buildAlgebraOperationCall
import kscience.kmath.asm.internal.buildName
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/**
* Compiles given MST to an Expression using AST compiler.
@ -23,37 +23,46 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
is MST.Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: Throwable) {
} catch (ignored: IllegalStateException) {
null
}
if (symbol != null)
loadTConstant(symbol)
loadObjectConstant(symbol as Any)
else
loadVariable(node.value)
}
is MST.Numeric -> loadNumeric(node.value)
is MST.Numeric -> loadNumberConstant(node.value)
is MST.Unary -> buildCall(algebra.unaryOperation(node.operation)) { visit(node.value) }
is MST.Unary -> buildAlgebraOperationCall(
context = algebra,
name = node.operation,
fallbackMethodName = "unaryOperation",
parameterTypes = arrayOf(MstType.fromMst(node.value))
) { visit(node.value) }
is MST.Binary -> when {
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
algebra.number(
RealField
.binaryOperation(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
)
)
is MST.Binary -> buildAlgebraOperationCall(
context = algebra,
name = node.operation,
fallbackMethodName = "binaryOperation",
parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
) {
visit(node.left)
visit(node.right)
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperation(node.operation)) {
visit(node.left)
visit(node.right)
}
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperation(node.operation)) {
visit(node.left)
visit(node.right)
}
else -> buildCall(algebra.binaryOperation(node.operation)) {
visit(node.left)
visit(node.right)
}
}
}
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
}
/**

View File

@ -4,26 +4,24 @@ import kscience.kmath.asm.internal.AsmBuilder.ClassLoader
import kscience.kmath.ast.MST
import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.commons.InstructionAdapter
import java.util.*
import java.util.stream.Collectors
import java.util.stream.Collectors.toMap
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/**
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
*
* @property T the type of AsmExpression to unwrap.
* @property algebra the algebra the applied AsmExpressions use.
* @property className the unique class name of new loaded class.
* @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
* @author Iaroslav Postovalov
*/
internal class AsmBuilder<T> internal constructor(
private val classOfT: Class<*>,
private val algebra: Algebra<T>,
classOfT: Class<*>,
private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit,
) {
@ -39,15 +37,10 @@ internal class AsmBuilder<T> internal constructor(
*/
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
/**
* ASM Type for [algebra].
*/
private val tAlgebraType: Type = algebra.javaClass.asm
/**
* ASM type for [T].
*/
internal val tType: Type = classOfT.asm
private val tType: Type = classOfT.asm
/**
* ASM type for new class.
@ -69,51 +62,13 @@ internal class AsmBuilder<T> internal constructor(
*/
private var hasConstants: Boolean = true
/**
* States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
*/
internal var primitiveMode: Boolean = false
/**
* Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
*/
internal var primitiveMask: Type = OBJECT_TYPE
/**
* Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
*/
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
/**
* Stack of useful objects types on stack to verify types.
*/
private val typeStack: ArrayDeque<Type> = ArrayDeque()
/**
* Stack of useful objects types on stack expected by algebra calls.
*/
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>(1).also { it.push(tType) }
/**
* The cache for instance built by this builder.
*/
private var generatedInstance: Expression<T>? = null
/**
* Subclasses, loads and instantiates [Expression] for given parameters.
*
* The built instance is cached.
*/
@Suppress("UNCHECKED_CAST")
internal fun getInstance(): Expression<T> {
generatedInstance?.let { return it }
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
primitiveMode = true
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
primitiveMaskBoxed = tType
}
val instance: Expression<T> by lazy {
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit(
V1_8,
@ -192,15 +147,6 @@ internal class AsmBuilder<T> internal constructor(
hasConstants = constants.isNotEmpty()
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "algebra",
descriptor = tAlgebraType.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
if (hasConstants)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
@ -214,25 +160,17 @@ internal class AsmBuilder<T> internal constructor(
visitMethod(
ACC_PUBLIC,
"<init>",
Type.getMethodDescriptor(
Type.VOID_TYPE,
tAlgebraType,
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
Type.getMethodDescriptor(Type.VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
null,
null
).instructionAdapter {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val constantsVar = 1
val l0 = label()
load(thisVar, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
label()
load(thisVar, classType)
load(algebraVar, tAlgebraType)
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
if (hasConstants) {
label()
@ -246,15 +184,6 @@ internal class AsmBuilder<T> internal constructor(
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
visitLocalVariable(
"algebra",
tAlgebraType.descriptor,
null,
l0,
l4,
algebraVar
)
if (hasConstants)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
@ -265,33 +194,55 @@ internal class AsmBuilder<T> internal constructor(
visitEnd()
}
val new = classLoader
// java.io.File("dump.class").writeBytes(classWriter.toByteArray())
classLoader
.defineClass(className, classWriter.toByteArray())
.constructors
.first()
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
generatedInstance = new
return new
.newInstance(*(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
}
/**
* Loads a [T] constant from [constants].
* Loads [java.lang.Object] constant from constants.
*/
internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT
loadNumberConstant(value as Number, mustBeBoxed)
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
loadThis()
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
iconst(idx)
visitInsn(AALOAD)
checkcast(type)
}
if (mustBeBoxed)
invokeMethodVisitor.checkcast(tType)
/**
* Loads `this` variable.
*/
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
/**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool.
*/
fun loadNumberConstant(value: Number) {
val boxed = value.javaClass.asm
val primitive = BOXED_TO_PRIMITIVES[boxed]
if (primitive != null) {
when (primitive) {
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
box(primitive)
return
}
loadObjectConstant(value as Any, tType)
loadObjectConstant(value, boxed)
}
/**
@ -309,77 +260,9 @@ internal class AsmBuilder<T> internal constructor(
}
/**
* Unboxes the current boxed value and pushes it.
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke].
*/
private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual(
NUMBER_TYPE.internalName,
NUMBER_CONVERTER_METHODS.getValue(primitive),
Type.getMethodDescriptor(primitive),
false
)
/**
* Loads [java.lang.Object] constant from constants.
*/
private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
loadThis()
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
iconst(idx)
visitInsn(AALOAD)
checkcast(type)
}
internal fun loadNumeric(value: Number) {
if (expectationStack.peek() == NUMBER_TYPE) {
loadNumberConstant(value, true)
expectationStack.pop()
typeStack.push(NUMBER_TYPE)
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
}
/**
* Loads this variable.
*/
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
/**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
* from it).
*/
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
val boxed = value.javaClass.asm
val primitive = BOXED_TO_PRIMITIVES[boxed]
if (primitive != null) {
when (primitive) {
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
if (mustBeBoxed)
box(primitive)
return
}
loadObjectConstant(value, boxed)
if (!mustBeBoxed)
unboxTo(primitiveMask)
}
/**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
* provided.
*/
internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
load(invokeArgumentsVar, MAP_TYPE)
aconst(name)
@ -391,70 +274,28 @@ internal class AsmBuilder<T> internal constructor(
)
checkcast(tType)
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT)
typeStack.push(tType)
else {
unboxTo(primitiveMask)
typeStack.push(primitiveMask)
}
}
/**
* Loads algebra from according field of the class and casts it to class of [algebra] provided.
*/
internal fun loadAlgebra() {
loadThis()
invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor)
}
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val `interface` = function.javaClass.interfaces.first { it.interfaces.contains(Function::class.java) }
/**
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
* called before the arguments and this operation.
*
* The result is casted to [T] automatically.
*/
internal fun invokeAlgebraOperation(
owner: String,
method: String,
descriptor: String,
expectedArity: Int,
opcode: Int = INVOKEINTERFACE,
) {
run loop@{
repeat(expectedArity) {
if (typeStack.isEmpty()) return@loop
typeStack.pop()
}
}
val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount
?: error("Provided function object doesn't contain invoke method")
invokeMethodVisitor.visitMethodInsn(
opcode,
owner,
method,
descriptor,
opcode == INVOKEINTERFACE
val type = Type.getType(`interface`)
loadObjectConstant(function, type)
parameters(this)
invokeMethodVisitor.invokeinterface(
type.internalName,
"invoke",
Type.getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE}),
)
invokeMethodVisitor.checkcast(tType)
val isLastExpr = expectationStack.size == 1
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT || isLastExpr)
typeStack.push(tType)
else {
unboxTo(primitiveMask)
typeStack.push(primitiveMask)
}
}
/**
* Writes a LDC Instruction with string constant provided.
*/
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
internal companion object {
/**
* Index of `this` variable in invoke method of the built subclass.
@ -490,32 +331,13 @@ internal class AsmBuilder<T> internal constructor(
*/
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
BOXED_TO_PRIMITIVES.entries.stream().collect(
Collectors.toMap(
toMap(
Map.Entry<Type, Type>::value,
Map.Entry<Type, Type>::key
)
)
}
/**
* Maps primitive ASM types to [Number] functions unboxing them.
*/
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
hashMapOf(
Type.BYTE_TYPE to "byteValue",
Type.SHORT_TYPE to "shortValue",
Type.INT_TYPE to "intValue",
Type.LONG_TYPE to "longValue",
Type.FLOAT_TYPE to "floatValue",
Type.DOUBLE_TYPE to "doubleValue"
)
}
/**
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
*/
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
/**
* ASM type for [Expression].
*/

View File

@ -1,20 +0,0 @@
package kscience.kmath.asm.internal
import kscience.kmath.ast.MST
/**
* Represents types known in [MST], numbers and general values.
*/
internal enum class MstType {
GENERAL,
NUMBER;
companion object {
fun fromMst(mst: MST): MstType {
if (mst is MST.Numeric)
return NUMBER
return GENERAL
}
}
}

View File

@ -2,29 +2,11 @@ package kscience.kmath.asm.internal
import kscience.kmath.ast.MST
import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.FieldOperations
import kscience.kmath.operations.RingOperations
import kscience.kmath.operations.SpaceOperations
import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
import org.objectweb.asm.commons.InstructionAdapter
import java.lang.reflect.Method
import java.util.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
hashMapOf(
SpaceOperations.PLUS_OPERATION to 2 to "add",
RingOperations.TIMES_OPERATION to 2 to "multiply",
FieldOperations.DIV_OPERATION to 2 to "divide",
SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus",
SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus",
SpaceOperations.MINUS_OPERATION to 2 to "minus"
)
}
/**
* Returns ASM [Type] for given [Class].
*
@ -110,106 +92,4 @@ internal inline fun ClassWriter.visitField(
return visitField(access, name, descriptor, signature, value).apply(block)
}
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
context.javaClass.methods.find { method ->
val nameValid = method.name == name
val arityValid = method.parameters.size == parameterTypes.size
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
}
nameValid && arityValid && notBridgeInPrimitive && paramsValid
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
* type expectation stack for needed arity.
*
* @author Iaroslav Postovalov
*/
private fun <T> AsmBuilder<T>.buildExpectationStack(
context: Algebra<T>,
name: String,
parameterTypes: Array<MstType>
): Boolean {
val arity = parameterTypes.size
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
if (specific != null)
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
else
expectationStack.addAll(Collections.nCopies(arity, tType))
return specific != null
}
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
.parameterTypes
.zip(parameterTypes)
.map { (type, mstType) ->
when {
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
}
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method.
*
* @author Iaroslav Postovalov
*/
private fun <T> AsmBuilder<T>.tryInvokeSpecific(
context: Algebra<T>,
name: String,
parameterTypes: Array<MstType>
): Boolean {
val arity = parameterTypes.size
val theName = methodNameAdapters[name to arity] ?: name
val spec = findSpecific(context, theName, parameterTypes) ?: return false
val owner = context.javaClass.asm
invokeAlgebraOperation(
owner = owner.internalName,
method = theName,
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
expectedArity = arity,
opcode = INVOKEVIRTUAL
)
return true
}
/**
* Builds specialized [context] call with option to fallback to generic algebra operation accepting [String].
*
* @author Iaroslav Postovalov
*/
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
context: Algebra<T>,
name: String,
fallbackMethodName: String,
parameterTypes: Array<MstType>,
parameters: AsmBuilder<T>.() -> Unit
) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val arity = parameterTypes.size
loadAlgebra()
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
parameters()
if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
method = fallbackMethodName,
descriptor = Type.getMethodDescriptor(
AsmBuilder.OBJECT_TYPE,
AsmBuilder.STRING_TYPE,
*Array(arity) { AsmBuilder.OBJECT_TYPE }
),
expectedArity = arity
)
}

View File

@ -10,15 +10,11 @@ import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmAlgebras {
@Test
fun space() {
val res1 = ByteRing.mstInSpace {
binaryOperation(
"+",
unaryOperation(
"+",
binaryOperation("+")(
unaryOperation("+")(
number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)),
2
@ -30,11 +26,8 @@ internal class TestAsmAlgebras {
}("x" to 2.toByte())
val res2 = ByteRing.mstInSpace {
binaryOperation(
"+",
unaryOperation(
"+",
binaryOperation("+")(
unaryOperation("+")(
number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)),
2
@ -51,11 +44,8 @@ internal class TestAsmAlgebras {
@Test
fun ring() {
val res1 = ByteRing.mstInRing {
binaryOperation(
"+",
unaryOperation(
"+",
binaryOperation("+")(
unaryOperation("+")(
(symbol("x") - (2.toByte() + (multiply(
add(number(1), number(1)),
2
@ -67,17 +57,13 @@ internal class TestAsmAlgebras {
}("x" to 3.toByte())
val res2 = ByteRing.mstInRing {
binaryOperation(
"+",
unaryOperation(
"+",
binaryOperation("+")(
unaryOperation("+")(
(symbol("x") - (2.toByte() + (multiply(
add(number(1), number(1)),
2
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
@ -88,8 +74,7 @@ internal class TestAsmAlgebras {
@Test
fun field() {
val res1 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
"+",
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")(
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
@ -97,8 +82,7 @@ internal class TestAsmAlgebras {
}("x" to 2.0)
val res2 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
"+",
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")(
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one

View File

@ -1,10 +1,11 @@
package kscience.kmath.asm
import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInExtendedField
import kscience.kmath.ast.mstInField
import kscience.kmath.ast.mstInSpace
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
@ -28,4 +29,13 @@ internal class TestAsmExpressions {
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
assertEquals(4.0, res)
}
@Test
fun testMultipleCalls() {
val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile()
val r = Random(0)
var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) }
println(s)
}
}

View File

@ -46,7 +46,7 @@ internal class TestAsmSpecialization {
@Test
fun testPower() {
val expr = RealField
.mstInField { binaryOperation("power")(symbol("x"), number(2)) }
.mstInField { binaryOperation("pow")(symbol("x"), number(2)) }
.compile()
assertEquals(4.0, expr("x" to 2.0))

View File

@ -17,6 +17,6 @@ internal class TestAsmVariables {
@Test
fun testVariableWithoutDefaultFails() {
val expr = ByteRing.mstInRing { symbol("x") }
assertFailsWith<IllegalStateException> { expr() }
assertFailsWith<NoSuchElementException> { expr() }
}
}

View File

@ -20,12 +20,14 @@ public interface Algebra<T> {
/**
* Dynamically dispatches an unary operation with name [operation].
*/
public fun unaryOperation(operation: String): (arg: T) -> T
public fun unaryOperation(operation: String): (arg: T) -> T =
error("Unary operation $operation not defined in $this")
/**
* Dynamically dispatches a binary operation with name [operation].
*/
public fun binaryOperation(operation: String): (left: T, right: T) -> T
public fun binaryOperation(operation: String): (left: T, right: T) -> T =
error("Binary operation $operation not defined in $this")
}
/**
@ -161,13 +163,13 @@ public interface SpaceOperations<T> : Algebra<T> {
override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> { arg -> arg }
MINUS_OPERATION -> { arg -> -arg }
else -> error("Unary operation $operation not defined in $this")
else -> super.unaryOperation(operation)
}
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
PLUS_OPERATION -> ::add
MINUS_OPERATION -> { left, right -> left - right }
else -> error("Binary operation $operation not defined in $this")
else -> super.binaryOperation(operation)
}
public companion object {

View File

@ -31,7 +31,7 @@ public interface ExtendedFieldOperations<T> :
PowerOperations.SQRT_OPERATION -> ::sqrt
ExponentialOperations.EXP_OPERATION -> ::exp
ExponentialOperations.LN_OPERATION -> ::ln
else -> super.unaryOperation(operation)
else -> super<FieldOperations>.unaryOperation(operation)
}
}