forked from kscience/kmath
Implement recursive constants evaluation, improve builders
This commit is contained in:
parent
013030951e
commit
b7d1fe2560
@ -3,13 +3,13 @@ package scientifik.kmath.expressions
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
open class AsmExpressionSpace<T>(space: Space<T>) : Space<AsmExpression<T>>,
|
open class AsmExpressionSpace<T>(private val space: Space<T>) : Space<AsmExpression<T>>,
|
||||||
ExpressionSpace<T, AsmExpression<T>> {
|
ExpressionSpace<T, AsmExpression<T>> {
|
||||||
override val zero: AsmExpression<T> = AsmConstantExpression(space.zero)
|
override val zero: AsmExpression<T> = AsmConstantExpression(space.zero)
|
||||||
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
||||||
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
||||||
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmSumExpression(a, b)
|
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmSumExpression(space, a, b)
|
||||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(a, k)
|
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(space, a, k)
|
||||||
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
|
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
|
||||||
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
|
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
|
||||||
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
|
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
|
||||||
@ -22,8 +22,11 @@ class AsmExpressionField<T>(private val field: Field<T>) : ExpressionField<T, As
|
|||||||
get() = const(this.field.one)
|
get() = const(this.field.one)
|
||||||
|
|
||||||
override fun number(value: Number): AsmExpression<T> = const(field.run { one * value })
|
override fun number(value: Number): AsmExpression<T> = const(field.run { one * value })
|
||||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmProductExpression(a, b)
|
|
||||||
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmDivExpression(a, b)
|
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmProductExpression(field, a, b)
|
||||||
|
|
||||||
|
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmDivExpression(field, a, b)
|
||||||
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
|
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
|
||||||
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
|
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
|
||||||
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
|
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
abstract class AsmCompiledExpression<T> internal constructor(
|
abstract class AsmCompiledExpression<T> internal constructor(
|
||||||
@JvmField private val algebra: Algebra<T>,
|
@JvmField private val algebra: Algebra<T>,
|
||||||
@ -10,6 +10,7 @@ abstract class AsmCompiledExpression<T> internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
interface AsmExpression<T> {
|
interface AsmExpression<T> {
|
||||||
|
fun tryEvaluate(): T? = null
|
||||||
fun invoke(gen: AsmGenerationContext<T>)
|
fun invoke(gen: AsmGenerationContext<T>)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,13 +21,22 @@ internal class AsmVariableExpression<T>(val name: String, val default: T? = null
|
|||||||
|
|
||||||
internal class AsmConstantExpression<T>(val value: T) :
|
internal class AsmConstantExpression<T>(val value: T) :
|
||||||
AsmExpression<T> {
|
AsmExpression<T> {
|
||||||
|
override fun tryEvaluate(): T = value
|
||||||
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
|
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AsmSumExpression<T>(
|
internal class AsmSumExpression<T>(
|
||||||
val first: AsmExpression<T>,
|
private val algebra: SpaceOperations<T>,
|
||||||
val second: AsmExpression<T>
|
first: AsmExpression<T>,
|
||||||
|
second: AsmExpression<T>
|
||||||
) : AsmExpression<T> {
|
) : AsmExpression<T> {
|
||||||
|
private val first: AsmExpression<T> = first.optimize()
|
||||||
|
private val second: AsmExpression<T> = second.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = algebra {
|
||||||
|
(first.tryEvaluate() ?: return@algebra null) + (second.tryEvaluate() ?: return@algebra null)
|
||||||
|
}
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
gen.visitLoadAlgebra()
|
gen.visitLoadAlgebra()
|
||||||
first.invoke(gen)
|
first.invoke(gen)
|
||||||
@ -41,9 +51,17 @@ internal class AsmSumExpression<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal class AsmProductExpression<T>(
|
internal class AsmProductExpression<T>(
|
||||||
val first: AsmExpression<T>,
|
private val algebra: RingOperations<T>,
|
||||||
val second: AsmExpression<T>
|
first: AsmExpression<T>,
|
||||||
|
second: AsmExpression<T>
|
||||||
) : AsmExpression<T> {
|
) : AsmExpression<T> {
|
||||||
|
private val first: AsmExpression<T> = first.optimize()
|
||||||
|
private val second: AsmExpression<T> = second.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = algebra {
|
||||||
|
(first.tryEvaluate() ?: return@algebra null) * (second.tryEvaluate() ?: return@algebra null)
|
||||||
|
}
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
gen.visitLoadAlgebra()
|
gen.visitLoadAlgebra()
|
||||||
first.invoke(gen)
|
first.invoke(gen)
|
||||||
@ -58,9 +76,14 @@ internal class AsmProductExpression<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal class AsmConstProductExpression<T>(
|
internal class AsmConstProductExpression<T>(
|
||||||
val expr: AsmExpression<T>,
|
private val algebra: SpaceOperations<T>,
|
||||||
val const: Number
|
expr: AsmExpression<T>,
|
||||||
|
private val const: Number
|
||||||
) : AsmExpression<T> {
|
) : AsmExpression<T> {
|
||||||
|
private val expr: AsmExpression<T> = expr.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const }
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
gen.visitLoadAlgebra()
|
gen.visitLoadAlgebra()
|
||||||
gen.visitNumberConstant(const)
|
gen.visitNumberConstant(const)
|
||||||
@ -75,9 +98,17 @@ internal class AsmConstProductExpression<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal class AsmDivExpression<T>(
|
internal class AsmDivExpression<T>(
|
||||||
val expr: AsmExpression<T>,
|
private val algebra: FieldOperations<T>,
|
||||||
val second: AsmExpression<T>
|
expr: AsmExpression<T>,
|
||||||
|
second: AsmExpression<T>
|
||||||
) : AsmExpression<T> {
|
) : AsmExpression<T> {
|
||||||
|
private val expr: AsmExpression<T> = expr.optimize()
|
||||||
|
private val second: AsmExpression<T> = second.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = algebra {
|
||||||
|
(expr.tryEvaluate() ?: return@algebra null) / (second.tryEvaluate() ?: return@algebra null)
|
||||||
|
}
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
gen.visitLoadAlgebra()
|
gen.visitLoadAlgebra()
|
||||||
expr.invoke(gen)
|
expr.invoke(gen)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
@ -17,22 +18,19 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
inline fun <reified T> asmSpace(
|
inline fun <reified T, I> asm(i: I, algebra: Algebra<T>, block: I.() -> AsmExpression<T>): Expression<T> {
|
||||||
algebra: Space<T>,
|
val expression = i.block().optimize()
|
||||||
block: AsmExpressionSpace<T>.() -> AsmExpression<T>
|
|
||||||
): Expression<T> {
|
|
||||||
val expression = AsmExpressionSpace(algebra).block()
|
|
||||||
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression))
|
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression))
|
||||||
expression.invoke(ctx)
|
expression.invoke(ctx)
|
||||||
return ctx.generate()
|
return ctx.generate()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline fun <reified T> asmSpace(
|
||||||
|
algebra: Space<T>,
|
||||||
|
block: AsmExpressionSpace<T>.() -> AsmExpression<T>
|
||||||
|
): Expression<T> = asm(AsmExpressionSpace(algebra), algebra, block)
|
||||||
|
|
||||||
inline fun <reified T> asmField(
|
inline fun <reified T> asmField(
|
||||||
algebra: Field<T>,
|
algebra: Field<T>,
|
||||||
block: AsmExpressionField<T>.() -> AsmExpression<T>
|
block: AsmExpressionField<T>.() -> AsmExpression<T>
|
||||||
): Expression<T> {
|
): Expression<T> = asm(AsmExpressionField(algebra), algebra, block)
|
||||||
val expression = AsmExpressionField(algebra).block()
|
|
||||||
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression))
|
|
||||||
expression.invoke(ctx)
|
|
||||||
return ctx.generate()
|
|
||||||
}
|
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
fun <T> AsmExpression<T>.optimize(): AsmExpression<T> {
|
||||||
|
val a = tryEvaluate()
|
||||||
|
return if (a == null) this else AsmConstantExpression(a)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user