forked from kscience/kmath
# Conflicts: # README.md # build.gradle.kts # examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt # examples/src/main/kotlin/space/kscience/kmath/structures/ParallelRealNDField.kt # kmath-ast/README.md # kmath-ast/build.gradle.kts # kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstExpression.kt # kmath-complex/README.md # kmath-complex/build.gradle.kts # kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexNDField.kt # kmath-core/README.md # kmath-core/build.gradle.kts # kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt # kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt # kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/RealNDField.kt # kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/RealBufferField.kt # kmath-coroutines/build.gradle.kts # kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingRealChain.kt # kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyNDStructure.kt # kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt # kmath-for-real/README.md # kmath-functions/README.md # kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/Interpolator.kt # kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LoessInterpolator.kt # kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/XYPointSet.kt # kmath-geometry/build.gradle.kts # kmath-nd4j/README.md # kmath-nd4j/build.gradle.kts # kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt # kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Distribution.kt # kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Fitting.kt # kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/OptimizationProblem.kt # kmath-stat/src/jvmMain/kotlin/space/kscience/kmath/stat/distributions.kt # kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorNDStructure.kt
96 lines
3.4 KiB
Kotlin
96 lines
3.4 KiB
Kotlin
/*
|
|
* Copyright 2018-2021 KMath contributors.
|
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
*/
|
|
|
|
package space.kscience.kmath.asm
|
|
|
|
import space.kscience.kmath.asm.internal.AsmBuilder
|
|
import space.kscience.kmath.asm.internal.buildName
|
|
import space.kscience.kmath.expressions.Expression
|
|
import space.kscience.kmath.expressions.MST
|
|
import space.kscience.kmath.expressions.MST.*
|
|
import space.kscience.kmath.expressions.invoke
|
|
import space.kscience.kmath.misc.Symbol
|
|
import space.kscience.kmath.operations.Algebra
|
|
import space.kscience.kmath.operations.NumericAlgebra
|
|
|
|
/**
|
|
* Compiles given MST to an Expression using AST compiler.
|
|
*
|
|
* @param type the target type.
|
|
* @param algebra the target algebra.
|
|
* @return the compiled expression.
|
|
* @author Alexander Nozik
|
|
*/
|
|
@PublishedApi
|
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
|
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
|
is Symbolic -> {
|
|
val symbol = algebra.bindSymbolOrNull(node.value)
|
|
|
|
if (symbol != null)
|
|
loadObjectConstant(symbol as Any)
|
|
else
|
|
loadVariable(node.value)
|
|
}
|
|
|
|
is Numeric -> loadNumberConstant(node.value)
|
|
|
|
is Unary -> when {
|
|
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
|
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
|
|
|
|
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
|
}
|
|
|
|
is Binary -> when {
|
|
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
|
algebra.binaryOperationFunction(node.operation).invoke(
|
|
algebra.number((node.left as Numeric).value),
|
|
algebra.number((node.right as Numeric).value)
|
|
)
|
|
)
|
|
|
|
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
|
algebra.leftSideNumberOperationFunction(node.operation)) {
|
|
visit(node.left)
|
|
visit(node.right)
|
|
}
|
|
|
|
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
|
algebra.rightSideNumberOperationFunction(node.operation)) {
|
|
visit(node.left)
|
|
visit(node.right)
|
|
}
|
|
|
|
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
|
|
visit(node.left)
|
|
visit(node.right)
|
|
}
|
|
}
|
|
}
|
|
|
|
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
|
}
|
|
|
|
|
|
/**
|
|
* Create a compiled expression with given [MST] and given [algebra].
|
|
*/
|
|
public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
|
|
compileWith(T::class.java, algebra)
|
|
|
|
|
|
/**
|
|
* Compile given MST to expression and evaluate it against [arguments]
|
|
*/
|
|
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
|
compileToExpression(algebra).invoke(arguments)
|
|
|
|
/**
|
|
* Compile given MST to expression and evaluate it against [arguments]
|
|
*/
|
|
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
|
compileToExpression(algebra).invoke(*arguments)
|