Provide dynamic operations currying for Algebra<T> instead of eager calls and add JS code generation support #162

Merged
CommanderTvis merged 44 commits from feature/dynamic-ops-currying into dev 2021-01-05 16:36:51 +03:00
18 changed files with 385 additions and 747 deletions
Showing only changes of commit a5c00051c2 - Show all commits

View File

@ -4,13 +4,19 @@ import kscience.kmath.asm.compile
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.expressionInField import kscience.kmath.expressions.expressionInField
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.expressions.symbol
import kscience.kmath.operations.Field import kscience.kmath.operations.Field
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import kotlin.random.Random import kotlin.random.Random
import kotlin.system.measureTimeMillis
@State(Scope.Benchmark)
internal class ExpressionsInterpretersBenchmark { internal class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField private val algebra: Field<Double> = RealField
@Benchmark
fun functionalExpression() { fun functionalExpression() {
val expr = algebra.expressionInField { val expr = algebra.expressionInField {
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0) symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
@ -19,6 +25,7 @@ internal class ExpressionsInterpretersBenchmark {
invokeAndSum(expr) invokeAndSum(expr)
} }
@Benchmark
fun mstExpression() { fun mstExpression() {
val expr = algebra.mstInField { val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
@ -27,6 +34,7 @@ internal class ExpressionsInterpretersBenchmark {
invokeAndSum(expr) invokeAndSum(expr)
} }
@Benchmark
fun asmExpression() { fun asmExpression() {
val expr = algebra.mstInField { val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
@ -35,6 +43,13 @@ internal class ExpressionsInterpretersBenchmark {
invokeAndSum(expr) invokeAndSum(expr)
} }
@Benchmark
fun rawExpression() {
val x by symbol
val expr = Expression<Double> { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 }
invokeAndSum(expr)
}
private fun invokeAndSum(expr: Expression<Double>) { private fun invokeAndSum(expr: Expression<Double>) {
val random = Random(0) val random = Random(0)
var sum = 0.0 var sum = 0.0
@ -46,35 +61,3 @@ internal class ExpressionsInterpretersBenchmark {
println(sum) println(sum)
} }
} }
/**
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
* core FunctionalExpressions API.
*
* The expected rating is:
*
* 1. ASM.
* 2. MST.
* 3. FE.
*/
fun main() {
val benchmark = ExpressionsInterpretersBenchmark()
val fe = measureTimeMillis {
benchmark.functionalExpression()
}
println("fe=$fe")
val mst = measureTimeMillis {
benchmark.mstExpression()
}
println("mst=$mst")
val asm = measureTimeMillis {
benchmark.asmExpression()
}
println("asm=$asm")
}

View File

@ -16,17 +16,13 @@ internal class ArrayBenchmark {
@Benchmark @Benchmark
fun benchmarkBufferRead() { fun benchmarkBufferRead() {
var res = 0 var res = 0
for (i in 1..size) res += arrayBuffer.get( for (i in 1..size) res += arrayBuffer[size - i]
size - i
)
} }
@Benchmark @Benchmark
fun nativeBufferRead() { fun nativeBufferRead() {
var res = 0 var res = 0
for (i in 1..size) res += nativeBuffer.get( for (i in 1..size) res += nativeBuffer[size - i]
size - i
)
} }
companion object { companion object {

View File

@ -55,24 +55,24 @@ public sealed class MST {
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) { public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Unary -> unaryOperation(node.operation)(evaluate(node.value))
is MST.Binary -> when { is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) this !is NumericAlgebra -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation( val number = RealField
node.operation, .binaryOperation(node.operation)
node.left.value.toDouble(), .invoke(node.left.value.toDouble(), node.right.value.toDouble())
node.right.value.toDouble()
)
number(number) number(number)
} }
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) node.left is MST.Numeric -> leftSideNumberOperation(node.operation)(node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) node.right is MST.Numeric -> rightSideNumberOperation(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) else -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
} }
} }

View File

@ -6,32 +6,35 @@ import kscience.kmath.operations.*
* [Algebra] over [MST] nodes. * [Algebra] over [MST] nodes.
*/ */
public object MstAlgebra : NumericAlgebra<MST> { public object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST.Numeric = MST.Numeric(value) public override fun number(value: Number): MST.Numeric = MST.Numeric(value)
public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary =
{ arg -> MST.Unary(operation, arg) }
override fun unaryOperation(operation: String, arg: MST): MST.Unary = public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MST.Unary(operation, arg) { left, right -> MST.Binary(operation, left, right) }
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MST.Binary(operation, left, right)
} }
/** /**
* [Space] over [MST] nodes. * [Space] over [MST] nodes.
*/ */
public object MstSpace : Space<MST>, NumericAlgebra<MST> { public object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST.Numeric by lazy { number(0.0) } public override val zero: MST.Numeric by lazy { number(0.0) }
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) public override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) public override fun MST.unaryMinus(): MST = unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = public override fun multiply(a: MST, k: Number): MST.Binary =
MstAlgebra.binaryOperation(operation, left, right) binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k))
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg) public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MstAlgebra.binaryOperation(operation)
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary =
MstAlgebra.unaryOperation(operation)
} }
/** /**
@ -43,16 +46,18 @@ public object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val one: MST.Numeric by lazy { number(1.0) } override val one: MST.Numeric by lazy { number(1.0) }
override fun number(value: Number): MST.Numeric = MstSpace.number(value) public override fun number(value: Number): MST.Numeric = MstSpace.number(value)
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b) public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MstSpace.binaryOperation(operation, left, right) MstSpace.binaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg) public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary =
MstAlgebra.unaryOperation(operation)
} }
/** /**
@ -70,51 +75,52 @@ public object MstField : Field<MST> {
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b) public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MstRing.binaryOperation(operation, left, right) MstRing.binaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg) public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstRing.unaryOperation(operation)
} }
/** /**
* [ExtendedField] over [MST] nodes. * [ExtendedField] over [MST] nodes.
*/ */
public object MstExtendedField : ExtendedField<MST> { public object MstExtendedField : ExtendedField<MST> {
override val zero: MST.Numeric public override val zero: MST.Numeric
get() = MstField.zero get() = MstField.zero
override val one: MST.Numeric public override val one: MST.Numeric
get() = MstField.one get() = MstField.one
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) public override fun symbol(value: String): MST = MstField.symbol(value)
override fun number(value: Number): MST.Numeric = MstField.number(value) public override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) public override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) public override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg)
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) public override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) public override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) public override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) public override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg)
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) public override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg)
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) public override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg)
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) public override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg)
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) public override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg)
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) public override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg)
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
override fun power(arg: MST, pow: Number): MST.Binary = public override fun power(arg: MST, pow: Number): MST.Binary =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) public override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg) public override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
MstField.binaryOperation(operation, left, right) MstField.binaryOperation(operation)
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg) public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstField.unaryOperation(operation)
} }

