diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt index 5c93fd729..da788372b 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt @@ -3,13 +3,13 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space -open class AsmExpressionSpace(space: Space) : Space>, +open class AsmExpressionSpace(private val space: Space) : Space>, ExpressionSpace> { override val zero: AsmExpression = AsmConstantExpression(space.zero) override fun const(value: T): AsmExpression = AsmConstantExpression(value) override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(space, a, b) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this @@ -22,8 +22,11 @@ class AsmExpressionField(private val field: Field) : ExpressionField = const(field.run { one * value }) - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(a, b) - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(a, b) + + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmProductExpression(field, a, b) + + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(field, a, b) operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) operator fun T.times(arg: AsmExpression): AsmExpression = arg * this diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 8708a33e1..cc861d363 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -1,6 +1,6 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.* abstract class AsmCompiledExpression internal constructor( @JvmField private val algebra: Algebra, @@ -10,6 +10,7 @@ abstract class AsmCompiledExpression internal constructor( } interface AsmExpression { + fun tryEvaluate(): T? = null fun invoke(gen: AsmGenerationContext) } @@ -20,13 +21,22 @@ internal class AsmVariableExpression(val name: String, val default: T? = null internal class AsmConstantExpression(val value: T) : AsmExpression { + override fun tryEvaluate(): T = value override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } internal class AsmSumExpression( - val first: AsmExpression, - val second: AsmExpression + private val algebra: SpaceOperations, + first: AsmExpression, + second: AsmExpression ) : AsmExpression { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = algebra { + (first.tryEvaluate() ?: return@algebra null) + (second.tryEvaluate() ?: return@algebra null) + } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() first.invoke(gen) @@ -41,9 +51,17 @@ internal class AsmSumExpression( } internal class AsmProductExpression( - val first: AsmExpression, - val second: AsmExpression + private val algebra: RingOperations, + first: AsmExpression, + second: AsmExpression ) : AsmExpression { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = algebra { + (first.tryEvaluate() ?: return@algebra null) * (second.tryEvaluate() ?: return@algebra null) + } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() first.invoke(gen) @@ -58,9 +76,14 @@ internal class AsmProductExpression( } internal class AsmConstProductExpression( - val expr: AsmExpression, - val const: Number + private val algebra: SpaceOperations, + expr: AsmExpression, + private val const: Number ) : AsmExpression { + private val expr: AsmExpression = expr.optimize() + + override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() gen.visitNumberConstant(const) @@ -75,9 +98,17 @@ internal class AsmConstProductExpression( } internal class AsmDivExpression( - val expr: AsmExpression, - val second: AsmExpression + private val algebra: FieldOperations, + expr: AsmExpression, + second: AsmExpression ) : AsmExpression { + private val expr: AsmExpression = expr.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = algebra { + (expr.tryEvaluate() ?: return@algebra null) / (second.tryEvaluate() ?: return@algebra null) + } + override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() expr.invoke(gen) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt index bdbab2b42..f8b4afd5b 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt @@ -1,5 +1,6 @@ package scientifik.kmath.expressions +import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space @@ -17,22 +18,19 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String } -inline fun asmSpace( - algebra: Space, - block: AsmExpressionSpace.() -> AsmExpression -): Expression { - val expression = AsmExpressionSpace(algebra).block() +inline fun asm(i: I, algebra: Algebra, block: I.() -> AsmExpression): Expression { + val expression = i.block().optimize() val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) expression.invoke(ctx) return ctx.generate() } +inline fun asmSpace( + algebra: Space, + block: AsmExpressionSpace.() -> AsmExpression +): Expression = asm(AsmExpressionSpace(algebra), algebra, block) + inline fun asmField( algebra: Field, block: AsmExpressionField.() -> AsmExpression -): Expression { - val expression = AsmExpressionField(algebra).block() - val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) - expression.invoke(ctx) - return ctx.generate() -} +): Expression = asm(AsmExpressionField(algebra), algebra, block) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt new file mode 100644 index 000000000..bb7d0476d --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt @@ -0,0 +1,6 @@ +package scientifik.kmath.expressions + +fun AsmExpression.optimize(): AsmExpression { + val a = tryEvaluate() + return if (a == null) this else AsmConstantExpression(a) +}