forked from kscience/kmath
Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-improved-trigonometry
This commit is contained in:
commit
41a81e7a14
@ -7,6 +7,11 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
|
// all {
|
||||||
|
// languageSettings.apply{
|
||||||
|
// enableLanguageFeature("NewInference")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
|
@ -1,33 +1,76 @@
|
|||||||
@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE")
|
|
||||||
|
|
||||||
package scientifik.kmath.ast
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
object MSTAlgebra : NumericAlgebra<MST> {
|
object MSTAlgebra : NumericAlgebra<MST> {
|
||||||
|
override fun number(value: Number): MST = MST.Numeric(value)
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST =
|
||||||
|
MST.Unary(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right)
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MST.Binary(operation, left, right)
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
object MSTSpace : Space<MST>, NumericAlgebra<MST> by MSTAlgebra {
|
object MSTSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST = number(0.0)
|
override val zero: MST = number(0.0)
|
||||||
|
|
||||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
override fun number(value: Number): MST = MST.Numeric(value)
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||||
|
|
||||||
|
override fun add(a: MST, b: MST): MST =
|
||||||
|
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
|
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MSTAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
object MSTRing : Ring<MST>, Space<MST> by MSTSpace {
|
object MSTRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
|
override fun number(value: Number): MST = MST.Numeric(value)
|
||||||
|
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||||
|
|
||||||
|
override val zero: MST = MSTSpace.number(0.0)
|
||||||
override val one: MST = number(1.0)
|
override val one: MST = number(1.0)
|
||||||
|
override fun add(a: MST, b: MST): MST =
|
||||||
|
MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
|
MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
|
||||||
|
|
||||||
|
override fun multiply(a: MST, b: MST): MST =
|
||||||
|
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MSTAlgebra.binaryOperation(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
object MSTField : Field<MST>, Ring<MST> by MSTRing {
|
object MSTField : Field<MST>{
|
||||||
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||||
|
override fun number(value: Number): MST = MST.Numeric(value)
|
||||||
|
|
||||||
|
override val zero: MST = MSTSpace.number(0.0)
|
||||||
|
override val one: MST = number(1.0)
|
||||||
|
override fun add(a: MST, b: MST): MST =
|
||||||
|
MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
|
MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
|
||||||
|
|
||||||
|
override fun multiply(a: MST, b: MST): MST =
|
||||||
|
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun divide(a: MST, b: MST): MST =
|
||||||
|
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MSTAlgebra.binaryOperation(operation, left, right)
|
||||||
}
|
}
|
@ -1,19 +1,55 @@
|
|||||||
package scientifik.kmath.ast
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
import scientifik.kmath.expressions.FunctionalExpressionField
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionRing
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionSpace
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
||||||
*/
|
*/
|
||||||
class MSTExpression<T>(val algebra: NumericAlgebra<T>, val mst: MST) : Expression<T> {
|
class MSTExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Substitute algebra raw value
|
* Substitute algebra raw value
|
||||||
*/
|
*/
|
||||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> by algebra {
|
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T>{
|
||||||
override fun symbol(value: String): T = arguments[value] ?: super.symbol(value)
|
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
||||||
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun number(value: Number): T = if(algebra is NumericAlgebra){
|
||||||
|
algebra.number(value)
|
||||||
|
} else{
|
||||||
|
error("Numeric nodes are not supported by $this")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
|
mstAlgebra: E,
|
||||||
|
block: E.() -> MST
|
||||||
|
): MSTExpression<T> = MSTExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Space<T>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
|
||||||
|
MSTExpression(this, MSTSpace.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Ring<T>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
|
||||||
|
MSTExpression(this, MSTRing.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Field<T>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
|
||||||
|
MSTExpression(this, MSTField.block())
|
||||||
|
|
||||||
|
inline fun <reified T: Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
|
||||||
|
algebra.mstInSpace(block)
|
||||||
|
|
||||||
|
inline fun <reified T: Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
|
||||||
|
algebra.mstInRing(block)
|
||||||
|
|
||||||
|
inline fun <reified T: Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
|
||||||
|
algebra.mstInField(block)
|
@ -5,11 +5,10 @@ import scientifik.kmath.asm.internal.buildName
|
|||||||
import scientifik.kmath.asm.internal.hasSpecific
|
import scientifik.kmath.asm.internal.hasSpecific
|
||||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
import scientifik.kmath.ast.MSTField
|
import scientifik.kmath.ast.MSTExpression
|
||||||
import scientifik.kmath.ast.MSTRing
|
|
||||||
import scientifik.kmath.ast.MSTSpace
|
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import scientifik.kmath.operations.NumericAlgebra
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -71,18 +70,12 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
/**
|
||||||
|
* Compile an [MST] to ASM using given algebra
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Algebra<T>.expresion(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.asm(
|
/**
|
||||||
mstAlgebra: E,
|
* Optimize performance of an [MSTExpression] using ASM codegen
|
||||||
block: E.() -> MST
|
*/
|
||||||
): Expression<T> = mstAlgebra.block().compileWith(T::class, this)
|
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Space<T>> A.asmInSpace(block: MSTSpace.() -> MST): Expression<T> =
|
|
||||||
MSTSpace.block().compileWith(T::class, this)
|
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Ring<T>> A.asmInRing(block: MSTRing.() -> MST): Expression<T> =
|
|
||||||
MSTRing.block().compileWith(T::class, this)
|
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Field<T>> A.asmInField(block: MSTField.() -> MST): Expression<T> =
|
|
||||||
MSTField.block().compileWith(T::class, this)
|
|
@ -1,6 +1,8 @@
|
|||||||
package scietifik.kmath.asm
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
import scientifik.kmath.asm.asmInField
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.ast.mstInSpace
|
||||||
import scientifik.kmath.expressions.invoke
|
import scientifik.kmath.expressions.invoke
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -9,13 +11,14 @@ import kotlin.test.assertEquals
|
|||||||
class TestAsmExpressions {
|
class TestAsmExpressions {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryOperationInvocation() {
|
fun testUnaryOperationInvocation() {
|
||||||
val res = RealField.asmInField { -symbol("x") }("x" to 2.0)
|
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||||
|
val res = expression("x" to 2.0)
|
||||||
assertEquals(-2.0, res)
|
assertEquals(-2.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testConstProductInvocation() {
|
fun testConstProductInvocation() {
|
||||||
val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0)
|
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
||||||
assertEquals(4.0, res)
|
assertEquals(4.0, res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
package scietifik.kmath.ast
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
import scientifik.kmath.asm.compile
|
import scientifik.kmath.ast.evaluate
|
||||||
import scientifik.kmath.ast.parseMath
|
import scientifik.kmath.ast.parseMath
|
||||||
import scientifik.kmath.expressions.invoke
|
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -12,7 +11,7 @@ class AsmTest {
|
|||||||
@Test
|
@Test
|
||||||
fun parsedExpression() {
|
fun parsedExpression() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
val res = ComplexField.compile(mst)()
|
val res = ComplexField.evaluate(mst)
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,17 +7,17 @@ import scientifik.kmath.operations.Space
|
|||||||
/**
|
/**
|
||||||
* Create a functional expression on this [Space]
|
* Create a functional expression on this [Space]
|
||||||
*/
|
*/
|
||||||
fun <T> Space<T>.buildExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
|
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
|
||||||
FunctionalExpressionSpace(this).run(block)
|
FunctionalExpressionSpace(this).run(block)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a functional expression on this [Ring]
|
* Create a functional expression on this [Ring]
|
||||||
*/
|
*/
|
||||||
fun <T> Ring<T>.buildExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
|
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
|
||||||
FunctionalExpressionRing(this).run(block)
|
FunctionalExpressionRing(this).run(block)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a functional expression on this [Field]
|
* Create a functional expression on this [Field]
|
||||||
*/
|
*/
|
||||||
fun <T> Field<T>.buildExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
|
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
|
||||||
FunctionalExpressionField(this).run(block)
|
FunctionalExpressionField(this).run(block)
|
||||||
|
@ -8,11 +8,14 @@ import scientifik.kmath.operations.Algebra
|
|||||||
interface Expression<T> {
|
interface Expression<T> {
|
||||||
operator fun invoke(arguments: Map<String, T>): T
|
operator fun invoke(arguments: Map<String, T>): T
|
||||||
|
|
||||||
companion object {
|
companion object
|
||||||
operator fun <T> invoke(block: (Map<String, T>) -> T): Expression<T> = object : Expression<T> {
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create simple lazily evaluated expression inside given algebra
|
||||||
|
*/
|
||||||
|
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> = object: Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = block(arguments)
|
override fun invoke(arguments: Map<String, T>): T = block(arguments)
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
||||||
|
@ -69,10 +69,10 @@ interface FunctionalExpressionAlgebra<T, A : Algebra<T>> : ExpressionAlgebra<T,
|
|||||||
/**
|
/**
|
||||||
* A context class for [Expression] construction for [Space] algebras.
|
* A context class for [Expression] construction for [Space] algebras.
|
||||||
*/
|
*/
|
||||||
open class FunctionalExpressionSpace<T, A>(override val algebra: A) : FunctionalExpressionAlgebra<T, A>,
|
open class FunctionalExpressionSpace<T, A : Space<T>>(override val algebra: A) :
|
||||||
Space<Expression<T>> where A : Space<T> {
|
FunctionalExpressionAlgebra<T, A>, Space<Expression<T>> {
|
||||||
override val zero: Expression<T>
|
|
||||||
get() = const(algebra.zero)
|
override val zero: Expression<T> get() = const(algebra.zero)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of addition of two another expressions.
|
* Builds an Expression of addition of two another expressions.
|
||||||
|
Loading…
Reference in New Issue
Block a user