View File

@ -15,11 +15,14 @@ import kotlin.contracts.contract
*/ */
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> { public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) override fun symbol(value: String): T = try {
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) algebra.symbol(value)
} catch (ignored: IllegalStateException) {
null
} ?: arguments.getValue(StringSymbol(value))
override fun binaryOperation(operation: String, left: T, right: T): T = override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation)
algebra.binaryOperation(operation, left, right) override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)

View File

@ -1,13 +1,13 @@
package kscience.kmath.asm package kscience.kmath.asm
import kscience.kmath.asm.internal.AsmBuilder import kscience.kmath.asm.internal.AsmBuilder
import kscience.kmath.asm.internal.MstType
import kscience.kmath.asm.internal.buildAlgebraOperationCall
import kscience.kmath.asm.internal.buildName import kscience.kmath.asm.internal.buildName
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.ast.MstExpression import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/** /**
* Compiles given MST to an Expression using AST compiler. * Compiles given MST to an Expression using AST compiler.
@ -23,37 +23,46 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
is MST.Symbolic -> { is MST.Symbolic -> {
val symbol = try { val symbol = try {
algebra.symbol(node.value) algebra.symbol(node.value)
} catch (ignored: Throwable) { } catch (ignored: IllegalStateException) {
null null
} }
if (symbol != null) if (symbol != null)
loadTConstant(symbol) loadObjectConstant(symbol as Any)
else else
loadVariable(node.value) loadVariable(node.value)
} }
is MST.Numeric -> loadNumeric(node.value) is MST.Numeric -> loadNumberConstant(node.value)
is MST.Unary -> buildCall(algebra.unaryOperation(node.operation)) { visit(node.value) }
is MST.Unary -> buildAlgebraOperationCall( is MST.Binary -> when {
context = algebra, algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
name = node.operation, algebra.number(
fallbackMethodName = "unaryOperation", RealField
parameterTypes = arrayOf(MstType.fromMst(node.value)) .binaryOperation(node.operation)
) { visit(node.value) } .invoke(node.left.value.toDouble(), node.right.value.toDouble())
)
)
is MST.Binary -> buildAlgebraOperationCall( algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperation(node.operation)) {
context = algebra, visit(node.left)
name = node.operation, visit(node.right)
fallbackMethodName = "binaryOperation", }
parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
) { algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperation(node.operation)) {
visit(node.left)
visit(node.right)
}
else -> buildCall(algebra.binaryOperation(node.operation)) {
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
} }
} }
}
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
} }
/** /**

View File

@ -3,29 +3,29 @@ package kscience.kmath.asm.internal
import kscience.kmath.asm.internal.AsmBuilder.ClassLoader import kscience.kmath.asm.internal.AsmBuilder.ClassLoader
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import org.objectweb.asm.* import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import java.util.* import java.lang.invoke.MethodHandles
import java.util.stream.Collectors import java.lang.invoke.MethodType
import java.util.stream.Collectors.toMap
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
* *
* @property T the type of AsmExpression to unwrap. * @property T the type of AsmExpression to unwrap.
* @property algebra the algebra the applied AsmExpressions use.
* @property className the unique class name of new loaded class. * @property className the unique class name of new loaded class.
* @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. * @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0.
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
internal class AsmBuilder<T> internal constructor( internal class AsmBuilder<T>(
private val classOfT: Class<*>, classOfT: Class<*>,
private val algebra: Algebra<T>,
private val className: String, private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit, private val callbackAtInvokeL0: AsmBuilder<T>.() -> Unit,
) { ) {
/** /**
* Internal classloader of [AsmBuilder] with alias to define class from byte array. * Internal classloader of [AsmBuilder] with alias to define class from byte array.
@ -39,20 +39,15 @@ internal class AsmBuilder<T> internal constructor(
*/ */
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
/**
* ASM Type for [algebra].
*/
private val tAlgebraType: Type = algebra.javaClass.asm
/** /**
* ASM type for [T]. * ASM type for [T].
*/ */
internal val tType: Type = classOfT.asm private val tType: Type = classOfT.asm
/** /**
* ASM type for new class. * ASM type for new class.
*/ */
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/'))
/** /**
* List of constants to provide to the subclass. * List of constants to provide to the subclass.
@ -64,55 +59,14 @@ internal class AsmBuilder<T> internal constructor(
*/ */
private lateinit var invokeMethodVisitor: InstructionAdapter private lateinit var invokeMethodVisitor: InstructionAdapter
/**
* States whether this [AsmBuilder] needs to generate constants field.
*/
private var hasConstants: Boolean = true
/**
* States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
*/
internal var primitiveMode: Boolean = false
/**
* Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
*/
internal var primitiveMask: Type = OBJECT_TYPE
/**
* Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
*/
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
/**
* Stack of useful objects types on stack to verify types.
*/
private val typeStack: ArrayDeque<Type> = ArrayDeque()
/**
* Stack of useful objects types on stack expected by algebra calls.
*/
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>(1).also { it.push(tType) }
/**
* The cache for instance built by this builder.
*/
private var generatedInstance: Expression<T>? = null
/** /**
* Subclasses, loads and instantiates [Expression] for given parameters. * Subclasses, loads and instantiates [Expression] for given parameters.
* *
* The built instance is cached. * The built instance is cached.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
internal fun getInstance(): Expression<T> { val instance: Expression<T> by lazy {
generatedInstance?.let { return it } val hasConstants: Boolean
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
primitiveMode = true
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
primitiveMaskBoxed = tType
}
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit( visit(
@ -121,20 +75,20 @@ internal class AsmBuilder<T> internal constructor(
classType.internalName, classType.internalName,
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
OBJECT_TYPE.internalName, OBJECT_TYPE.internalName,
arrayOf(EXPRESSION_TYPE.internalName) arrayOf(EXPRESSION_TYPE.internalName),
) )
visitMethod( visitMethod(
ACC_PUBLIC or ACC_FINAL, ACC_PUBLIC or ACC_FINAL,
"invoke", "invoke",
Type.getMethodDescriptor(tType, MAP_TYPE), getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
null null,
).instructionAdapter { ).instructionAdapter {
invokeMethodVisitor = this invokeMethodVisitor = this
visitCode() visitCode()
val l0 = label() val l0 = label()
invokeLabel0Visitor() callbackAtInvokeL0()
areturn(tType) areturn(tType)
val l1 = label() val l1 = label()
@ -144,7 +98,7 @@ internal class AsmBuilder<T> internal constructor(
null, null,
l0, l0,
l1, l1,
invokeThisVar 0,
) )
visitLocalVariable( visitLocalVariable(
@ -153,7 +107,7 @@ internal class AsmBuilder<T> internal constructor(
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
l0, l0,
l1, l1,
invokeArgumentsVar 1,
) )
visitMaxs(0, 2) visitMaxs(0, 2)
@ -163,17 +117,15 @@ internal class AsmBuilder<T> internal constructor(
visitMethod( visitMethod(
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
"invoke", "invoke",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null,
null, null,
null
).instructionAdapter { ).instructionAdapter {
val thisVar = 0
val argumentsVar = 1
visitCode() visitCode()
val l0 = label() val l0 = label()
load(thisVar, OBJECT_TYPE) load(0, OBJECT_TYPE)
load(argumentsVar, MAP_TYPE) load(1, MAP_TYPE)
invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false)
areturn(tType) areturn(tType)
val l1 = label() val l1 = label()
@ -183,7 +135,7 @@ internal class AsmBuilder<T> internal constructor(
null, null,
l0, l0,
l1, l1,
thisVar 0,
) )
visitMaxs(0, 2) visitMaxs(0, 2)
@ -192,15 +144,6 @@ internal class AsmBuilder<T> internal constructor(
hasConstants = constants.isNotEmpty() hasConstants = constants.isNotEmpty()
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "algebra",
descriptor = tAlgebraType.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
if (hasConstants) if (hasConstants)
visitField( visitField(
access = ACC_PRIVATE or ACC_FINAL, access = ACC_PRIVATE or ACC_FINAL,
@ -208,55 +151,36 @@ internal class AsmBuilder<T> internal constructor(
descriptor = OBJECT_ARRAY_TYPE.descriptor, descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null, signature = null,
value = null, value = null,
block = FieldVisitor::visitEnd block = FieldVisitor::visitEnd,
) )
visitMethod( visitMethod(
ACC_PUBLIC, ACC_PUBLIC,
"<init>", "<init>",
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
Type.getMethodDescriptor(
Type.VOID_TYPE,
tAlgebraType,
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
null, null,
null null
).instructionAdapter { ).instructionAdapter {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = label() val l0 = label()
load(thisVar, classType) load(0, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false) invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
label() label()
load(thisVar, classType) load(0, classType)
load(algebraVar, tAlgebraType)
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
if (hasConstants) { if (hasConstants) {
label() label()
load(thisVar, classType) load(0, classType)
load(constantsVar, OBJECT_ARRAY_TYPE) load(1, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
} }
label() label()
visitInsn(RETURN) visitInsn(RETURN)
val l4 = label() val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
visitLocalVariable(
"algebra",
tAlgebraType.descriptor,
null,
l0,
l4,
algebraVar
)
if (hasConstants) if (hasConstants)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1)
visitMaxs(0, 3) visitMaxs(0, 3)
visitEnd() visitEnd()
@ -265,33 +189,56 @@ internal class AsmBuilder<T> internal constructor(
visitEnd() visitEnd()
} }
val new = classLoader val cls = classLoader.defineClass(className, classWriter.toByteArray())
.defineClass(className, classWriter.toByteArray()) val l = MethodHandles.publicLookup()
.constructors
.first()
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
generatedInstance = new if (hasConstants)
return new l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array<Any>::class.java))
.invoke(constants.toTypedArray()) as Expression<T>
else
l.findConstructor(cls, MethodType.methodType(Void.TYPE)).invoke() as Expression<T>
} }
/** /**
* Loads a [T] constant from [constants]. * Loads [java.lang.Object] constant from constants.
*/ */
internal fun loadTConstant(value: T) { fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
if (classOfT in INLINABLE_NUMBERS) { val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
val expectedType = expectationStack.pop() loadThis()
val mustBeBoxed = expectedType.sort == Type.OBJECT getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
loadNumberConstant(value as Number, mustBeBoxed) iconst(idx)
visitInsn(AALOAD)
if (type != OBJECT_TYPE) checkcast(type)
}
if (mustBeBoxed) /**
invokeMethodVisitor.checkcast(tType) * Loads `this` variable.
*/
private fun loadThis(): Unit = invokeMethodVisitor.load(0, classType)
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) /**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool.
*/
fun loadNumberConstant(value: Number) {
val boxed = value.javaClass.asm
val primitive = BOXED_TO_PRIMITIVES[boxed]
if (primitive != null) {
when (primitive) {
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
box(primitive)
return return
} }
loadObjectConstant(value as Any, tType) loadObjectConstant(value, boxed)
} }
/** /**
@ -303,258 +250,100 @@ internal class AsmBuilder<T> internal constructor(
invokeMethodVisitor.invokestatic( invokeMethodVisitor.invokestatic(
r.internalName, r.internalName,
"valueOf", "valueOf",
Type.getMethodDescriptor(r, primitive), getMethodDescriptor(r, primitive),
false false,
) )
} }
/** /**
* Unboxes the current boxed value and pushes it. * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke].
*/ */
private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual( fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
NUMBER_TYPE.internalName, load(1, MAP_TYPE)
NUMBER_CONVERTER_METHODS.getValue(primitive),
Type.getMethodDescriptor(primitive),
false
)
/**
* Loads [java.lang.Object] constant from constants.
*/
private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
loadThis()
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
iconst(idx)
visitInsn(AALOAD)
checkcast(type)
}
internal fun loadNumeric(value: Number) {
if (expectationStack.peek() == NUMBER_TYPE) {
loadNumberConstant(value, true)
expectationStack.pop()
typeStack.push(NUMBER_TYPE)
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
}
/**
* Loads this variable.
*/
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
/**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
* from it).
*/
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
val boxed = value.javaClass.asm
val primitive = BOXED_TO_PRIMITIVES[boxed]
if (primitive != null) {
when (primitive) {
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
if (mustBeBoxed)
box(primitive)
return
}
loadObjectConstant(value, boxed)
if (!mustBeBoxed)
unboxTo(primitiveMask)
}
/**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
* provided.
*/
internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
load(invokeArgumentsVar, MAP_TYPE)
aconst(name) aconst(name)
invokestatic( invokestatic(
MAP_INTRINSICS_TYPE.internalName, MAP_INTRINSICS_TYPE.internalName,
"getOrFail", "getOrFail",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
false false,
) )
checkcast(tType) checkcast(tType)
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT)
typeStack.push(tType)
else {
unboxTo(primitiveMask)
typeStack.push(primitiveMask)
}
} }
/** inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
* Loads algebra from according field of the class and casts it to class of [algebra] provided. contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
*/ val `interface` = function.javaClass.interfaces.first { it.interfaces.contains(Function::class.java) }
internal fun loadAlgebra() {
loadThis()
invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor)
}
/** val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is ?: error("Provided function object doesn't contain invoke method")
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
* called before the arguments and this operation.
*
* The result is casted to [T] automatically.
*/
internal fun invokeAlgebraOperation(
owner: String,
method: String,
descriptor: String,
expectedArity: Int,
opcode: Int = INVOKEINTERFACE,
) {
run loop@{
repeat(expectedArity) {
if (typeStack.isEmpty()) return@loop
typeStack.pop()
}
}
invokeMethodVisitor.visitMethodInsn( val type = getType(`interface`)
opcode, loadObjectConstant(function, type)
owner, parameters(this)
method,
descriptor, invokeMethodVisitor.invokeinterface(
opcode == INVOKEINTERFACE type.internalName,
"invoke",
getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }),
) )
invokeMethodVisitor.checkcast(tType) invokeMethodVisitor.checkcast(tType)
val isLastExpr = expectationStack.size == 1
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT || isLastExpr)
typeStack.push(tType)
else {
unboxTo(primitiveMask)
typeStack.push(primitiveMask)
}
}
/**
* Writes a LDC Instruction with string constant provided.
*/
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
internal companion object {
/**
* Index of `this` variable in invoke method of the built subclass.
*/
private const val invokeThisVar: Int = 0
/**
* Index of `arguments` variable in invoke method of the built subclass.
*/
private const val invokeArgumentsVar: Int = 1
/**
* Maps JVM primitive numbers boxed types to their primitive ASM types.
*/
private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
hashMapOf(
java.lang.Byte::class.java to Type.BYTE_TYPE,
java.lang.Short::class.java to Type.SHORT_TYPE,
java.lang.Integer::class.java to Type.INT_TYPE,
java.lang.Long::class.java to Type.LONG_TYPE,
java.lang.Float::class.java to Type.FLOAT_TYPE,
java.lang.Double::class.java to Type.DOUBLE_TYPE
)
} }
companion object {
/** /**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types. * Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/ */
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
hashMapOf(
Byte::class.java.asm to BYTE_TYPE,
Short::class.java.asm to SHORT_TYPE,
Integer::class.java.asm to INT_TYPE,
Long::class.java.asm to LONG_TYPE,
Float::class.java.asm to FLOAT_TYPE,
Double::class.java.asm to DOUBLE_TYPE,
)
}
/** /**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types. * Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/ */
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy { private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
BOXED_TO_PRIMITIVES.entries.stream().collect( BOXED_TO_PRIMITIVES.entries.stream().collect(
Collectors.toMap( toMap(Map.Entry<Type, Type>::value, Map.Entry<Type, Type>::key),
Map.Entry<Type, Type>::value,
Map.Entry<Type, Type>::key
)
) )
} }
/**
* Maps primitive ASM types to [Number] functions unboxing them.
*/
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
hashMapOf(
Type.BYTE_TYPE to "byteValue",
Type.SHORT_TYPE to "shortValue",
Type.INT_TYPE to "intValue",
Type.LONG_TYPE to "longValue",
Type.FLOAT_TYPE to "floatValue",
Type.DOUBLE_TYPE to "doubleValue"
)
}
/**
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
*/
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
/** /**
* ASM type for [Expression]. * ASM type for [Expression].
*/ */
internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") } val EXPRESSION_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Expression") }
/**
* ASM type for [java.lang.Number].
*/
internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") }
/** /**
* ASM type for [java.util.Map]. * ASM type for [java.util.Map].
*/ */
internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") } val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") }
/** /**
* ASM type for [java.lang.Object]. * ASM type for [java.lang.Object].
*/ */
internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") } val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") }
/** /**
* ASM type for array of [java.lang.Object]. * ASM type for array of [java.lang.Object].
*/ */
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") }
internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") }
/**
* ASM type for [Algebra].
*/
internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") }
/** /**
* ASM type for [java.lang.String]. * ASM type for [java.lang.String].
*/ */
internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") } val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") }
/** /**
* ASM type for MapIntrinsics. * ASM type for MapIntrinsics.
*/ */
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/asm/internal/MapIntrinsics") } val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("kscience/kmath/asm/internal/MapIntrinsics") }
} }
} }

