Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-refactor-agc
# Conflicts: # kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt # kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt # kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt # kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt
This commit is contained in:
commit
fe64537cbc
examples/src/main/kotlin/scientifik/kmath/structures
kmath-ast/src
commonMain/kotlin/scientifik/kmath/ast
jvmMain/kotlin/scientifik/kmath/asm
jvmTest/kotlin/scietifik/kmath
kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions
kmath-core/src/commonMain/kotlin/scientifik/kmath
@ -27,7 +27,7 @@ fun main() {
|
||||
|
||||
val complexTime = measureTimeMillis {
|
||||
complexField.run {
|
||||
var res = one
|
||||
var res: NDBuffer<Complex> = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
}
|
||||
|
@ -23,14 +23,14 @@ fun main() {
|
||||
|
||||
measureAndPrint("Automatic field addition") {
|
||||
autoField.run {
|
||||
var res = one
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
res += number(1.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Element addition"){
|
||||
measureAndPrint("Element addition") {
|
||||
var res = genericField.one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
@ -63,7 +63,7 @@ fun main() {
|
||||
genericField.run {
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
res += one // con't avoid using `one` due to resolution ambiguity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.NumericAlgebra
|
||||
import scientifik.kmath.operations.RealField
|
||||
|
||||
@ -40,12 +41,14 @@ sealed class MST {
|
||||
|
||||
//TODO add a function with named arguments
|
||||
|
||||
fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
|
||||
fun <T> Algebra<T>.evaluate(node: MST): T {
|
||||
return when (node) {
|
||||
is MST.Numeric -> number(node.value)
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
@ -59,4 +62,6 @@ fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
@ -0,0 +1,33 @@
|
||||
@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE")
|
||||
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
object MSTAlgebra : NumericAlgebra<MST> {
|
||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right)
|
||||
|
||||
override fun number(value: Number): MST = MST.Numeric(value)
|
||||
}
|
||||
|
||||
object MSTSpace : Space<MST>, NumericAlgebra<MST> by MSTAlgebra {
|
||||
override val zero: MST = number(0.0)
|
||||
|
||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
}
|
||||
|
||||
object MSTRing : Ring<MST>, Space<MST> by MSTSpace {
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
}
|
||||
|
||||
object MSTField : Field<MST>, Ring<MST> by MSTRing {
|
||||
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
}
|
@ -1,265 +0,0 @@
|
||||
package scientifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.internal.AsmBuilder
|
||||
import scientifik.kmath.asm.internal.hasSpecific
|
||||
import scientifik.kmath.asm.internal.optimize
|
||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.expressions.ExpressionAlgebra
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
/**
|
||||
* A function declaration that could be compiled to [AsmBuilder].
|
||||
*
|
||||
* @param T the type the stored function returns.
|
||||
*/
|
||||
abstract class AsmExpression<T> internal constructor() {
|
||||
/**
|
||||
* Tries to evaluate this function without its variables. This method is intended for optimization.
|
||||
*
|
||||
* @return `null` if the function depends on its variables, the value if the function is a constant.
|
||||
*/
|
||||
internal open fun tryEvaluate(): T? = null
|
||||
|
||||
/**
|
||||
* Compiles this declaration.
|
||||
*
|
||||
* @param gen the target [AsmBuilder].
|
||||
*/
|
||||
@PublishedApi
|
||||
internal abstract fun compile(gen: AsmBuilder<T>)
|
||||
}
|
||||
|
||||
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) :
|
||||
AsmExpression<T>() {
|
||||
private val expr: AsmExpression<T> = expr.optimize()
|
||||
override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) }
|
||||
|
||||
override fun compile(gen: AsmBuilder<T>) {
|
||||
gen.loadAlgebra()
|
||||
|
||||
if (!hasSpecific(context, name, 1))
|
||||
gen.loadStringConstant(name)
|
||||
|
||||
expr.compile(gen)
|
||||
|
||||
if (gen.tryInvokeSpecific(context, name, 1))
|
||||
return
|
||||
|
||||
gen.invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "unaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};)" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmBinaryOperation<T>(
|
||||
private val context: Algebra<T>,
|
||||
private val name: String,
|
||||
first: AsmExpression<T>,
|
||||
second: AsmExpression<T>
|
||||
) : AsmExpression<T>() {
|
||||
private val first: AsmExpression<T> = first.optimize()
|
||||
private val second: AsmExpression<T> = second.optimize()
|
||||
|
||||
override fun tryEvaluate(): T? = context {
|
||||
binaryOperation(
|
||||
name,
|
||||
first.tryEvaluate() ?: return@context null,
|
||||
second.tryEvaluate() ?: return@context null
|
||||
)
|
||||
}
|
||||
|
||||
override fun compile(gen: AsmBuilder<T>) {
|
||||
gen.loadAlgebra()
|
||||
|
||||
if (!hasSpecific(context, name, 2))
|
||||
gen.loadStringConstant(name)
|
||||
|
||||
first.compile(gen)
|
||||
second.compile(gen)
|
||||
|
||||
if (gen.tryInvokeSpecific(context, name, 2))
|
||||
return
|
||||
|
||||
gen.invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "binaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};)" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) :
|
||||
AsmExpression<T>() {
|
||||
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadVariable(name, default)
|
||||
}
|
||||
|
||||
internal class AsmConstantExpression<T>(private val value: T) :
|
||||
AsmExpression<T>() {
|
||||
override fun tryEvaluate(): T = value
|
||||
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadTConstant(value)
|
||||
}
|
||||
|
||||
internal class AsmConstProductExpression<T>(
|
||||
private val context: Space<T>,
|
||||
expr: AsmExpression<T>,
|
||||
private val const: Number
|
||||
) : AsmExpression<T>() {
|
||||
private val expr: AsmExpression<T> = expr.optimize()
|
||||
|
||||
override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const }
|
||||
|
||||
override fun compile(gen: AsmBuilder<T>) {
|
||||
gen.loadAlgebra()
|
||||
gen.loadNumberConstant(const)
|
||||
expr.compile(gen)
|
||||
|
||||
gen.invokeAlgebraOperation(
|
||||
owner = AsmBuilder.SPACE_OPERATIONS_CLASS,
|
||||
method = "multiply",
|
||||
descriptor = "(L${AsmBuilder.OBJECT_CLASS};" +
|
||||
"L${AsmBuilder.NUMBER_CLASS};)" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmNumberExpression<T>(private val context: NumericAlgebra<T>, private val value: Number) :
|
||||
AsmExpression<T>() {
|
||||
override fun tryEvaluate(): T? = context.number(value)
|
||||
|
||||
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadNumberConstant(value)
|
||||
}
|
||||
|
||||
internal abstract class FunctionalCompiledExpression<T> internal constructor(
|
||||
@JvmField protected val algebra: Algebra<T>,
|
||||
@JvmField protected val constants: Array<Any>
|
||||
) : Expression<T> {
|
||||
abstract override fun invoke(arguments: Map<String, T>): T
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [AsmExpression] construction.
|
||||
*/
|
||||
interface AsmExpressionAlgebra<T, A : NumericAlgebra<T>> : NumericAlgebra<AsmExpression<T>>,
|
||||
ExpressionAlgebra<T, AsmExpression<T>> {
|
||||
/**
|
||||
* The algebra to provide for AsmExpressions built.
|
||||
*/
|
||||
val algebra: A
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression to wrap a number.
|
||||
*/
|
||||
override fun number(value: Number): AsmExpression<T> = AsmNumberExpression(algebra, value)
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression of constant expression which does not depend on arguments.
|
||||
*/
|
||||
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression to access a variable.
|
||||
*/
|
||||
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right].
|
||||
*/
|
||||
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmBinaryOperation(algebra, operation, left, right)
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg].
|
||||
*/
|
||||
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmUnaryOperation(algebra, operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [AsmExpression] construction for [Space] algebras.
|
||||
*/
|
||||
open class AsmExpressionSpace<T, A>(override val algebra: A) : AsmExpressionAlgebra<T, A>,
|
||||
Space<AsmExpression<T>> where A : Space<T>, A : NumericAlgebra<T> {
|
||||
override val zero: AsmExpression<T>
|
||||
get() = const(algebra.zero)
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression of addition of two another expressions.
|
||||
*/
|
||||
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression of multiplication of expression by number.
|
||||
*/
|
||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(algebra, a, k)
|
||||
|
||||
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
|
||||
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
|
||||
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
|
||||
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||
super<AsmExpressionAlgebra>.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||
super<AsmExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [AsmExpression] construction for [Ring] algebras.
|
||||
*/
|
||||
open class AsmExpressionRing<T, A>(override val algebra: A) : AsmExpressionSpace<T, A>(algebra),
|
||||
Ring<AsmExpression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||
override val one: AsmExpression<T>
|
||||
get() = const(algebra.one)
|
||||
|
||||
/**
|
||||
* Builds an AsmExpression of multiplication of two expressions.
|
||||
*/
|
||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
|
||||
|
||||
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
|
||||
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||
super<AsmExpressionSpace>.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||
super<AsmExpressionSpace>.binaryOperation(operation, left, right)
|
||||
|
||||
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionSpace>.number(value)
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [AsmExpression] construction for [Field] algebras.
|
||||
*/
|
||||
open class AsmExpressionField<T, A>(override val algebra: A) :
|
||||
AsmExpressionRing<T, A>(algebra),
|
||||
Field<AsmExpression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
||||
/**
|
||||
* Builds an AsmExpression of division an expression by another one.
|
||||
*/
|
||||
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
|
||||
|
||||
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
|
||||
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||
super<AsmExpressionRing>.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||
super<AsmExpressionRing>.binaryOperation(operation, left, right)
|
||||
|
||||
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionRing>.number(value)
|
||||
}
|
@ -1,53 +1,103 @@
|
||||
package scientifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.internal.AsmBuilder
|
||||
import scientifik.kmath.asm.internal.hasSpecific
|
||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||
import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.ast.evaluate
|
||||
import scientifik.kmath.ast.MSTField
|
||||
import scientifik.kmath.ast.MSTRing
|
||||
import scientifik.kmath.ast.MSTSpace
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.*
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
@PublishedApi
|
||||
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
/**
|
||||
* Compile given MST to an Expression using AST compiler
|
||||
*/
|
||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||
|
||||
fun buildName(mst: MST, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(mst, collision + 1)
|
||||
}
|
||||
|
||||
return buildName(expression, collision + 1)
|
||||
fun AsmBuilder<T>.visit(node: MST): Unit {
|
||||
when (node) {
|
||||
is MST.Symbolic -> visitLoadFromVariables(node.value)
|
||||
is MST.Numeric -> {
|
||||
val constant = if (algebra is NumericAlgebra<T>) {
|
||||
algebra.number(node.value)
|
||||
} else {
|
||||
error("Number literals are not supported in $algebra")
|
||||
}
|
||||
visitLoadFromConstants(constant)
|
||||
}
|
||||
is MST.Unary -> {
|
||||
visitLoadAlgebra()
|
||||
|
||||
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
|
||||
|
||||
visit(node.value)
|
||||
|
||||
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
|
||||
visitAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "unaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};)" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
is MST.Binary -> {
|
||||
visitLoadAlgebra()
|
||||
|
||||
if (!hasSpecific(algebra, node.operation, 2))
|
||||
visitStringConstant(node.operation)
|
||||
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
|
||||
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
|
||||
|
||||
visitAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "binaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};)" +
|
||||
"L${AsmBuilder.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val builder = AsmBuilder(type.java, algebra, buildName(this))
|
||||
builder.visit(this)
|
||||
return builder.generate()
|
||||
}
|
||||
|
||||
@PublishedApi
|
||||
internal inline fun <reified T> AsmExpression<T>.compile(algebra: Algebra<T>): Expression<T> =
|
||||
AsmBuilder(T::class.java, algebra, buildName(this), this).getInstance()
|
||||
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||
|
||||
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
|
||||
expressionAlgebra: E,
|
||||
block: E.() -> AsmExpression<T>
|
||||
): Expression<T> = expressionAlgebra.block().compile(expressionAlgebra.algebra)
|
||||
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.asm(
|
||||
mstAlgebra: E,
|
||||
block: E.() -> MST
|
||||
): Expression<T> = mstAlgebra.block().compileWith(T::class, this)
|
||||
|
||||
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
|
||||
expressionAlgebra: E,
|
||||
ast: MST
|
||||
): Expression<T> = asm(expressionAlgebra) { evaluate(ast) }
|
||||
inline fun <reified T : Any, A : Space<T>> A.asmInSpace(block: MSTSpace.() -> MST): Expression<T> =
|
||||
MSTSpace.block().compileWith(T::class, this)
|
||||
|
||||
inline fun <reified T, A> A.asmSpace(block: AsmExpressionSpace<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
|
||||
AsmExpressionSpace(this).let { it.block().compile(it.algebra) }
|
||||
|
||||
inline fun <reified T, A> A.asmSpace(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
|
||||
asmSpace { evaluate(ast) }
|
||||
|
||||
inline fun <reified T, A> A.asmRing(block: AsmExpressionRing<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
|
||||
AsmExpressionRing(this).let { it.block().compile(it.algebra) }
|
||||
|
||||
inline fun <reified T, A> A.asmRing(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
|
||||
asmRing { evaluate(ast) }
|
||||
|
||||
inline fun <reified T, A> A.asmField(block: AsmExpressionField<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
|
||||
AsmExpressionField(this).let { it.block().compile(it.algebra) }
|
||||
|
||||
inline fun <reified T, A> A.asmField(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
|
||||
asmRing { evaluate(ast) }
|
||||
inline fun <reified T : Any, A : Ring<T>> A.asmInRing(block: MSTRing.() -> MST): Expression<T> =
|
||||
MSTRing.block().compileWith(T::class, this)
|
||||
|
||||
inline fun <reified T : Any, A : Field<T>> A.asmInField(block: MSTField.() -> MST): Expression<T> =
|
||||
MSTField.block().compileWith(T::class, this)
|
||||
|
@ -1,8 +1,6 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import org.objectweb.asm.Opcodes
|
||||
import scientifik.kmath.asm.AsmConstantExpression
|
||||
import scientifik.kmath.asm.AsmExpression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
|
||||
@ -31,7 +29,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
||||
append("L${AsmBuilder.OBJECT_CLASS};")
|
||||
}
|
||||
|
||||
invokeAlgebraOperation(
|
||||
visitAlgebraOperation(
|
||||
owner = owner,
|
||||
method = aName,
|
||||
descriptor = sig,
|
||||
@ -41,9 +39,8 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@PublishedApi
|
||||
internal fun <T> AsmExpression<T>.optimize(): AsmExpression<T> {
|
||||
val a = tryEvaluate()
|
||||
return if (a == null) this else AsmConstantExpression(a)
|
||||
}
|
||||
//
|
||||
//internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
|
||||
// val a = tryEvaluate()
|
||||
// return if (a == null) this else AsmConstantExpression(type, algebra, a)
|
||||
//}
|
||||
|
@ -1,75 +1,66 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.asmField
|
||||
import scientifik.kmath.asm.asmRing
|
||||
import scientifik.kmath.asm.asmSpace
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.ByteRing
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class TestAsmAlgebras {
|
||||
@Test
|
||||
fun space() {
|
||||
val res = ByteRing.asmSpace {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
3.toByte() - (2.toByte() + (multiply(
|
||||
add(const(1), const(1)),
|
||||
2
|
||||
) + 1.toByte()) * 3.toByte() - 1.toByte())
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + variable("x") + zero
|
||||
}("x" to 2.toByte())
|
||||
|
||||
assertEquals(16, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun ring() {
|
||||
val res = ByteRing.asmRing {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
(3.toByte() - (2.toByte() + (multiply(
|
||||
add(const(1), const(1)),
|
||||
2
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
) * const(2)
|
||||
}()
|
||||
|
||||
assertEquals(24, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun field() {
|
||||
val res = RealField.asmField {
|
||||
divide(binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
(3.0 - (2.0 + (multiply(
|
||||
add(const(1.0), const(1.0)),
|
||||
2
|
||||
) + 1.0))) * 3 - 1.0
|
||||
),
|
||||
|
||||
number(1)
|
||||
) / 2, const(2.0)) * one
|
||||
}()
|
||||
|
||||
assertEquals(3.0, res)
|
||||
}
|
||||
}
|
||||
//
|
||||
//class TestAsmAlgebras {
|
||||
// @Test
|
||||
// fun space() {
|
||||
// val res = ByteRing.asmInRing {
|
||||
// binaryOperation(
|
||||
// "+",
|
||||
//
|
||||
// unaryOperation(
|
||||
// "+",
|
||||
// 3.toByte() - (2.toByte() + (multiply(
|
||||
// add(number(1), number(1)),
|
||||
// 2
|
||||
// ) + 1.toByte()) * 3.toByte() - 1.toByte())
|
||||
// ),
|
||||
//
|
||||
// number(1)
|
||||
// ) + symbol("x") + zero
|
||||
// }("x" to 2.toByte())
|
||||
//
|
||||
// assertEquals(16, res)
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// fun ring() {
|
||||
// val res = ByteRing.asmInRing {
|
||||
// binaryOperation(
|
||||
// "+",
|
||||
//
|
||||
// unaryOperation(
|
||||
// "+",
|
||||
// (3.toByte() - (2.toByte() + (multiply(
|
||||
// add(const(1), const(1)),
|
||||
// 2
|
||||
// ) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
// ),
|
||||
//
|
||||
// number(1)
|
||||
// ) * const(2)
|
||||
// }()
|
||||
//
|
||||
// assertEquals(24, res)
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// fun field() {
|
||||
// val res = RealField.asmInField {
|
||||
// +(3 - 2 + 2*(number(1)+1.0)
|
||||
//
|
||||
// unaryOperation(
|
||||
// "+",
|
||||
// (3.0 - (2.0 + (multiply(
|
||||
// add((1.0), const(1.0)),
|
||||
// 2
|
||||
// ) + 1.0))) * 3 - 1.0
|
||||
// )+
|
||||
//
|
||||
// number(1)
|
||||
// ) / 2, const(2.0)) * one
|
||||
// }()
|
||||
//
|
||||
// assertEquals(3.0, res)
|
||||
// }
|
||||
//}
|
||||
|
@ -1,6 +1,6 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.asmField
|
||||
import scientifik.kmath.asm.asmInField
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
@ -9,13 +9,13 @@ import kotlin.test.assertEquals
|
||||
class TestAsmExpressions {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0)
|
||||
assertEquals(2.0, res)
|
||||
val res = RealField.asmInField { -symbol("x") }("x" to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = RealField.asmField { variable("x") * 2 }("x" to 2.0)
|
||||
val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +1,18 @@
|
||||
package scietifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.asm.asmField
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.ast.parseMath
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.ComplexField
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class AsmTest {
|
||||
@Test
|
||||
fun parsedExpression() {
|
||||
val mst = "2+2*(2+2)".parseMath()
|
||||
val res = ComplexField.asmField(mst)()
|
||||
val res = ComplexField.compile(mst)()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
}
|
||||
|
@ -74,10 +74,10 @@ class DerivativeStructureField(
|
||||
|
||||
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
||||
|
||||
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble())
|
||||
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble())
|
||||
operator fun Number.plus(s: DerivativeStructure) = s + this
|
||||
operator fun Number.minus(s: DerivativeStructure) = s - this
|
||||
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
||||
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||
override operator fun Number.plus(b: DerivativeStructure) = b + this
|
||||
override operator fun Number.minus(b: DerivativeStructure) = b - this
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -90,20 +90,20 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
operator fun Number.plus(that: Variable<T>): Variable<T> =
|
||||
derive(variable { this@plus.toDouble() * one + that.value }) { z ->
|
||||
that.d += z.d
|
||||
override operator fun Number.plus(b: Variable<T>): Variable<T> =
|
||||
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||
|
||||
operator fun Number.minus(that: Variable<T>): Variable<T> =
|
||||
derive(variable { this@minus.toDouble() * one - that.value }) { z ->
|
||||
that.d -= z.d
|
||||
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
||||
derive(variable { this@minus.toDouble() * one - b.value }) { z ->
|
||||
b.d -= z.d
|
||||
}
|
||||
|
||||
operator fun Variable<T>.minus(that: Number): Variable<T> =
|
||||
derive(variable { this@minus.value - one * that.toDouble() }) { z ->
|
||||
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
||||
derive(variable { this@minus.value - one * b.toDouble() }) { z ->
|
||||
this@minus.d += z.d
|
||||
}
|
||||
}
|
||||
|
@ -12,9 +12,6 @@ interface Algebra<T> {
|
||||
*/
|
||||
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
|
||||
|
||||
@Deprecated("Symbol is more concise",replaceWith = ReplaceWith("symbol"))
|
||||
fun raw(value: String): T = symbol(value)
|
||||
|
||||
/**
|
||||
* Dynamic call of unary operation with name [operation] on [arg]
|
||||
*/
|
||||
@ -64,6 +61,8 @@ interface SpaceOperations<T> : Algebra<T> {
|
||||
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
|
||||
operator fun T.unaryMinus(): T = multiply(this, -1.0)
|
||||
|
||||
operator fun T.unaryPlus(): T = this
|
||||
|
||||
operator fun T.plus(b: T): T = add(this, b)
|
||||
operator fun T.minus(b: T): T = add(this, -b)
|
||||
operator fun T.times(k: Number) = multiply(this, k.toDouble())
|
||||
@ -138,17 +137,25 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||
override fun number(value: Number): T = one * value.toDouble()
|
||||
|
||||
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.leftSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
//TODO those operators are blocked by type conflict in RealField
|
||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
// operator fun T.plus(b: Number) = this.plus(b * one)
|
||||
// operator fun Number.plus(b: T) = b + this
|
||||
//
|
||||
// operator fun T.minus(b: Number) = this.minus(b * one)
|
||||
// operator fun Number.minus(b: T) = -b + this
|
||||
|
||||
operator fun T.plus(b: Number) = this.plus(number(b))
|
||||
operator fun Number.plus(b: T) = b + this
|
||||
|
||||
operator fun T.minus(b: Number) = this.minus(number(b))
|
||||
operator fun Number.minus(b: T) = -b + this
|
||||
}
|
||||
|
||||
/**
|
||||
|
Loading…
x
Reference in New Issue
Block a user