Rename and refactor MstAlgebra (ex-MSTAlgebra) (and its subclasses), MstExpression (ex-MSTExpression)
This commit is contained in:
parent
0ee1d31571
commit
d962ab4d11
@ -1,55 +0,0 @@
|
|||||||
package scientifik.kmath.ast
|
|
||||||
|
|
||||||
import scientifik.kmath.expressions.Expression
|
|
||||||
import scientifik.kmath.expressions.FunctionalExpressionField
|
|
||||||
import scientifik.kmath.expressions.FunctionalExpressionRing
|
|
||||||
import scientifik.kmath.expressions.FunctionalExpressionSpace
|
|
||||||
import scientifik.kmath.operations.*
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
|
||||||
*/
|
|
||||||
class MSTExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Substitute algebra raw value
|
|
||||||
*/
|
|
||||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T>{
|
|
||||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right)
|
|
||||||
|
|
||||||
override fun number(value: Number): T = if(algebra is NumericAlgebra){
|
|
||||||
algebra.number(value)
|
|
||||||
} else{
|
|
||||||
error("Numeric nodes are not supported by $this")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
|
||||||
mstAlgebra: E,
|
|
||||||
block: E.() -> MST
|
|
||||||
): MSTExpression<T> = MSTExpression(this, mstAlgebra.block())
|
|
||||||
|
|
||||||
inline fun <reified T : Any> Space<T>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
|
|
||||||
MSTExpression(this, MSTSpace.block())
|
|
||||||
|
|
||||||
inline fun <reified T : Any> Ring<T>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
|
|
||||||
MSTExpression(this, MSTRing.block())
|
|
||||||
|
|
||||||
inline fun <reified T : Any> Field<T>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
|
|
||||||
MSTExpression(this, MSTField.block())
|
|
||||||
|
|
||||||
inline fun <reified T: Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
|
|
||||||
algebra.mstInSpace(block)
|
|
||||||
|
|
||||||
inline fun <reified T: Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
|
|
||||||
algebra.mstInRing(block)
|
|
||||||
|
|
||||||
inline fun <reified T: Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
|
|
||||||
algebra.mstInField(block)
|
|
@ -2,7 +2,7 @@ package scientifik.kmath.ast
|
|||||||
|
|
||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
object MSTAlgebra : NumericAlgebra<MST> {
|
object MstAlgebra : NumericAlgebra<MST> {
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
override fun number(value: Number): MST = MST.Numeric(value)
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||||
@ -14,12 +14,11 @@ object MSTAlgebra : NumericAlgebra<MST> {
|
|||||||
MST.Binary(operation, left, right)
|
MST.Binary(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
object MSTSpace : Space<MST>, NumericAlgebra<MST> {
|
object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST = number(0.0)
|
override val zero: MST = number(0.0)
|
||||||
|
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||||
|
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
|
||||||
|
|
||||||
override fun add(a: MST, b: MST): MST =
|
override fun add(a: MST, b: MST): MST =
|
||||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
@ -28,46 +27,46 @@ object MSTSpace : Space<MST>, NumericAlgebra<MST> {
|
|||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
MSTAlgebra.binaryOperation(operation, left, right)
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
object MSTRing : Ring<MST>, NumericAlgebra<MST> {
|
object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST = MSTSpace.number(0.0)
|
override val zero: MST = number(0.0)
|
||||||
override val one: MST = number(1.0)
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST =
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
|
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
MSTAlgebra.binaryOperation(operation, left, right)
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
object MSTField : Field<MST> {
|
object MstField : Field<MST> {
|
||||||
override val zero: MST = MSTSpace.number(0.0)
|
override val zero: MST = number(0.0)
|
||||||
override val one: MST = number(1.0)
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST =
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
|
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
MSTAlgebra.binaryOperation(operation, left, right)
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
}
|
}
|
@ -0,0 +1,55 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionField
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionRing
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionSpace
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
||||||
|
*/
|
||||||
|
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Substitute algebra raw value
|
||||||
|
*/
|
||||||
|
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
||||||
|
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
||||||
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
|
algebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
||||||
|
algebra.number(value)
|
||||||
|
else
|
||||||
|
error("Numeric nodes are not supported by $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
|
mstAlgebra: E,
|
||||||
|
block: E.() -> MST
|
||||||
|
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstSpace.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstRing.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstField.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||||
|
algebra.mstInSpace(block)
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||||
|
algebra.mstInRing(block)
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||||
|
algebra.mstInField(block)
|
@ -6,7 +6,7 @@ import scientifik.kmath.asm.internal.buildExpectationStack
|
|||||||
import scientifik.kmath.asm.internal.buildName
|
import scientifik.kmath.asm.internal.buildName
|
||||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
import scientifik.kmath.ast.MSTExpression
|
import scientifik.kmath.ast.MstExpression
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
import scientifik.kmath.operations.NumericAlgebra
|
||||||
@ -80,6 +80,6 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize performance of an [MSTExpression] using ASM codegen
|
* Optimize performance of an [MstExpression] using ASM codegen
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
||||||
|
Loading…
Reference in New Issue
Block a user