View File

@ -1,20 +0,0 @@
package kscience.kmath.asm.internal
import kscience.kmath.ast.MST
/**
* Represents types known in [MST], numbers and general values.
*/
internal enum class MstType {
GENERAL,
NUMBER;
companion object {
fun fromMst(mst: MST): MstType {
if (mst is MST.Numeric)
return NUMBER
return GENERAL
}
}
}

View File

@ -2,29 +2,11 @@ package kscience.kmath.asm.internal
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.FieldOperations
import kscience.kmath.operations.RingOperations
import kscience.kmath.operations.SpaceOperations
import org.objectweb.asm.* import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import java.lang.reflect.Method
import java.util.*
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
hashMapOf(
SpaceOperations.PLUS_OPERATION to 2 to "add",
RingOperations.TIMES_OPERATION to 2 to "multiply",
FieldOperations.DIV_OPERATION to 2 to "divide",
SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus",
SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus",
SpaceOperations.MINUS_OPERATION to 2 to "minus"
)
}
/** /**
* Returns ASM [Type] for given [Class]. * Returns ASM [Type] for given [Class].
* *
@ -109,107 +91,3 @@ internal inline fun ClassWriter.visitField(
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return visitField(access, name, descriptor, signature, value).apply(block) return visitField(access, name, descriptor, signature, value).apply(block)
} }
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
context.javaClass.methods.find { method ->
val nameValid = method.name == name
val arityValid = method.parameters.size == parameterTypes.size
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
}
nameValid && arityValid && notBridgeInPrimitive && paramsValid
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
* type expectation stack for needed arity.
*
* @author Iaroslav Postovalov
*/
private fun <T> AsmBuilder<T>.buildExpectationStack(
context: Algebra<T>,
name: String,
parameterTypes: Array<MstType>
): Boolean {
val arity = parameterTypes.size
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
if (specific != null)
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
else
expectationStack.addAll(Collections.nCopies(arity, tType))
return specific != null
}
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
.parameterTypes
.zip(parameterTypes)
.map { (type, mstType) ->
when {
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
}
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method.
*
* @author Iaroslav Postovalov
*/
private fun <T> AsmBuilder<T>.tryInvokeSpecific(
context: Algebra<T>,
name: String,
parameterTypes: Array<MstType>
): Boolean {
val arity = parameterTypes.size
val theName = methodNameAdapters[name to arity] ?: name
val spec = findSpecific(context, theName, parameterTypes) ?: return false
val owner = context.javaClass.asm
invokeAlgebraOperation(
owner = owner.internalName,
method = theName,
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
expectedArity = arity,
opcode = INVOKEVIRTUAL
)
return true
}
/**
* Builds specialized [context] call with option to fallback to generic algebra operation accepting [String].
*
* @author Iaroslav Postovalov
*/
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
context: Algebra<T>,
name: String,
fallbackMethodName: String,
parameterTypes: Array<MstType>,
parameters: AsmBuilder<T>.() -> Unit
) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val arity = parameterTypes.size
loadAlgebra()
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
parameters()
if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
method = fallbackMethodName,
descriptor = Type.getMethodDescriptor(
AsmBuilder.OBJECT_TYPE,
AsmBuilder.STRING_TYPE,
*Array(arity) { AsmBuilder.OBJECT_TYPE }
),
expectedArity = arity
)
}

