forked from kscience/kmath
Merge branch 'feature/dynamic-ops-currying' into feature/estree-codegen
This commit is contained in:
commit
a5c00051c2
@ -4,13 +4,19 @@ import kscience.kmath.asm.compile
|
|||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.expressions.expressionInField
|
import kscience.kmath.expressions.expressionInField
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.expressions.symbol
|
||||||
import kscience.kmath.operations.Field
|
import kscience.kmath.operations.Field
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
|
import org.openjdk.jmh.annotations.Benchmark
|
||||||
|
import org.openjdk.jmh.annotations.Scope
|
||||||
|
import org.openjdk.jmh.annotations.State
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.system.measureTimeMillis
|
|
||||||
|
|
||||||
|
@State(Scope.Benchmark)
|
||||||
internal class ExpressionsInterpretersBenchmark {
|
internal class ExpressionsInterpretersBenchmark {
|
||||||
private val algebra: Field<Double> = RealField
|
private val algebra: Field<Double> = RealField
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
fun functionalExpression() {
|
fun functionalExpression() {
|
||||||
val expr = algebra.expressionInField {
|
val expr = algebra.expressionInField {
|
||||||
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
|
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
|
||||||
@ -19,6 +25,7 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
fun mstExpression() {
|
fun mstExpression() {
|
||||||
val expr = algebra.mstInField {
|
val expr = algebra.mstInField {
|
||||||
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
@ -27,6 +34,7 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
fun asmExpression() {
|
fun asmExpression() {
|
||||||
val expr = algebra.mstInField {
|
val expr = algebra.mstInField {
|
||||||
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
@ -35,6 +43,13 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun rawExpression() {
|
||||||
|
val x by symbol
|
||||||
|
val expr = Expression<Double> { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 }
|
||||||
|
invokeAndSum(expr)
|
||||||
|
}
|
||||||
|
|
||||||
private fun invokeAndSum(expr: Expression<Double>) {
|
private fun invokeAndSum(expr: Expression<Double>) {
|
||||||
val random = Random(0)
|
val random = Random(0)
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
@ -46,35 +61,3 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
println(sum)
|
println(sum)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
|
|
||||||
* core FunctionalExpressions API.
|
|
||||||
*
|
|
||||||
* The expected rating is:
|
|
||||||
*
|
|
||||||
* 1. ASM.
|
|
||||||
* 2. MST.
|
|
||||||
* 3. FE.
|
|
||||||
*/
|
|
||||||
fun main() {
|
|
||||||
val benchmark = ExpressionsInterpretersBenchmark()
|
|
||||||
|
|
||||||
val fe = measureTimeMillis {
|
|
||||||
benchmark.functionalExpression()
|
|
||||||
}
|
|
||||||
|
|
||||||
println("fe=$fe")
|
|
||||||
|
|
||||||
val mst = measureTimeMillis {
|
|
||||||
benchmark.mstExpression()
|
|
||||||
}
|
|
||||||
|
|
||||||
println("mst=$mst")
|
|
||||||
|
|
||||||
val asm = measureTimeMillis {
|
|
||||||
benchmark.asmExpression()
|
|
||||||
}
|
|
||||||
|
|
||||||
println("asm=$asm")
|
|
||||||
}
|
|
@ -16,17 +16,13 @@ internal class ArrayBenchmark {
|
|||||||
@Benchmark
|
@Benchmark
|
||||||
fun benchmarkBufferRead() {
|
fun benchmarkBufferRead() {
|
||||||
var res = 0
|
var res = 0
|
||||||
for (i in 1..size) res += arrayBuffer.get(
|
for (i in 1..size) res += arrayBuffer[size - i]
|
||||||
size - i
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun nativeBufferRead() {
|
fun nativeBufferRead() {
|
||||||
var res = 0
|
var res = 0
|
||||||
for (i in 1..size) res += nativeBuffer.get(
|
for (i in 1..size) res += nativeBuffer[size - i]
|
||||||
size - i
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
@ -55,24 +55,24 @@ public sealed class MST {
|
|||||||
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||||
?: error("Numeric nodes are not supported by $this")
|
?: error("Numeric nodes are not supported by $this")
|
||||||
|
|
||||||
is MST.Symbolic -> symbol(node.value)
|
is MST.Symbolic -> symbol(node.value)
|
||||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
is MST.Unary -> unaryOperation(node.operation)(evaluate(node.value))
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is MST.Binary -> when {
|
||||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
this !is NumericAlgebra -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||||
|
|
||||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||||
val number = RealField.binaryOperation(
|
val number = RealField
|
||||||
node.operation,
|
.binaryOperation(node.operation)
|
||||||
node.left.value.toDouble(),
|
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||||
node.right.value.toDouble()
|
|
||||||
)
|
|
||||||
|
|
||||||
number(number)
|
number(number)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
node.left is MST.Numeric -> leftSideNumberOperation(node.operation)(node.left.value, evaluate(node.right))
|
||||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
node.right is MST.Numeric -> rightSideNumberOperation(node.operation)(evaluate(node.left), node.right.value)
|
||||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
else -> binaryOperation(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,32 +6,35 @@ import kscience.kmath.operations.*
|
|||||||
* [Algebra] over [MST] nodes.
|
* [Algebra] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstAlgebra : NumericAlgebra<MST> {
|
public object MstAlgebra : NumericAlgebra<MST> {
|
||||||
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
public override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||||
|
public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||||
|
|
||||||
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary =
|
||||||
|
{ arg -> MST.Unary(operation, arg) }
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
MST.Unary(operation, arg)
|
{ left, right -> MST.Binary(operation, left, right) }
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
|
||||||
MST.Binary(operation, left, right)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Space] over [MST] nodes.
|
* [Space] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST.Numeric by lazy { number(0.0) }
|
public override val zero: MST.Numeric by lazy { number(0.0) }
|
||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||||
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
public override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
|
||||||
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
public override fun MST.unaryMinus(): MST = unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
public override fun multiply(a: MST, k: Number): MST.Binary =
|
||||||
MstAlgebra.binaryOperation(operation, left, right)
|
binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k))
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
|
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
|
MstAlgebra.binaryOperation(operation)
|
||||||
|
|
||||||
|
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary =
|
||||||
|
MstAlgebra.unaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -43,16 +46,18 @@ public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
|||||||
|
|
||||||
override val one: MST.Numeric by lazy { number(1.0) }
|
override val one: MST.Numeric by lazy { number(1.0) }
|
||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
public override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||||
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||||
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
|
||||||
|
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
MstSpace.binaryOperation(operation, left, right)
|
MstSpace.binaryOperation(operation)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
|
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary =
|
||||||
|
MstAlgebra.unaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -70,51 +75,52 @@ public object MstField : Field<MST> {
|
|||||||
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||||
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||||
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
|
||||||
|
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
MstRing.binaryOperation(operation, left, right)
|
MstRing.binaryOperation(operation)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
|
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstRing.unaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [ExtendedField] over [MST] nodes.
|
* [ExtendedField] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstExtendedField : ExtendedField<MST> {
|
public object MstExtendedField : ExtendedField<MST> {
|
||||||
override val zero: MST.Numeric
|
public override val zero: MST.Numeric
|
||||||
get() = MstField.zero
|
get() = MstField.zero
|
||||||
|
|
||||||
override val one: MST.Numeric
|
public override val one: MST.Numeric
|
||||||
get() = MstField.one
|
get() = MstField.one
|
||||||
|
|
||||||
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
public override fun symbol(value: String): MST = MstField.symbol(value)
|
||||||
override fun number(value: Number): MST.Numeric = MstField.number(value)
|
public override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||||
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
public override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
|
||||||
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
public override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION)(arg)
|
||||||
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
public override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
|
||||||
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
public override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
|
||||||
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
public override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
|
||||||
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
public override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION)(arg)
|
||||||
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
public override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION)(arg)
|
||||||
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
public override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION)(arg)
|
||||||
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
public override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION)(arg)
|
||||||
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
public override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION)(arg)
|
||||||
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
public override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION)(arg)
|
||||||
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||||
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||||
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this)
|
||||||
|
|
||||||
override fun power(arg: MST, pow: Number): MST.Binary =
|
public override fun power(arg: MST, pow: Number): MST.Binary =
|
||||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||||
|
|
||||||
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
public override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
|
||||||
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
public override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
MstField.binaryOperation(operation, left, right)
|
MstField.binaryOperation(operation)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
|
public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstField.unaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
@ -15,11 +15,14 @@ import kotlin.contracts.contract
|
|||||||
*/
|
*/
|
||||||
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
override fun symbol(value: String): T = try {
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
algebra.symbol(value)
|
||||||
|
} catch (ignored: IllegalStateException) {
|
||||||
|
null
|
||||||
|
} ?: arguments.getValue(StringSymbol(value))
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation)
|
||||||
algebra.binaryOperation(operation, left, right)
|
override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation)
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
package kscience.kmath.asm
|
package kscience.kmath.asm
|
||||||
|
|
||||||
import kscience.kmath.asm.internal.AsmBuilder
|
import kscience.kmath.asm.internal.AsmBuilder
|
||||||
import kscience.kmath.asm.internal.MstType
|
|
||||||
import kscience.kmath.asm.internal.buildAlgebraOperationCall
|
|
||||||
import kscience.kmath.asm.internal.buildName
|
import kscience.kmath.asm.internal.buildName
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles given MST to an Expression using AST compiler.
|
* Compiles given MST to an Expression using AST compiler.
|
||||||
@ -23,37 +23,46 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
|||||||
is MST.Symbolic -> {
|
is MST.Symbolic -> {
|
||||||
val symbol = try {
|
val symbol = try {
|
||||||
algebra.symbol(node.value)
|
algebra.symbol(node.value)
|
||||||
} catch (ignored: Throwable) {
|
} catch (ignored: IllegalStateException) {
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
|
|
||||||
if (symbol != null)
|
if (symbol != null)
|
||||||
loadTConstant(symbol)
|
loadObjectConstant(symbol as Any)
|
||||||
else
|
else
|
||||||
loadVariable(node.value)
|
loadVariable(node.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Numeric -> loadNumeric(node.value)
|
is MST.Numeric -> loadNumberConstant(node.value)
|
||||||
|
is MST.Unary -> buildCall(algebra.unaryOperation(node.operation)) { visit(node.value) }
|
||||||
|
|
||||||
is MST.Unary -> buildAlgebraOperationCall(
|
is MST.Binary -> when {
|
||||||
context = algebra,
|
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
|
||||||
name = node.operation,
|
algebra.number(
|
||||||
fallbackMethodName = "unaryOperation",
|
RealField
|
||||||
parameterTypes = arrayOf(MstType.fromMst(node.value))
|
.binaryOperation(node.operation)
|
||||||
) { visit(node.value) }
|
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
is MST.Binary -> buildAlgebraOperationCall(
|
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperation(node.operation)) {
|
||||||
context = algebra,
|
visit(node.left)
|
||||||
name = node.operation,
|
visit(node.right)
|
||||||
fallbackMethodName = "binaryOperation",
|
}
|
||||||
parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
|
|
||||||
) {
|
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperation(node.operation)) {
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> buildCall(algebra.binaryOperation(node.operation)) {
|
||||||
|
visit(node.left)
|
||||||
|
visit(node.right)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3,29 +3,29 @@ package kscience.kmath.asm.internal
|
|||||||
import kscience.kmath.asm.internal.AsmBuilder.ClassLoader
|
import kscience.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
|
||||||
import org.objectweb.asm.*
|
import org.objectweb.asm.*
|
||||||
import org.objectweb.asm.Opcodes.*
|
import org.objectweb.asm.Opcodes.*
|
||||||
|
import org.objectweb.asm.Type.*
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import java.util.*
|
import java.lang.invoke.MethodHandles
|
||||||
import java.util.stream.Collectors
|
import java.lang.invoke.MethodType
|
||||||
|
import java.util.stream.Collectors.toMap
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
||||||
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
||||||
*
|
*
|
||||||
* @property T the type of AsmExpression to unwrap.
|
* @property T the type of AsmExpression to unwrap.
|
||||||
* @property algebra the algebra the applied AsmExpressions use.
|
|
||||||
* @property className the unique class name of new loaded class.
|
* @property className the unique class name of new loaded class.
|
||||||
* @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
* @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0.
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal class AsmBuilder<T> internal constructor(
|
internal class AsmBuilder<T>(
|
||||||
private val classOfT: Class<*>,
|
classOfT: Class<*>,
|
||||||
private val algebra: Algebra<T>,
|
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit,
|
private val callbackAtInvokeL0: AsmBuilder<T>.() -> Unit,
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||||
@ -39,20 +39,15 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||||
|
|
||||||
/**
|
|
||||||
* ASM Type for [algebra].
|
|
||||||
*/
|
|
||||||
private val tAlgebraType: Type = algebra.javaClass.asm
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [T].
|
* ASM type for [T].
|
||||||
*/
|
*/
|
||||||
internal val tType: Type = classOfT.asm
|
private val tType: Type = classOfT.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for new class.
|
* ASM type for new class.
|
||||||
*/
|
*/
|
||||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/'))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of constants to provide to the subclass.
|
* List of constants to provide to the subclass.
|
||||||
@ -64,55 +59,14 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
/**
|
|
||||||
* States whether this [AsmBuilder] needs to generate constants field.
|
|
||||||
*/
|
|
||||||
private var hasConstants: Boolean = true
|
|
||||||
|
|
||||||
/**
|
|
||||||
* States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
|
||||||
*/
|
|
||||||
internal var primitiveMode: Boolean = false
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
|
||||||
*/
|
|
||||||
internal var primitiveMask: Type = OBJECT_TYPE
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
|
||||||
*/
|
|
||||||
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stack of useful objects types on stack to verify types.
|
|
||||||
*/
|
|
||||||
private val typeStack: ArrayDeque<Type> = ArrayDeque()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stack of useful objects types on stack expected by algebra calls.
|
|
||||||
*/
|
|
||||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>(1).also { it.push(tType) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The cache for instance built by this builder.
|
|
||||||
*/
|
|
||||||
private var generatedInstance: Expression<T>? = null
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||||
*
|
*
|
||||||
* The built instance is cached.
|
* The built instance is cached.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
internal fun getInstance(): Expression<T> {
|
val instance: Expression<T> by lazy {
|
||||||
generatedInstance?.let { return it }
|
val hasConstants: Boolean
|
||||||
|
|
||||||
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
|
||||||
primitiveMode = true
|
|
||||||
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
|
|
||||||
primitiveMaskBoxed = tType
|
|
||||||
}
|
|
||||||
|
|
||||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||||
visit(
|
visit(
|
||||||
@ -121,20 +75,20 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
classType.internalName,
|
classType.internalName,
|
||||||
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
||||||
OBJECT_TYPE.internalName,
|
OBJECT_TYPE.internalName,
|
||||||
arrayOf(EXPRESSION_TYPE.internalName)
|
arrayOf(EXPRESSION_TYPE.internalName),
|
||||||
)
|
)
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL,
|
ACC_PUBLIC or ACC_FINAL,
|
||||||
"invoke",
|
"invoke",
|
||||||
Type.getMethodDescriptor(tType, MAP_TYPE),
|
getMethodDescriptor(tType, MAP_TYPE),
|
||||||
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
||||||
null
|
null,
|
||||||
).instructionAdapter {
|
).instructionAdapter {
|
||||||
invokeMethodVisitor = this
|
invokeMethodVisitor = this
|
||||||
visitCode()
|
visitCode()
|
||||||
val l0 = label()
|
val l0 = label()
|
||||||
invokeLabel0Visitor()
|
callbackAtInvokeL0()
|
||||||
areturn(tType)
|
areturn(tType)
|
||||||
val l1 = label()
|
val l1 = label()
|
||||||
|
|
||||||
@ -144,7 +98,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
null,
|
null,
|
||||||
l0,
|
l0,
|
||||||
l1,
|
l1,
|
||||||
invokeThisVar
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
@ -153,7 +107,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
|
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
|
||||||
l0,
|
l0,
|
||||||
l1,
|
l1,
|
||||||
invokeArgumentsVar
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
visitMaxs(0, 2)
|
visitMaxs(0, 2)
|
||||||
@ -163,17 +117,15 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||||
"invoke",
|
"invoke",
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||||
|
null,
|
||||||
null,
|
null,
|
||||||
null
|
|
||||||
).instructionAdapter {
|
).instructionAdapter {
|
||||||
val thisVar = 0
|
|
||||||
val argumentsVar = 1
|
|
||||||
visitCode()
|
visitCode()
|
||||||
val l0 = label()
|
val l0 = label()
|
||||||
load(thisVar, OBJECT_TYPE)
|
load(0, OBJECT_TYPE)
|
||||||
load(argumentsVar, MAP_TYPE)
|
load(1, MAP_TYPE)
|
||||||
invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false)
|
invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false)
|
||||||
areturn(tType)
|
areturn(tType)
|
||||||
val l1 = label()
|
val l1 = label()
|
||||||
|
|
||||||
@ -183,7 +135,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
null,
|
null,
|
||||||
l0,
|
l0,
|
||||||
l1,
|
l1,
|
||||||
thisVar
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
visitMaxs(0, 2)
|
visitMaxs(0, 2)
|
||||||
@ -192,15 +144,6 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
|
|
||||||
hasConstants = constants.isNotEmpty()
|
hasConstants = constants.isNotEmpty()
|
||||||
|
|
||||||
visitField(
|
|
||||||
access = ACC_PRIVATE or ACC_FINAL,
|
|
||||||
name = "algebra",
|
|
||||||
descriptor = tAlgebraType.descriptor,
|
|
||||||
signature = null,
|
|
||||||
value = null,
|
|
||||||
block = FieldVisitor::visitEnd
|
|
||||||
)
|
|
||||||
|
|
||||||
if (hasConstants)
|
if (hasConstants)
|
||||||
visitField(
|
visitField(
|
||||||
access = ACC_PRIVATE or ACC_FINAL,
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
@ -208,55 +151,36 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
||||||
signature = null,
|
signature = null,
|
||||||
value = null,
|
value = null,
|
||||||
block = FieldVisitor::visitEnd
|
block = FieldVisitor::visitEnd,
|
||||||
)
|
)
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC,
|
ACC_PUBLIC,
|
||||||
"<init>",
|
"<init>",
|
||||||
|
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
||||||
Type.getMethodDescriptor(
|
|
||||||
Type.VOID_TYPE,
|
|
||||||
tAlgebraType,
|
|
||||||
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
|
||||||
|
|
||||||
null,
|
null,
|
||||||
null
|
null
|
||||||
).instructionAdapter {
|
).instructionAdapter {
|
||||||
val thisVar = 0
|
|
||||||
val algebraVar = 1
|
|
||||||
val constantsVar = 2
|
|
||||||
val l0 = label()
|
val l0 = label()
|
||||||
load(thisVar, classType)
|
load(0, classType)
|
||||||
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
|
||||||
label()
|
label()
|
||||||
load(thisVar, classType)
|
load(0, classType)
|
||||||
load(algebraVar, tAlgebraType)
|
|
||||||
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
|
||||||
|
|
||||||
if (hasConstants) {
|
if (hasConstants) {
|
||||||
label()
|
label()
|
||||||
load(thisVar, classType)
|
load(0, classType)
|
||||||
load(constantsVar, OBJECT_ARRAY_TYPE)
|
load(1, OBJECT_ARRAY_TYPE)
|
||||||
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
}
|
}
|
||||||
|
|
||||||
label()
|
label()
|
||||||
visitInsn(RETURN)
|
visitInsn(RETURN)
|
||||||
val l4 = label()
|
val l4 = label()
|
||||||
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
|
||||||
|
|
||||||
visitLocalVariable(
|
|
||||||
"algebra",
|
|
||||||
tAlgebraType.descriptor,
|
|
||||||
null,
|
|
||||||
l0,
|
|
||||||
l4,
|
|
||||||
algebraVar
|
|
||||||
)
|
|
||||||
|
|
||||||
if (hasConstants)
|
if (hasConstants)
|
||||||
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1)
|
||||||
|
|
||||||
visitMaxs(0, 3)
|
visitMaxs(0, 3)
|
||||||
visitEnd()
|
visitEnd()
|
||||||
@ -265,33 +189,56 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
|
|
||||||
val new = classLoader
|
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
||||||
.defineClass(className, classWriter.toByteArray())
|
val l = MethodHandles.publicLookup()
|
||||||
.constructors
|
|
||||||
.first()
|
|
||||||
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
|
|
||||||
|
|
||||||
generatedInstance = new
|
if (hasConstants)
|
||||||
return new
|
l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array<Any>::class.java))
|
||||||
|
.invoke(constants.toTypedArray()) as Expression<T>
|
||||||
|
else
|
||||||
|
l.findConstructor(cls, MethodType.methodType(Void.TYPE)).invoke() as Expression<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a [T] constant from [constants].
|
* Loads [java.lang.Object] constant from constants.
|
||||||
*/
|
*/
|
||||||
internal fun loadTConstant(value: T) {
|
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
|
||||||
if (classOfT in INLINABLE_NUMBERS) {
|
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
|
||||||
val expectedType = expectationStack.pop()
|
loadThis()
|
||||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
loadNumberConstant(value as Number, mustBeBoxed)
|
iconst(idx)
|
||||||
|
visitInsn(AALOAD)
|
||||||
|
if (type != OBJECT_TYPE) checkcast(type)
|
||||||
|
}
|
||||||
|
|
||||||
if (mustBeBoxed)
|
/**
|
||||||
invokeMethodVisitor.checkcast(tType)
|
* Loads `this` variable.
|
||||||
|
*/
|
||||||
|
private fun loadThis(): Unit = invokeMethodVisitor.load(0, classType)
|
||||||
|
|
||||||
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
/**
|
||||||
|
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
|
||||||
|
* constant from the constant pool.
|
||||||
|
*/
|
||||||
|
fun loadNumberConstant(value: Number) {
|
||||||
|
val boxed = value.javaClass.asm
|
||||||
|
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
||||||
|
|
||||||
|
if (primitive != null) {
|
||||||
|
when (primitive) {
|
||||||
|
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
|
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
||||||
|
FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
|
||||||
|
LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
|
||||||
|
INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
|
SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
|
}
|
||||||
|
|
||||||
|
box(primitive)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loadObjectConstant(value as Any, tType)
|
loadObjectConstant(value, boxed)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -303,258 +250,100 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
invokeMethodVisitor.invokestatic(
|
invokeMethodVisitor.invokestatic(
|
||||||
r.internalName,
|
r.internalName,
|
||||||
"valueOf",
|
"valueOf",
|
||||||
Type.getMethodDescriptor(r, primitive),
|
getMethodDescriptor(r, primitive),
|
||||||
false
|
false,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unboxes the current boxed value and pushes it.
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke].
|
||||||
*/
|
*/
|
||||||
private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual(
|
fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
|
||||||
NUMBER_TYPE.internalName,
|
load(1, MAP_TYPE)
|
||||||
NUMBER_CONVERTER_METHODS.getValue(primitive),
|
|
||||||
Type.getMethodDescriptor(primitive),
|
|
||||||
false
|
|
||||||
)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loads [java.lang.Object] constant from constants.
|
|
||||||
*/
|
|
||||||
private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
|
||||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
|
||||||
loadThis()
|
|
||||||
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
|
||||||
iconst(idx)
|
|
||||||
visitInsn(AALOAD)
|
|
||||||
checkcast(type)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun loadNumeric(value: Number) {
|
|
||||||
if (expectationStack.peek() == NUMBER_TYPE) {
|
|
||||||
loadNumberConstant(value, true)
|
|
||||||
expectationStack.pop()
|
|
||||||
typeStack.push(NUMBER_TYPE)
|
|
||||||
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
|
|
||||||
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loads this variable.
|
|
||||||
*/
|
|
||||||
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
|
|
||||||
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
|
|
||||||
* from it).
|
|
||||||
*/
|
|
||||||
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
|
||||||
val boxed = value.javaClass.asm
|
|
||||||
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
|
||||||
|
|
||||||
if (primitive != null) {
|
|
||||||
when (primitive) {
|
|
||||||
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
|
||||||
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
|
||||||
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
|
|
||||||
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
|
|
||||||
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
|
||||||
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mustBeBoxed)
|
|
||||||
box(primitive)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
loadObjectConstant(value, boxed)
|
|
||||||
|
|
||||||
if (!mustBeBoxed)
|
|
||||||
unboxTo(primitiveMask)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
|
|
||||||
* provided.
|
|
||||||
*/
|
|
||||||
internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
|
|
||||||
load(invokeArgumentsVar, MAP_TYPE)
|
|
||||||
aconst(name)
|
aconst(name)
|
||||||
|
|
||||||
invokestatic(
|
invokestatic(
|
||||||
MAP_INTRINSICS_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
"getOrFail",
|
"getOrFail",
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
|
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
|
||||||
false
|
false,
|
||||||
)
|
)
|
||||||
|
|
||||||
checkcast(tType)
|
checkcast(tType)
|
||||||
val expectedType = expectationStack.pop()
|
|
||||||
|
|
||||||
if (expectedType.sort == Type.OBJECT)
|
|
||||||
typeStack.push(tType)
|
|
||||||
else {
|
|
||||||
unboxTo(primitiveMask)
|
|
||||||
typeStack.push(primitiveMask)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
|
||||||
* Loads algebra from according field of the class and casts it to class of [algebra] provided.
|
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||||
*/
|
val `interface` = function.javaClass.interfaces.first { it.interfaces.contains(Function::class.java) }
|
||||||
internal fun loadAlgebra() {
|
|
||||||
loadThis()
|
|
||||||
invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount
|
||||||
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
|
?: error("Provided function object doesn't contain invoke method")
|
||||||
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
|
|
||||||
* called before the arguments and this operation.
|
|
||||||
*
|
|
||||||
* The result is casted to [T] automatically.
|
|
||||||
*/
|
|
||||||
internal fun invokeAlgebraOperation(
|
|
||||||
owner: String,
|
|
||||||
method: String,
|
|
||||||
descriptor: String,
|
|
||||||
expectedArity: Int,
|
|
||||||
opcode: Int = INVOKEINTERFACE,
|
|
||||||
) {
|
|
||||||
run loop@{
|
|
||||||
repeat(expectedArity) {
|
|
||||||
if (typeStack.isEmpty()) return@loop
|
|
||||||
typeStack.pop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
invokeMethodVisitor.visitMethodInsn(
|
val type = getType(`interface`)
|
||||||
opcode,
|
loadObjectConstant(function, type)
|
||||||
owner,
|
parameters(this)
|
||||||
method,
|
|
||||||
descriptor,
|
invokeMethodVisitor.invokeinterface(
|
||||||
opcode == INVOKEINTERFACE
|
type.internalName,
|
||||||
|
"invoke",
|
||||||
|
getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }),
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeMethodVisitor.checkcast(tType)
|
invokeMethodVisitor.checkcast(tType)
|
||||||
val isLastExpr = expectationStack.size == 1
|
|
||||||
val expectedType = expectationStack.pop()
|
|
||||||
|
|
||||||
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
|
||||||
typeStack.push(tType)
|
|
||||||
else {
|
|
||||||
unboxTo(primitiveMask)
|
|
||||||
typeStack.push(primitiveMask)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
companion object {
|
||||||
* Writes a LDC Instruction with string constant provided.
|
|
||||||
*/
|
|
||||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
|
||||||
|
|
||||||
internal companion object {
|
|
||||||
/**
|
|
||||||
* Index of `this` variable in invoke method of the built subclass.
|
|
||||||
*/
|
|
||||||
private const val invokeThisVar: Int = 0
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Index of `arguments` variable in invoke method of the built subclass.
|
|
||||||
*/
|
|
||||||
private const val invokeArgumentsVar: Int = 1
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
|
||||||
*/
|
|
||||||
private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
|
|
||||||
hashMapOf(
|
|
||||||
java.lang.Byte::class.java to Type.BYTE_TYPE,
|
|
||||||
java.lang.Short::class.java to Type.SHORT_TYPE,
|
|
||||||
java.lang.Integer::class.java to Type.INT_TYPE,
|
|
||||||
java.lang.Long::class.java to Type.LONG_TYPE,
|
|
||||||
java.lang.Float::class.java to Type.FLOAT_TYPE,
|
|
||||||
java.lang.Double::class.java to Type.DOUBLE_TYPE
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||||
*/
|
*/
|
||||||
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
|
||||||
|
hashMapOf(
|
||||||
|
Byte::class.java.asm to BYTE_TYPE,
|
||||||
|
Short::class.java.asm to SHORT_TYPE,
|
||||||
|
Integer::class.java.asm to INT_TYPE,
|
||||||
|
Long::class.java.asm to LONG_TYPE,
|
||||||
|
Float::class.java.asm to FLOAT_TYPE,
|
||||||
|
Double::class.java.asm to DOUBLE_TYPE,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||||
*/
|
*/
|
||||||
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
|
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
|
||||||
BOXED_TO_PRIMITIVES.entries.stream().collect(
|
BOXED_TO_PRIMITIVES.entries.stream().collect(
|
||||||
Collectors.toMap(
|
toMap(Map.Entry<Type, Type>::value, Map.Entry<Type, Type>::key),
|
||||||
Map.Entry<Type, Type>::value,
|
|
||||||
Map.Entry<Type, Type>::key
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Maps primitive ASM types to [Number] functions unboxing them.
|
|
||||||
*/
|
|
||||||
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
|
|
||||||
hashMapOf(
|
|
||||||
Type.BYTE_TYPE to "byteValue",
|
|
||||||
Type.SHORT_TYPE to "shortValue",
|
|
||||||
Type.INT_TYPE to "intValue",
|
|
||||||
Type.LONG_TYPE to "longValue",
|
|
||||||
Type.FLOAT_TYPE to "floatValue",
|
|
||||||
Type.DOUBLE_TYPE to "doubleValue"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
|
||||||
*/
|
|
||||||
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [Expression].
|
* ASM type for [Expression].
|
||||||
*/
|
*/
|
||||||
internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") }
|
val EXPRESSION_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Expression") }
|
||||||
|
|
||||||
/**
|
|
||||||
* ASM type for [java.lang.Number].
|
|
||||||
*/
|
|
||||||
internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.util.Map].
|
* ASM type for [java.util.Map].
|
||||||
*/
|
*/
|
||||||
internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") }
|
val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.Object].
|
* ASM type for [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") }
|
val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for array of [java.lang.Object].
|
* ASM type for array of [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") }
|
||||||
internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ASM type for [Algebra].
|
|
||||||
*/
|
|
||||||
internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.String].
|
* ASM type for [java.lang.String].
|
||||||
*/
|
*/
|
||||||
internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") }
|
val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for MapIntrinsics.
|
* ASM type for MapIntrinsics.
|
||||||
*/
|
*/
|
||||||
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/asm/internal/MapIntrinsics") }
|
val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("kscience/kmath/asm/internal/MapIntrinsics") }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
package kscience.kmath.asm.internal
|
|
||||||
|
|
||||||
import kscience.kmath.ast.MST
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents types known in [MST], numbers and general values.
|
|
||||||
*/
|
|
||||||
internal enum class MstType {
|
|
||||||
GENERAL,
|
|
||||||
NUMBER;
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
fun fromMst(mst: MST): MstType {
|
|
||||||
if (mst is MST.Numeric)
|
|
||||||
return NUMBER
|
|
||||||
|
|
||||||
return GENERAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -2,29 +2,11 @@ package kscience.kmath.asm.internal
|
|||||||
|
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
|
||||||
import kscience.kmath.operations.FieldOperations
|
|
||||||
import kscience.kmath.operations.RingOperations
|
|
||||||
import kscience.kmath.operations.SpaceOperations
|
|
||||||
import org.objectweb.asm.*
|
import org.objectweb.asm.*
|
||||||
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import java.lang.reflect.Method
|
|
||||||
import java.util.*
|
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
|
||||||
hashMapOf(
|
|
||||||
SpaceOperations.PLUS_OPERATION to 2 to "add",
|
|
||||||
RingOperations.TIMES_OPERATION to 2 to "multiply",
|
|
||||||
FieldOperations.DIV_OPERATION to 2 to "divide",
|
|
||||||
SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus",
|
|
||||||
SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus",
|
|
||||||
SpaceOperations.MINUS_OPERATION to 2 to "minus"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns ASM [Type] for given [Class].
|
* Returns ASM [Type] for given [Class].
|
||||||
*
|
*
|
||||||
@ -109,107 +91,3 @@ internal inline fun ClassWriter.visitField(
|
|||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return visitField(access, name, descriptor, signature, value).apply(block)
|
return visitField(access, name, descriptor, signature, value).apply(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
|
|
||||||
context.javaClass.methods.find { method ->
|
|
||||||
val nameValid = method.name == name
|
|
||||||
val arityValid = method.parameters.size == parameterTypes.size
|
|
||||||
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
|
|
||||||
|
|
||||||
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
|
|
||||||
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
|
|
||||||
}
|
|
||||||
|
|
||||||
nameValid && arityValid && notBridgeInPrimitive && paramsValid
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
|
|
||||||
* type expectation stack for needed arity.
|
|
||||||
*
|
|
||||||
* @author Iaroslav Postovalov
|
|
||||||
*/
|
|
||||||
private fun <T> AsmBuilder<T>.buildExpectationStack(
|
|
||||||
context: Algebra<T>,
|
|
||||||
name: String,
|
|
||||||
parameterTypes: Array<MstType>
|
|
||||||
): Boolean {
|
|
||||||
val arity = parameterTypes.size
|
|
||||||
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
|
|
||||||
|
|
||||||
if (specific != null)
|
|
||||||
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
|
||||||
else
|
|
||||||
expectationStack.addAll(Collections.nCopies(arity, tType))
|
|
||||||
|
|
||||||
return specific != null
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
|
|
||||||
.parameterTypes
|
|
||||||
.zip(parameterTypes)
|
|
||||||
.map { (type, mstType) ->
|
|
||||||
when {
|
|
||||||
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
|
|
||||||
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
|
|
||||||
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
|
||||||
*
|
|
||||||
* @author Iaroslav Postovalov
|
|
||||||
*/
|
|
||||||
private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
|
||||||
context: Algebra<T>,
|
|
||||||
name: String,
|
|
||||||
parameterTypes: Array<MstType>
|
|
||||||
): Boolean {
|
|
||||||
val arity = parameterTypes.size
|
|
||||||
val theName = methodNameAdapters[name to arity] ?: name
|
|
||||||
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
|
||||||
val owner = context.javaClass.asm
|
|
||||||
|
|
||||||
invokeAlgebraOperation(
|
|
||||||
owner = owner.internalName,
|
|
||||||
method = theName,
|
|
||||||
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
|
|
||||||
expectedArity = arity,
|
|
||||||
opcode = INVOKEVIRTUAL
|
|
||||||
)
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds specialized [context] call with option to fallback to generic algebra operation accepting [String].
|
|
||||||
*
|
|
||||||
* @author Iaroslav Postovalov
|
|
||||||
*/
|
|
||||||
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
|
||||||
context: Algebra<T>,
|
|
||||||
name: String,
|
|
||||||
fallbackMethodName: String,
|
|
||||||
parameterTypes: Array<MstType>,
|
|
||||||
parameters: AsmBuilder<T>.() -> Unit
|
|
||||||
) {
|
|
||||||
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
val arity = parameterTypes.size
|
|
||||||
loadAlgebra()
|
|
||||||
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
|
|
||||||
parameters()
|
|
||||||
|
|
||||||
if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
|
|
||||||
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
|
||||||
method = fallbackMethodName,
|
|
||||||
|
|
||||||
descriptor = Type.getMethodDescriptor(
|
|
||||||
AsmBuilder.OBJECT_TYPE,
|
|
||||||
AsmBuilder.STRING_TYPE,
|
|
||||||
*Array(arity) { AsmBuilder.OBJECT_TYPE }
|
|
||||||
),
|
|
||||||
|
|
||||||
expectedArity = arity
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
@ -10,15 +10,11 @@ import kotlin.test.Test
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestAsmAlgebras {
|
internal class TestAsmAlgebras {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun space() {
|
fun space() {
|
||||||
val res1 = ByteRing.mstInSpace {
|
val res1 = ByteRing.mstInSpace {
|
||||||
binaryOperation(
|
binaryOperation("+")(
|
||||||
"+",
|
unaryOperation("+")(
|
||||||
|
|
||||||
unaryOperation(
|
|
||||||
"+",
|
|
||||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||||
add(number(1), number(1)),
|
add(number(1), number(1)),
|
||||||
2
|
2
|
||||||
@ -30,11 +26,8 @@ internal class TestAsmAlgebras {
|
|||||||
}("x" to 2.toByte())
|
}("x" to 2.toByte())
|
||||||
|
|
||||||
val res2 = ByteRing.mstInSpace {
|
val res2 = ByteRing.mstInSpace {
|
||||||
binaryOperation(
|
binaryOperation("+")(
|
||||||
"+",
|
unaryOperation("+")(
|
||||||
|
|
||||||
unaryOperation(
|
|
||||||
"+",
|
|
||||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||||
add(number(1), number(1)),
|
add(number(1), number(1)),
|
||||||
2
|
2
|
||||||
@ -51,11 +44,8 @@ internal class TestAsmAlgebras {
|
|||||||
@Test
|
@Test
|
||||||
fun ring() {
|
fun ring() {
|
||||||
val res1 = ByteRing.mstInRing {
|
val res1 = ByteRing.mstInRing {
|
||||||
binaryOperation(
|
binaryOperation("+")(
|
||||||
"+",
|
unaryOperation("+")(
|
||||||
|
|
||||||
unaryOperation(
|
|
||||||
"+",
|
|
||||||
(symbol("x") - (2.toByte() + (multiply(
|
(symbol("x") - (2.toByte() + (multiply(
|
||||||
add(number(1), number(1)),
|
add(number(1), number(1)),
|
||||||
2
|
2
|
||||||
@ -67,17 +57,13 @@ internal class TestAsmAlgebras {
|
|||||||
}("x" to 3.toByte())
|
}("x" to 3.toByte())
|
||||||
|
|
||||||
val res2 = ByteRing.mstInRing {
|
val res2 = ByteRing.mstInRing {
|
||||||
binaryOperation(
|
binaryOperation("+")(
|
||||||
"+",
|
unaryOperation("+")(
|
||||||
|
|
||||||
unaryOperation(
|
|
||||||
"+",
|
|
||||||
(symbol("x") - (2.toByte() + (multiply(
|
(symbol("x") - (2.toByte() + (multiply(
|
||||||
add(number(1), number(1)),
|
add(number(1), number(1)),
|
||||||
2
|
2
|
||||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||||
),
|
),
|
||||||
|
|
||||||
number(1)
|
number(1)
|
||||||
) * number(2)
|
) * number(2)
|
||||||
}.compile()("x" to 3.toByte())
|
}.compile()("x" to 3.toByte())
|
||||||
@ -88,8 +74,7 @@ internal class TestAsmAlgebras {
|
|||||||
@Test
|
@Test
|
||||||
fun field() {
|
fun field() {
|
||||||
val res1 = RealField.mstInField {
|
val res1 = RealField.mstInField {
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")(
|
||||||
"+",
|
|
||||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
number(1) / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
@ -97,8 +82,7 @@ internal class TestAsmAlgebras {
|
|||||||
}("x" to 2.0)
|
}("x" to 2.0)
|
||||||
|
|
||||||
val res2 = RealField.mstInField {
|
val res2 = RealField.mstInField {
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")(
|
||||||
"+",
|
|
||||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
number(1) / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
package kscience.kmath.asm
|
package kscience.kmath.asm
|
||||||
|
|
||||||
import kscience.kmath.asm.compile
|
import kscience.kmath.ast.mstInExtendedField
|
||||||
import kscience.kmath.ast.mstInField
|
import kscience.kmath.ast.mstInField
|
||||||
import kscience.kmath.ast.mstInSpace
|
import kscience.kmath.ast.mstInSpace
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
|
import kotlin.random.Random
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -28,4 +29,13 @@ internal class TestAsmExpressions {
|
|||||||
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
||||||
assertEquals(4.0, res)
|
assertEquals(4.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMultipleCalls() {
|
||||||
|
val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile()
|
||||||
|
val r = Random(0)
|
||||||
|
var s = 0.0
|
||||||
|
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
||||||
|
println(s)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package kscience.kmath.asm
|
package kscience.kmath.asm
|
||||||
|
|
||||||
import kscience.kmath.asm.compile
|
|
||||||
import kscience.kmath.ast.mstInField
|
import kscience.kmath.ast.mstInField
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
@ -10,44 +9,44 @@ import kotlin.test.assertEquals
|
|||||||
internal class TestAsmSpecialization {
|
internal class TestAsmSpecialization {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryPlus() {
|
fun testUnaryPlus() {
|
||||||
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
val expr = RealField.mstInField { unaryOperation("+")(symbol("x")) }.compile()
|
||||||
assertEquals(2.0, expr("x" to 2.0))
|
assertEquals(2.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryMinus() {
|
fun testUnaryMinus() {
|
||||||
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
val expr = RealField.mstInField { unaryOperation("-")(symbol("x")) }.compile()
|
||||||
assertEquals(-2.0, expr("x" to 2.0))
|
assertEquals(-2.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAdd() {
|
fun testAdd() {
|
||||||
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
val expr = RealField.mstInField { binaryOperation("+")(symbol("x"), symbol("x")) }.compile()
|
||||||
assertEquals(4.0, expr("x" to 2.0))
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSine() {
|
fun testSine() {
|
||||||
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
|
val expr = RealField.mstInField { unaryOperation("sin")(symbol("x")) }.compile()
|
||||||
assertEquals(0.0, expr("x" to 0.0))
|
assertEquals(0.0, expr("x" to 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testMinus() {
|
fun testMinus() {
|
||||||
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
val expr = RealField.mstInField { binaryOperation("-")(symbol("x"), symbol("x")) }.compile()
|
||||||
assertEquals(0.0, expr("x" to 2.0))
|
assertEquals(0.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDivide() {
|
fun testDivide() {
|
||||||
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
val expr = RealField.mstInField { binaryOperation("/")(symbol("x"), symbol("x")) }.compile()
|
||||||
assertEquals(1.0, expr("x" to 2.0))
|
assertEquals(1.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPower() {
|
fun testPower() {
|
||||||
val expr = RealField
|
val expr = RealField
|
||||||
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
|
.mstInField { binaryOperation("pow")(symbol("x"), number(2)) }
|
||||||
.compile()
|
.compile()
|
||||||
|
|
||||||
assertEquals(4.0, expr("x" to 2.0))
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
|
@ -17,6 +17,6 @@ internal class TestAsmVariables {
|
|||||||
@Test
|
@Test
|
||||||
fun testVariableWithoutDefaultFails() {
|
fun testVariableWithoutDefaultFails() {
|
||||||
val expr = ByteRing.mstInRing { symbol("x") }
|
val expr = ByteRing.mstInRing { symbol("x") }
|
||||||
assertFailsWith<IllegalStateException> { expr() }
|
assertFailsWith<NoSuchElementException> { expr() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
package kscience.kmath.ast
|
package kscience.kmath.ast
|
||||||
|
|
||||||
import kscience.kmath.ast.evaluate
|
|
||||||
import kscience.kmath.ast.mstInField
|
|
||||||
import kscience.kmath.ast.parseMath
|
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kscience.kmath.operations.Complex
|
import kscience.kmath.operations.Complex
|
||||||
@ -45,12 +42,15 @@ internal class ParserTest {
|
|||||||
val magicalAlgebra = object : Algebra<String> {
|
val magicalAlgebra = object : Algebra<String> {
|
||||||
override fun symbol(value: String): String = value
|
override fun symbol(value: String): String = value
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
override fun unaryOperation(operation: String): (arg: String) -> String {
|
||||||
|
throw NotImplementedError()
|
||||||
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
|
||||||
"magic" -> "$left ★ $right"
|
|
||||||
else -> throw NotImplementedError()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String): (left: String, right: String) -> String =
|
||||||
|
when (operation) {
|
||||||
|
"magic" -> { left, right -> "$left ★ $right" }
|
||||||
|
else -> throw NotImplementedError()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val mst = "magic(a, b)".parseMath()
|
val mst = "magic(a, b)".parseMath()
|
||||||
|
@ -7,9 +7,8 @@ import kscience.kmath.operations.*
|
|||||||
*
|
*
|
||||||
* @param algebra The algebra to provide for Expressions built.
|
* @param algebra The algebra to provide for Expressions built.
|
||||||
*/
|
*/
|
||||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
|
||||||
public val algebra: A,
|
ExpressionAlgebra<T, Expression<T>> {
|
||||||
) : ExpressionAlgebra<T, Expression<T>> {
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression which does not depend on arguments.
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
@ -25,19 +24,18 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
*/
|
*/
|
||||||
public override fun binaryOperation(
|
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
operation: String,
|
{ left, right ->
|
||||||
left: Expression<T>,
|
Expression { arguments ->
|
||||||
right: Expression<T>,
|
algebra.binaryOperation(operation)(left.invoke(arguments), right.invoke(arguments))
|
||||||
): Expression<T> = Expression { arguments ->
|
}
|
||||||
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
*/
|
*/
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
|
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
||||||
algebra.unaryOperation(operation, arg.invoke(arguments))
|
Expression { arguments -> algebra.unaryOperation(operation)(arg.invoke(arguments)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +50,7 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
|||||||
* Builds an Expression of addition of two another expressions.
|
* Builds an Expression of addition of two another expressions.
|
||||||
*/
|
*/
|
||||||
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of multiplication of expression by number.
|
* Builds an Expression of multiplication of expression by number.
|
||||||
@ -66,11 +64,11 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
|||||||
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||||
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||||
|
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
super<FunctionalExpressionAlgebra>.unaryOperation(operation)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionAlgebra>.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||||
@ -82,16 +80,16 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
|||||||
* Builds an Expression of multiplication of two expressions.
|
* Builds an Expression of multiplication of two expressions.
|
||||||
*/
|
*/
|
||||||
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
binaryOperation(RingOperations.TIMES_OPERATION)(a, b)
|
||||||
|
|
||||||
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
|
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
super<FunctionalExpressionSpace>.unaryOperation(operation)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionSpace>.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||||
@ -101,49 +99,49 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
|
|||||||
* Builds an Expression of division an expression by another one.
|
* Builds an Expression of division an expression by another one.
|
||||||
*/
|
*/
|
||||||
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
binaryOperation(FieldOperations.DIV_OPERATION)(a, b)
|
||||||
|
|
||||||
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
|
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
super<FunctionalExpressionRing>.unaryOperation(operation)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionRing>.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||||
FunctionalExpressionField<T, A>(algebra),
|
FunctionalExpressionField<T, A>(algebra),
|
||||||
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||||
public override fun sin(arg: Expression<T>): Expression<T> =
|
public override fun sin(arg: Expression<T>): Expression<T> =
|
||||||
unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
unaryOperation(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||||
|
|
||||||
public override fun cos(arg: Expression<T>): Expression<T> =
|
public override fun cos(arg: Expression<T>): Expression<T> =
|
||||||
unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
unaryOperation(TrigonometricOperations.COS_OPERATION)(arg)
|
||||||
|
|
||||||
public override fun asin(arg: Expression<T>): Expression<T> =
|
public override fun asin(arg: Expression<T>): Expression<T> =
|
||||||
unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
unaryOperation(TrigonometricOperations.ASIN_OPERATION)(arg)
|
||||||
|
|
||||||
public override fun acos(arg: Expression<T>): Expression<T> =
|
public override fun acos(arg: Expression<T>): Expression<T> =
|
||||||
unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
unaryOperation(TrigonometricOperations.ACOS_OPERATION)(arg)
|
||||||
|
|
||||||
public override fun atan(arg: Expression<T>): Expression<T> =
|
public override fun atan(arg: Expression<T>): Expression<T> =
|
||||||
unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
unaryOperation(TrigonometricOperations.ATAN_OPERATION)(arg)
|
||||||
|
|
||||||
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
||||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||||
|
|
||||||
public override fun exp(arg: Expression<T>): Expression<T> =
|
public override fun exp(arg: Expression<T>): Expression<T> =
|
||||||
unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
unaryOperation(ExponentialOperations.EXP_OPERATION)(arg)
|
||||||
|
|
||||||
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION)(arg)
|
||||||
|
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionField>.unaryOperation(operation, arg)
|
super<FunctionalExpressionField>.unaryOperation(operation)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionField>.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
|
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
|
||||||
|
@ -19,10 +19,11 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
|
|||||||
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
|
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): M = when (operation) {
|
public override fun binaryOperation(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M =
|
||||||
"dot" -> left dot right
|
when (operation) {
|
||||||
else -> super.binaryOperation(operation, left, right) as M
|
"dot" -> { left, right -> left dot right }
|
||||||
}
|
else -> super.binaryOperation(operation) as (Matrix<T>, Matrix<T>) -> M
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the dot product of this matrix and another one.
|
* Computes the dot product of this matrix and another one.
|
||||||
|
@ -13,19 +13,21 @@ public annotation class KMathContext
|
|||||||
*/
|
*/
|
||||||
public interface Algebra<T> {
|
public interface Algebra<T> {
|
||||||
/**
|
/**
|
||||||
* Wrap raw string or variable
|
* Wraps raw string or variable.
|
||||||
*/
|
*/
|
||||||
public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
|
public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dynamic call of unary operation with name [operation] on [arg]
|
* Dynamically dispatches an unary operation with name [operation].
|
||||||
*/
|
*/
|
||||||
public fun unaryOperation(operation: String, arg: T): T
|
public fun unaryOperation(operation: String): (arg: T) -> T =
|
||||||
|
error("Unary operation $operation not defined in $this")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dynamic call of binary operation [operation] on [left] and [right]
|
* Dynamically dispatches a binary operation with name [operation].
|
||||||
*/
|
*/
|
||||||
public fun binaryOperation(operation: String, left: T, right: T): T
|
public fun binaryOperation(operation: String): (left: T, right: T) -> T =
|
||||||
|
error("Binary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -40,16 +42,28 @@ public interface NumericAlgebra<T> : Algebra<T> {
|
|||||||
public fun number(value: Number): T
|
public fun number(value: Number): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
|
* Dynamically dispatches a binary operation with name [operation] where the left argument is [Number].
|
||||||
*/
|
*/
|
||||||
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
public fun leftSideNumberOperation(operation: String): (left: Number, right: T) -> T =
|
||||||
binaryOperation(operation, number(left), right)
|
{ l, r -> binaryOperation(operation)(number(l), r) }
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Dynamically calls a binary operation with name [operation] where the left argument is [Number].
|
||||||
|
// */
|
||||||
|
// public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||||
|
// leftSideNumberOperation(operation)(left, right)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
|
* Dynamically dispatches a binary operation with name [operation] where the right argument is [Number].
|
||||||
*/
|
*/
|
||||||
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
public fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
|
||||||
leftSideNumberOperation(operation, right, left)
|
{ l, r -> binaryOperation(operation)(l, number(r)) }
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Dynamically calls a binary operation with name [operation] where the right argument is [Number].
|
||||||
|
// */
|
||||||
|
// public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||||
|
// rightSideNumberOperation(operation)(left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -146,16 +160,16 @@ public interface SpaceOperations<T> : Algebra<T> {
|
|||||||
*/
|
*/
|
||||||
public operator fun Number.times(b: T): T = b * this
|
public operator fun Number.times(b: T): T = b * this
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
|
||||||
PLUS_OPERATION -> arg
|
PLUS_OPERATION -> { arg -> arg }
|
||||||
MINUS_OPERATION -> -arg
|
MINUS_OPERATION -> { arg -> -arg }
|
||||||
else -> error("Unary operation $operation not defined in $this")
|
else -> super.unaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||||
PLUS_OPERATION -> add(left, right)
|
PLUS_OPERATION -> ::add
|
||||||
MINUS_OPERATION -> left - right
|
MINUS_OPERATION -> { left, right -> left - right }
|
||||||
else -> error("Binary operation $operation not defined in $this")
|
else -> super.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
@ -207,9 +221,9 @@ public interface RingOperations<T> : SpaceOperations<T> {
|
|||||||
*/
|
*/
|
||||||
public operator fun T.times(b: T): T = multiply(this, b)
|
public operator fun T.times(b: T): T = multiply(this, b)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||||
TIMES_OPERATION -> multiply(left, right)
|
TIMES_OPERATION -> ::multiply
|
||||||
else -> super.binaryOperation(operation, left, right)
|
else -> super.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
@ -234,20 +248,6 @@ public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
|||||||
|
|
||||||
override fun number(value: Number): T = one * value.toDouble()
|
override fun number(value: Number): T = one * value.toDouble()
|
||||||
|
|
||||||
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
|
|
||||||
SpaceOperations.PLUS_OPERATION -> left + right
|
|
||||||
SpaceOperations.MINUS_OPERATION -> left - right
|
|
||||||
RingOperations.TIMES_OPERATION -> left * right
|
|
||||||
else -> super.leftSideNumberOperation(operation, left, right)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
|
||||||
SpaceOperations.PLUS_OPERATION -> left + right
|
|
||||||
SpaceOperations.MINUS_OPERATION -> left - right
|
|
||||||
RingOperations.TIMES_OPERATION -> left * right
|
|
||||||
else -> super.rightSideNumberOperation(operation, left, right)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Addition of element and scalar.
|
* Addition of element and scalar.
|
||||||
*
|
*
|
||||||
@ -308,9 +308,9 @@ public interface FieldOperations<T> : RingOperations<T> {
|
|||||||
*/
|
*/
|
||||||
public operator fun T.div(b: T): T = divide(this, b)
|
public operator fun T.div(b: T): T = divide(this, b)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||||
DIV_OPERATION -> divide(left, right)
|
DIV_OPERATION -> ::divide
|
||||||
else -> super.binaryOperation(operation, left, right)
|
else -> super.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
|
@ -15,23 +15,23 @@ public interface ExtendedFieldOperations<T> :
|
|||||||
public override fun tan(arg: T): T = sin(arg) / cos(arg)
|
public override fun tan(arg: T): T = sin(arg) / cos(arg)
|
||||||
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
|
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
|
||||||
|
|
||||||
public override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
public override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) {
|
||||||
TrigonometricOperations.COS_OPERATION -> cos(arg)
|
TrigonometricOperations.COS_OPERATION -> ::cos
|
||||||
TrigonometricOperations.SIN_OPERATION -> sin(arg)
|
TrigonometricOperations.SIN_OPERATION -> ::sin
|
||||||
TrigonometricOperations.TAN_OPERATION -> tan(arg)
|
TrigonometricOperations.TAN_OPERATION -> ::tan
|
||||||
TrigonometricOperations.ACOS_OPERATION -> acos(arg)
|
TrigonometricOperations.ACOS_OPERATION -> ::acos
|
||||||
TrigonometricOperations.ASIN_OPERATION -> asin(arg)
|
TrigonometricOperations.ASIN_OPERATION -> ::asin
|
||||||
TrigonometricOperations.ATAN_OPERATION -> atan(arg)
|
TrigonometricOperations.ATAN_OPERATION -> ::atan
|
||||||
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
|
HyperbolicOperations.COSH_OPERATION -> ::cosh
|
||||||
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
|
HyperbolicOperations.SINH_OPERATION -> ::sinh
|
||||||
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
|
HyperbolicOperations.TANH_OPERATION -> ::tanh
|
||||||
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
|
HyperbolicOperations.ACOSH_OPERATION -> ::acosh
|
||||||
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
|
HyperbolicOperations.ASINH_OPERATION -> ::asinh
|
||||||
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
|
HyperbolicOperations.ATANH_OPERATION -> ::atanh
|
||||||
PowerOperations.SQRT_OPERATION -> sqrt(arg)
|
PowerOperations.SQRT_OPERATION -> ::sqrt
|
||||||
ExponentialOperations.EXP_OPERATION -> exp(arg)
|
ExponentialOperations.EXP_OPERATION -> ::exp
|
||||||
ExponentialOperations.LN_OPERATION -> ln(arg)
|
ExponentialOperations.LN_OPERATION -> ::ln
|
||||||
else -> super.unaryOperation(operation, arg)
|
else -> super<FieldOperations>.unaryOperation(operation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,10 +46,11 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
|||||||
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
|
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
|
||||||
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
|
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
|
||||||
|
|
||||||
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
public override fun rightSideNumberOperation(operation: String): (left: T, right: Number) -> T =
|
||||||
PowerOperations.POW_OPERATION -> power(left, right)
|
when (operation) {
|
||||||
else -> super.rightSideNumberOperation(operation, left, right)
|
PowerOperations.POW_OPERATION -> ::power
|
||||||
}
|
else -> super.rightSideNumberOperation(operation)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -80,10 +81,11 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
|||||||
public override val one: Double
|
public override val one: Double
|
||||||
get() = 1.0
|
get() = 1.0
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
|
public override fun binaryOperation(operation: String): (left: Double, right: Double) -> Double =
|
||||||
PowerOperations.POW_OPERATION -> left pow right
|
when (operation) {
|
||||||
else -> super.binaryOperation(operation, left, right)
|
PowerOperations.POW_OPERATION -> ::power
|
||||||
}
|
else -> super.binaryOperation(operation)
|
||||||
|
}
|
||||||
|
|
||||||
public override inline fun add(a: Double, b: Double): Double = a + b
|
public override inline fun add(a: Double, b: Double): Double = a + b
|
||||||
public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||||
@ -130,9 +132,9 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
|||||||
public override val one: Float
|
public override val one: Float
|
||||||
get() = 1.0f
|
get() = 1.0f
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
|
public override fun binaryOperation(operation: String): (left: Float, right: Float) -> Float = when (operation) {
|
||||||
PowerOperations.POW_OPERATION -> left pow right
|
PowerOperations.POW_OPERATION -> ::power
|
||||||
else -> super.binaryOperation(operation, left, right)
|
else -> super.binaryOperation(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
public override inline fun add(a: Float, b: Float): Float = a + b
|
public override inline fun add(a: Float, b: Float): Float = a + b
|
||||||
|
Loading…
Reference in New Issue
Block a user