forked from kscience/kmath
Update KG and Maven repos, delete symbol delegate provider, implement working differentiable mst expression based on SFun shape to MST conversion
This commit is contained in:
parent
520f6cedeb
commit
29a670483b
@ -9,10 +9,15 @@ internal val githubProject: String by extra("kmath")
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
maven("https://jitpack.io")
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
group = "kscience.kmath"
|
||||
|
@ -8,14 +8,6 @@ plugins {
|
||||
}
|
||||
|
||||
allOpen.annotation("org.openjdk.jmh.annotations.State")
|
||||
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
sourceSets.register("benchmarks")
|
||||
|
||||
dependencies {
|
||||
|
@ -13,7 +13,7 @@ internal class ExpressionsInterpretersBenchmark {
|
||||
private val algebra: Field<Double> = RealField
|
||||
fun functionalExpression() {
|
||||
val expr = algebra.expressionInField {
|
||||
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
||||
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
|
||||
}
|
||||
|
||||
invokeAndSum(expr)
|
||||
|
@ -1,11 +1,9 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.kotlingrad.toMst
|
||||
import kscience.kmath.kotlingrad.toSFun
|
||||
import kscience.kmath.kotlingrad.toSVar
|
||||
import kscience.kmath.expressions.symbol
|
||||
import kscience.kmath.kotlingrad.DifferentiableMstExpression
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
@ -13,10 +11,12 @@ import kscience.kmath.operations.RealField
|
||||
* valid derivative.
|
||||
*/
|
||||
fun main() {
|
||||
val proto = DoublePrecision.prototype
|
||||
val x by MstAlgebra.symbol("x").toSVar(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto)
|
||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||
val x by symbol
|
||||
|
||||
val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath())
|
||||
.derivativeOrNull(listOf(x))
|
||||
.compile()
|
||||
|
||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ import kotlin.contracts.contract
|
||||
* @property mst the [MST] node.
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
||||
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||
@ -21,8 +21,9 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
||||
algebra.number(value)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||
(algebra as NumericAlgebra<T>).number(value)
|
||||
else
|
||||
error("Numeric nodes are not supported by $this")
|
||||
}
|
||||
@ -38,14 +39,14 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||
mstAlgebra: E,
|
||||
block: E.() -> MST,
|
||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Space].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Space<T>> A.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstSpace.block())
|
||||
}
|
||||
@ -55,7 +56,7 @@ public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MS
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstRing.block())
|
||||
}
|
||||
@ -65,7 +66,7 @@ public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST):
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstField.block())
|
||||
}
|
||||
@ -75,7 +76,7 @@ public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MS
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstExtendedField.block())
|
||||
}
|
||||
@ -85,7 +86,7 @@ public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtend
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInSpace(block)
|
||||
}
|
||||
@ -95,7 +96,7 @@ public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInRing(block)
|
||||
}
|
||||
@ -105,7 +106,7 @@ public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInField(block)
|
||||
}
|
||||
@ -117,7 +118,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
||||
*/
|
||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||
block: MstExtendedField.() -> MST,
|
||||
): MstExpression<T> {
|
||||
): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInExtendedField(block)
|
||||
}
|
||||
|
@ -1,12 +0,0 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kotlin.properties.PropertyDelegateProvider
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
|
||||
/**
|
||||
* Returns [PropertyDelegateProvider] providing [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name
|
||||
* of the property.
|
||||
*/
|
||||
public val Algebra<MST>.symbol: PropertyDelegateProvider<Algebra<MST>, ReadOnlyProperty<Algebra<MST>, MST.Symbolic>>
|
||||
get() = PropertyDelegateProvider { _, _ -> ReadOnlyProperty { _, p -> MST.Symbolic(p.name) } }
|
@ -69,4 +69,5 @@ public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<
|
||||
*
|
||||
* @author Alexander Nozik.
|
||||
*/
|
||||
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra)
|
||||
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
||||
mst.compileWith(T::class.java, algebra)
|
||||
|
@ -3,6 +3,7 @@ plugins {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api("com.github.breandan:kotlingrad:0.3.7")
|
||||
implementation("com.github.breandan:kaliningraph:0.1.2")
|
||||
implementation("com.github.breandan:kotlingrad:0.3.7")
|
||||
api(project(":kmath-ast"))
|
||||
}
|
||||
|
@ -0,0 +1,53 @@
|
||||
package kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.SFun
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MstAlgebra
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
/**
|
||||
* Represents wrapper of [MstExpression] implementing [DifferentiableExpression].
|
||||
*
|
||||
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
|
||||
* [SFun] back to [MST].
|
||||
*
|
||||
* @param T the type of number.
|
||||
* @param A the [NumericAlgebra] of [T].
|
||||
* @property expr the underlying [MstExpression].
|
||||
*/
|
||||
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
|
||||
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
||||
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
||||
|
||||
/**
|
||||
* The [MstExpression.algebra] of [expr].
|
||||
*/
|
||||
public val algebra: A
|
||||
get() = expr.algebra
|
||||
|
||||
/**
|
||||
* The [MstExpression.mst] of [expr].
|
||||
*/
|
||||
public val mst: MST
|
||||
get() = expr.mst
|
||||
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
|
||||
algebra,
|
||||
symbols.map(Symbol::identity)
|
||||
.map(MstAlgebra::symbol)
|
||||
.map { it.toSVar<KMathNumber<T, A>>() }
|
||||
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
||||
.toMst(),
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
||||
*/
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
|
||||
DifferentiableMstExpression(this)
|
@ -0,0 +1,18 @@
|
||||
package kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.RealNumber
|
||||
import edu.umontreal.kotlingrad.experimental.SConst
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
/**
|
||||
* Implements [RealNumber] by delegating its functionality to [NumericAlgebra].
|
||||
*
|
||||
* @param T the type of number.
|
||||
* @param A the [NumericAlgebra] of [T].
|
||||
* @property algebra the algebra.
|
||||
* @param value the value of this number.
|
||||
*/
|
||||
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||
}
|
@ -3,14 +3,26 @@ package kscience.kmath.kotlingrad
|
||||
import edu.umontreal.kotlingrad.experimental.*
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MstAlgebra
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.ast.MstExtendedField
|
||||
import kscience.kmath.ast.MstExtendedField.unaryMinus
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.operations.*
|
||||
|
||||
/**
|
||||
* Maps [SVar] to [MST.Symbolic] directly.
|
||||
*
|
||||
* @receiver the variable.
|
||||
* @return a node.
|
||||
*/
|
||||
public fun <X : SFun<X>> SVar<X>.toMst(): MST.Symbolic = MstAlgebra.symbol(name)
|
||||
|
||||
/**
|
||||
* Maps [SVar] to [MST.Numeric] directly.
|
||||
*
|
||||
* @receiver the constant.
|
||||
* @return a node.
|
||||
*/
|
||||
public fun <X : SFun<X>> SConst<X>.toMst(): MST.Numeric = MstAlgebra.number(doubleValue)
|
||||
|
||||
/**
|
||||
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
|
||||
* [Power] operation is limited to constant right-hand side arguments.
|
||||
@ -37,8 +49,8 @@ import kscience.kmath.operations.*
|
||||
*/
|
||||
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
|
||||
when (this@toMst) {
|
||||
is SVar -> symbol(name)
|
||||
is SConst -> number(doubleValue)
|
||||
is SVar -> toMst()
|
||||
is SConst -> toMst()
|
||||
is Sum -> left.toMst() + right.toMst()
|
||||
is Prod -> left.toMst() * right.toMst()
|
||||
is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue)
|
||||
@ -69,7 +81,7 @@ public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||
* @param proto the prototype instance.
|
||||
* @return a new variable.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
|
||||
internal fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
|
||||
|
||||
/**
|
||||
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||
@ -90,12 +102,12 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||
is MST.Symbolic -> toSVar()
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> value.toSFun<X>()
|
||||
SpaceOperations.MINUS_OPERATION -> (-value).toSFun()
|
||||
SpaceOperations.PLUS_OPERATION -> +value.toSFun<X>()
|
||||
SpaceOperations.MINUS_OPERATION -> -value.toSFun<X>()
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||
PowerOperations.SQRT_OPERATION -> value.toSFun<X>().sqrt()
|
||||
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
|
||||
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
@ -110,27 +122,3 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
}
|
||||
|
||||
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||
}
|
||||
|
||||
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
|
||||
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
||||
public val expr by lazy { MstExpression(algebra, mst) }
|
||||
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||
|
||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
|
||||
TODO()
|
||||
}
|
||||
|
||||
public fun derivativeOrNull(orders: List<Symbol>): Expression<T> {
|
||||
orders.map { MstAlgebra.symbol(it.identity).toSVar<KMathNumber<T, A>>() }
|
||||
.fold<SVar<KMathNumber<T, A>>, SFun<KMathNumber<T, A>>>(mst.toSFun()) { result, sVar -> result.d(sVar) }
|
||||
.toMst()
|
||||
|
||||
TODO()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user