View File

@ -10,15 +10,11 @@ import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestAsmAlgebras { internal class TestAsmAlgebras {
@Test @Test
fun space() { fun space() {
val res1 = ByteRing.mstInSpace { val res1 = ByteRing.mstInSpace {
binaryOperation( binaryOperation("+")(
"+", unaryOperation("+")(
unaryOperation(
"+",
number(3.toByte()) - (number(2.toByte()) + (multiply( number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)), add(number(1), number(1)),
2 2
@ -30,11 +26,8 @@ internal class TestAsmAlgebras {
}("x" to 2.toByte()) }("x" to 2.toByte())
val res2 = ByteRing.mstInSpace { val res2 = ByteRing.mstInSpace {
binaryOperation( binaryOperation("+")(
"+", unaryOperation("+")(
unaryOperation(
"+",
number(3.toByte()) - (number(2.toByte()) + (multiply( number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)), add(number(1), number(1)),
2 2
@ -51,11 +44,8 @@ internal class TestAsmAlgebras {
@Test @Test
fun ring() { fun ring() {
val res1 = ByteRing.mstInRing { val res1 = ByteRing.mstInRing {
binaryOperation( binaryOperation("+")(
"+", unaryOperation("+")(
unaryOperation(
"+",
(symbol("x") - (2.toByte() + (multiply( (symbol("x") - (2.toByte() + (multiply(
add(number(1), number(1)), add(number(1), number(1)),
2 2
@ -67,17 +57,13 @@ internal class TestAsmAlgebras {
}("x" to 3.toByte()) }("x" to 3.toByte())
val res2 = ByteRing.mstInRing { val res2 = ByteRing.mstInRing {
binaryOperation( binaryOperation("+")(
"+", unaryOperation("+")(
unaryOperation(
"+",
(symbol("x") - (2.toByte() + (multiply( (symbol("x") - (2.toByte() + (multiply(
add(number(1), number(1)), add(number(1), number(1)),
2 2
) + 1.toByte()))) * 3.0 - 1.toByte() ) + 1.toByte()))) * 3.0 - 1.toByte()
), ),
number(1) number(1)
) * number(2) ) * number(2)
}.compile()("x" to 3.toByte()) }.compile()("x" to 3.toByte())
@ -88,8 +74,7 @@ internal class TestAsmAlgebras {
@Test @Test
fun field() { fun field() {
val res1 = RealField.mstInField { val res1 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")(
"+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one
@ -97,8 +82,7 @@ internal class TestAsmAlgebras {
}("x" to 2.0) }("x" to 2.0)
val res2 = RealField.mstInField { val res2 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")(
"+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one

View File

@ -1,10 +1,11 @@
package kscience.kmath.asm package kscience.kmath.asm
import kscience.kmath.asm.compile import kscience.kmath.ast.mstInExtendedField
import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import kscience.kmath.ast.mstInSpace import kscience.kmath.ast.mstInSpace
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
import kotlin.random.Random
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -28,4 +29,13 @@ internal class TestAsmExpressions {
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
assertEquals(4.0, res) assertEquals(4.0, res)
} }
@Test
fun testMultipleCalls() {
val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile()
val r = Random(0)
var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) }
println(s)
}
} }

View File

@ -1,6 +1,5 @@
package kscience.kmath.asm package kscience.kmath.asm
import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
@ -10,44 +9,44 @@ import kotlin.test.assertEquals
internal class TestAsmSpecialization { internal class TestAsmSpecialization {
@Test @Test
fun testUnaryPlus() { fun testUnaryPlus() {
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() val expr = RealField.mstInField { unaryOperation("+")(symbol("x")) }.compile()
assertEquals(2.0, expr("x" to 2.0)) assertEquals(2.0, expr("x" to 2.0))
} }
@Test @Test
fun testUnaryMinus() { fun testUnaryMinus() {
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() val expr = RealField.mstInField { unaryOperation("-")(symbol("x")) }.compile()
assertEquals(-2.0, expr("x" to 2.0)) assertEquals(-2.0, expr("x" to 2.0))
} }
@Test @Test
fun testAdd() { fun testAdd() {
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() val expr = RealField.mstInField { binaryOperation("+")(symbol("x"), symbol("x")) }.compile()
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))
} }
@Test @Test
fun testSine() { fun testSine() {
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() val expr = RealField.mstInField { unaryOperation("sin")(symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 0.0)) assertEquals(0.0, expr("x" to 0.0))
} }
@Test @Test
fun testMinus() { fun testMinus() {
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() val expr = RealField.mstInField { binaryOperation("-")(symbol("x"), symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 2.0)) assertEquals(0.0, expr("x" to 2.0))
} }
@Test @Test
fun testDivide() { fun testDivide() {
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() val expr = RealField.mstInField { binaryOperation("/")(symbol("x"), symbol("x")) }.compile()
assertEquals(1.0, expr("x" to 2.0)) assertEquals(1.0, expr("x" to 2.0))
} }
@Test @Test
fun testPower() { fun testPower() {
val expr = RealField val expr = RealField
.mstInField { binaryOperation("power", symbol("x"), number(2)) } .mstInField { binaryOperation("pow")(symbol("x"), number(2)) }
.compile() .compile()
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))

View File

@ -17,6 +17,6 @@ internal class TestAsmVariables {
@Test @Test
fun testVariableWithoutDefaultFails() { fun testVariableWithoutDefaultFails() {
val expr = ByteRing.mstInRing { symbol("x") } val expr = ByteRing.mstInRing { symbol("x") }
assertFailsWith<IllegalStateException> { expr() } assertFailsWith<NoSuchElementException> { expr() }
} }
} }

View File

@ -1,8 +1,5 @@
package kscience.kmath.ast package kscience.kmath.ast
import kscience.kmath.ast.evaluate
import kscience.kmath.ast.mstInField
import kscience.kmath.ast.parseMath
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.Complex import kscience.kmath.operations.Complex
@ -45,10 +42,13 @@ internal class ParserTest {
val magicalAlgebra = object : Algebra<String> { val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() override fun unaryOperation(operation: String): (arg: String) -> String {
throw NotImplementedError()
}
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) { override fun binaryOperation(operation: String): (left: String, right: String) -> String =
"magic" -> "$left$right" when (operation) {
"magic" -> { left, right -> "$left$right" }
else -> throw NotImplementedError() else -> throw NotImplementedError()
} }
} }

View File

@ -7,9 +7,8 @@ import kscience.kmath.operations.*
* *
* @param algebra The algebra to provide for Expressions built. * @param algebra The algebra to provide for Expressions built.
*/ */
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>( public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
public val algebra: A, ExpressionAlgebra<T, Expression<T>> {
) : ExpressionAlgebra<T, Expression<T>> {
/** /**
* Builds an Expression of constant expression which does not depend on arguments. * Builds an Expression of constant expression which does not depend on arguments.
*/ */
@ -25,19 +24,18 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
/** /**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/ */
public override fun binaryOperation( public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
operation: String, { left, right ->
left: Expression<T>, Expression { arguments ->
right: Expression<T>, algebra.binaryOperation(operation)(left.invoke(arguments), right.invoke(arguments))
): Expression<T> = Expression { arguments -> }
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
} }
/** /**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. * Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/ */
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments -> public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
algebra.unaryOperation(operation, arg.invoke(arguments)) Expression { arguments -> algebra.unaryOperation(operation)(arg.invoke(arguments)) }
} }
} }
@ -52,7 +50,7 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
* Builds an Expression of addition of two another expressions. * Builds an Expression of addition of two another expressions.
*/ */
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
/** /**
* Builds an Expression of multiplication of expression by number. * Builds an Expression of multiplication of expression by number.
@ -66,11 +64,11 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg) super<FunctionalExpressionAlgebra>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right) super<FunctionalExpressionAlgebra>.binaryOperation(operation)
} }
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra), public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
@ -82,16 +80,16 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
* Builds an Expression of multiplication of two expressions. * Builds an Expression of multiplication of two expressions.
*/ */
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(RingOperations.TIMES_OPERATION, a, b) binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg) public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation, arg) super<FunctionalExpressionSpace>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right) super<FunctionalExpressionSpace>.binaryOperation(operation)
} }
public open class FunctionalExpressionField<T, A>(algebra: A) : public open class FunctionalExpressionField<T, A>(algebra: A) :
@ -101,49 +99,49 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(FieldOperations.DIV_OPERATION, a, b) binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg) public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation, arg) super<FunctionalExpressionRing>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation, left, right) super<FunctionalExpressionRing>.binaryOperation(operation)
} }
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) : public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
FunctionalExpressionField<T, A>(algebra), FunctionalExpressionField<T, A>(algebra),
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> { ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
public override fun sin(arg: Expression<T>): Expression<T> = public override fun sin(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
public override fun cos(arg: Expression<T>): Expression<T> = public override fun cos(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.COS_OPERATION, arg) unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
public override fun asin(arg: Expression<T>): Expression<T> = public override fun asin(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
public override fun acos(arg: Expression<T>): Expression<T> = public override fun acos(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
public override fun atan(arg: Expression<T>): Expression<T> = public override fun atan(arg: Expression<T>): Expression<T> =
unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
public override fun power(arg: Expression<T>, pow: Number): Expression<T> = public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
public override fun exp(arg: Expression<T>): Expression<T> = public override fun exp(arg: Expression<T>): Expression<T> =
unaryOperation(ExponentialOperations.EXP_OPERATION, arg) unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg) public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.unaryOperation(operation, arg) super<FunctionalExpressionField>.unaryOperation(operation)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.binaryOperation(operation, left, right) super<FunctionalExpressionField>.binaryOperation(operation)
} }
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> = public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =

View File

@ -19,9 +19,10 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): M = when (operation) { public override fun binaryOperation(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M =
"dot" -> left dot right when (operation) {
else -> super.binaryOperation(operation, left, right) as M "dot" -> { left, right -> left dot right }
else -> super.binaryOperation(operation) as (Matrix<T>, Matrix<T>) -> M
} }
/** /**

View File

@ -13,19 +13,21 @@ public annotation class KMathContext
*/ */
public interface Algebra<T> { public interface Algebra<T> {
/** /**
* Wrap raw string or variable * Wraps raw string or variable.
*/ */
public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
/** /**
* Dynamic call of unary operation with name [operation] on [arg] * Dynamically dispatches an unary operation with name [operation].
*/ */
public fun unaryOperation(operation: String, arg: T): T public fun unaryOperation(operation: String): (arg: T) -> T =
error("Unary operation $operation not defined in $this")
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] * Dynamically dispatches a binary operation with name [operation].
*/ */
public fun binaryOperation(operation: String, left: T, right: T): T public fun binaryOperation(operation: String): (left: T, right: T) -> T =
error("Binary operation $operation not defined in $this")
} }
/** /**
@ -40,16 +42,28 @@ public interface NumericAlgebra<T> : Algebra<T> {
public fun number(value: Number): T public fun number(value: Number): T
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. * Dynamically dispatches a binary operation with name [operation] where the left argument is [Number].
*/ */
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = public fun leftSideNumberOperation(operation: String): (left: Number, right: T) -> T =
binaryOperation(operation, number(left), right) { l, r -> binaryOperation(operation)(number(l), r) }
// /**
// * Dynamically calls a binary operation with name [operation] where the left argument is [Number].
// */
// public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
// leftSideNumberOperation(operation)(left, right)
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. * Dynamically dispatches a binary operation with name [operation] where the right argument is [Number].
*/ */
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = public fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
leftSideNumberOperation(operation, right, left) { l, r -> binaryOperation(operation)(l, number(r)) }
// /**
// * Dynamically calls a binary operation with name [operation] where the right argument is [Number].
// */
// public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
// rightSideNumberOperation(operation)(left, right)
} }
/** /**
@ -146,16 +160,16 @@ public interface SpaceOperations<T> : Algebra<T> {
*/ */
public operator fun Number.times(b: T): T = b * this public operator fun Number.times(b: T): T = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> arg PLUS_OPERATION -> { arg -> arg }
MINUS_OPERATION -> -arg MINUS_OPERATION -> { arg -> -arg }
else -> error("Unary operation $operation not defined in $this") else -> super.unaryOperation(operation)
} }
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
PLUS_OPERATION -> add(left, right) PLUS_OPERATION -> ::add
MINUS_OPERATION -> left - right MINUS_OPERATION -> { left, right -> left - right }
else -> error("Binary operation $operation not defined in $this") else -> super.binaryOperation(operation)
} }
public companion object { public companion object {
@ -207,9 +221,9 @@ public interface RingOperations<T> : SpaceOperations<T> {
*/ */
public operator fun T.times(b: T): T = multiply(this, b) public operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
TIMES_OPERATION -> multiply(left, right) TIMES_OPERATION -> ::multiply
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation)
} }
public companion object { public companion object {
@ -234,20 +248,6 @@ public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble() override fun number(value: Number): T = one * value.toDouble()
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right)
}
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
/** /**
* Addition of element and scalar. * Addition of element and scalar.
* *
@ -308,9 +308,9 @@ public interface FieldOperations<T> : RingOperations<T> {
*/ */
public operator fun T.div(b: T): T = divide(this, b) public operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
DIV_OPERATION -> divide(left, right) DIV_OPERATION -> ::divide
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation)
} }
public companion object { public companion object {

View File

@ -15,23 +15,23 @@ public interface ExtendedFieldOperations<T> :
public override fun tan(arg: T): T = sin(arg) / cos(arg) public override fun tan(arg: T): T = sin(arg) / cos(arg)
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg) public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
public override fun unaryOperation(operation: String, arg: T): T = when (operation) { public override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.COS_OPERATION -> ::cos
TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.SIN_OPERATION -> ::sin
TrigonometricOperations.TAN_OPERATION -> tan(arg) TrigonometricOperations.TAN_OPERATION -> ::tan
TrigonometricOperations.ACOS_OPERATION -> acos(arg) TrigonometricOperations.ACOS_OPERATION -> ::acos
TrigonometricOperations.ASIN_OPERATION -> asin(arg) TrigonometricOperations.ASIN_OPERATION -> ::asin
TrigonometricOperations.ATAN_OPERATION -> atan(arg) TrigonometricOperations.ATAN_OPERATION -> ::atan
HyperbolicOperations.COSH_OPERATION -> cosh(arg) HyperbolicOperations.COSH_OPERATION -> ::cosh
HyperbolicOperations.SINH_OPERATION -> sinh(arg) HyperbolicOperations.SINH_OPERATION -> ::sinh
HyperbolicOperations.TANH_OPERATION -> tanh(arg) HyperbolicOperations.TANH_OPERATION -> ::tanh
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg) HyperbolicOperations.ACOSH_OPERATION -> ::acosh
HyperbolicOperations.ASINH_OPERATION -> asinh(arg) HyperbolicOperations.ASINH_OPERATION -> ::asinh
HyperbolicOperations.ATANH_OPERATION -> atanh(arg) HyperbolicOperations.ATANH_OPERATION -> ::atanh
PowerOperations.SQRT_OPERATION -> sqrt(arg) PowerOperations.SQRT_OPERATION -> ::sqrt
ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.EXP_OPERATION -> ::exp
ExponentialOperations.LN_OPERATION -> ln(arg) ExponentialOperations.LN_OPERATION -> ::ln
else -> super.unaryOperation(operation, arg) else -> super<FieldOperations>.unaryOperation(operation)
} }
} }
@ -46,9 +46,10 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { public override fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
PowerOperations.POW_OPERATION -> power(left, right) when (operation) {
else -> super.rightSideNumberOperation(operation, left, right) PowerOperations.POW_OPERATION -> ::power
else -> super.rightSideNumberOperation(operation)
} }
} }
@ -80,9 +81,10 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
public override val one: Double public override val one: Double
get() = 1.0 get() = 1.0
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { public override fun binaryOperation(operation: String): (left: Double, right: Double) -> Double =
PowerOperations.POW_OPERATION -> left pow right when (operation) {
else -> super.binaryOperation(operation, left, right) PowerOperations.POW_OPERATION -> ::power
else -> super.binaryOperation(operation)
} }
public override inline fun add(a: Double, b: Double): Double = a + b public override inline fun add(a: Double, b: Double): Double = a + b
@ -130,9 +132,9 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
public override val one: Float public override val one: Float
get() = 1.0f get() = 1.0f
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { public override fun binaryOperation(operation: String): (left: Float, right: Float) -> Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right PowerOperations.POW_OPERATION -> ::power
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation)
} }
public override inline fun add(a: Float, b: Float): Float = a + b public override inline fun add(a: Float, b: Float): Float = a + b