Update KG and Maven repos, delete symbol delegate provider, implement working differentiable mst expression based on SFun shape to MST conversion

This commit is contained in:
Iaroslav Postovalov 2020-10-30 01:09:11 +07:00
parent 520f6cedeb
commit 29a670483b
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
11 changed files with 125 additions and 78 deletions

View File

@ -9,10 +9,15 @@ internal val githubProject: String by extra("kmath")
allprojects { allprojects {
repositories { repositories {
jcenter() jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/hotkeytlt/maven") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://jitpack.io") maven("https://jitpack.io")
mavenCentral()
} }
group = "kscience.kmath" group = "kscience.kmath"

View File

@ -8,14 +8,6 @@ plugins {
} }
allOpen.annotation("org.openjdk.jmh.annotations.State") allOpen.annotation("org.openjdk.jmh.annotations.State")
repositories {
maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
mavenCentral()
}
sourceSets.register("benchmarks") sourceSets.register("benchmarks")
dependencies { dependencies {

View File

@ -13,7 +13,7 @@ internal class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField private val algebra: Field<Double> = RealField
fun functionalExpression() { fun functionalExpression() {
val expr = algebra.expressionInField { val expr = algebra.expressionInField {
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
} }
invokeAndSum(expr) invokeAndSum(expr)

View File

@ -1,11 +1,9 @@
package kscience.kmath.ast package kscience.kmath.ast
import edu.umontreal.kotlingrad.experimental.DoublePrecision
import kscience.kmath.asm.compile import kscience.kmath.asm.compile
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.kotlingrad.toMst import kscience.kmath.expressions.symbol
import kscience.kmath.kotlingrad.toSFun import kscience.kmath.kotlingrad.DifferentiableMstExpression
import kscience.kmath.kotlingrad.toSVar
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
/** /**
@ -13,10 +11,12 @@ import kscience.kmath.operations.RealField
* valid derivative. * valid derivative.
*/ */
fun main() { fun main() {
val proto = DoublePrecision.prototype val x by symbol
val x by MstAlgebra.symbol("x").toSVar(proto)
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto) val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath())
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() .derivativeOrNull(listOf(x))
.compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
} }

View File

@ -13,7 +13,7 @@ import kotlin.contracts.contract
* @property mst the [MST] node. * @property mst the [MST] node.
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> { public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
@ -21,8 +21,9 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
override fun binaryOperation(operation: String, left: T, right: T): T = override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right) algebra.binaryOperation(operation, left, right)
override fun number(value: Number): T = if (algebra is NumericAlgebra) @Suppress("UNCHECKED_CAST")
algebra.number(value) override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else else
error("Numeric nodes are not supported by $this") error("Numeric nodes are not supported by $this")
} }
@ -38,14 +39,14 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst( public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E, mstAlgebra: E,
block: E.() -> MST, block: E.() -> MST,
): MstExpression<T> = MstExpression(this, mstAlgebra.block()) ): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
/** /**
* Builds [MstExpression] over [Space]. * Builds [MstExpression] over [Space].
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Space<T>> A.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstSpace.block()) return MstExpression(this, MstSpace.block())
} }
@ -55,7 +56,7 @@ public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MS
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block()) return MstExpression(this, MstRing.block())
} }
@ -65,7 +66,7 @@ public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST):
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block()) return MstExpression(this, MstField.block())
} }
@ -75,7 +76,7 @@ public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MS
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block()) return MstExpression(this, MstExtendedField.block())
} }
@ -85,7 +86,7 @@ public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtend
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInSpace(block) return algebra.mstInSpace(block)
} }
@ -95,7 +96,7 @@ public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block) return algebra.mstInRing(block)
} }
@ -105,7 +106,7 @@ public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block) return algebra.mstInField(block)
} }
@ -117,7 +118,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
*/ */
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField( public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST, block: MstExtendedField.() -> MST,
): MstExpression<T> { ): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block) return algebra.mstInExtendedField(block)
} }

View File

@ -1,12 +0,0 @@
package kscience.kmath.ast
import kscience.kmath.operations.Algebra
import kotlin.properties.PropertyDelegateProvider
import kotlin.properties.ReadOnlyProperty
/**
* Returns [PropertyDelegateProvider] providing [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name
* of the property.
*/
public val Algebra<MST>.symbol: PropertyDelegateProvider<Algebra<MST>, ReadOnlyProperty<Algebra<MST>, MST.Symbolic>>
get() = PropertyDelegateProvider { _, _ -> ReadOnlyProperty { _, p -> MST.Symbolic(p.name) } }

View File

@ -69,4 +69,5 @@ public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<
* *
* @author Alexander Nozik. * @author Alexander Nozik.
*/ */
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra) public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
mst.compileWith(T::class.java, algebra)

View File

@ -3,6 +3,7 @@ plugins {
} }
dependencies { dependencies {
api("com.github.breandan:kotlingrad:0.3.7") implementation("com.github.breandan:kaliningraph:0.1.2")
implementation("com.github.breandan:kotlingrad:0.3.7")
api(project(":kmath-ast")) api(project(":kmath-ast"))
} }

View File

@ -0,0 +1,53 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.SFun
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.NumericAlgebra
/**
* Represents wrapper of [MstExpression] implementing [DifferentiableExpression].
*
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin, then converting
* [SFun] back to [MST].
*
* @param T the type of number.
* @param A the [NumericAlgebra] of [T].
* @property expr the underlying [MstExpression].
*/
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
/**
* The [MstExpression.algebra] of [expr].
*/
public val algebra: A
get() = expr.algebra
/**
* The [MstExpression.mst] of [expr].
*/
public val mst: MST
get() = expr.mst
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::symbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
}
/**
* Wraps this [MstExpression] into [DifferentiableMstExpression].
*/
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(this)

View File

@ -0,0 +1,18 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.RealNumber
import edu.umontreal.kotlingrad.experimental.SConst
import kscience.kmath.operations.NumericAlgebra
/**
* Implements [RealNumber] by delegating its functionality to [NumericAlgebra].
*
* @param T the type of number.
* @param A the [NumericAlgebra] of [T].
* @property algebra the algebra.
* @param value the value of this number.
*/
public class KMathNumber<T, A>(public val algebra: A, value: T) :
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
}

View File

@ -3,14 +3,26 @@ package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.* import edu.umontreal.kotlingrad.experimental.*
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.ast.MstExtendedField import kscience.kmath.ast.MstExtendedField
import kscience.kmath.ast.MstExtendedField.unaryMinus import kscience.kmath.ast.MstExtendedField.unaryMinus
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.* import kscience.kmath.operations.*
/**
* Maps [SVar] to [MST.Symbolic] directly.
*
* @receiver the variable.
* @return a node.
*/
public fun <X : SFun<X>> SVar<X>.toMst(): MST.Symbolic = MstAlgebra.symbol(name)
/**
* Maps [SVar] to [MST.Numeric] directly.
*
* @receiver the constant.
* @return a node.
*/
public fun <X : SFun<X>> SConst<X>.toMst(): MST.Numeric = MstAlgebra.number(doubleValue)
/** /**
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then. * Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
* [Power] operation is limited to constant right-hand side arguments. * [Power] operation is limited to constant right-hand side arguments.
@ -37,8 +49,8 @@ import kscience.kmath.operations.*
*/ */
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField { public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
when (this@toMst) { when (this@toMst) {
is SVar -> symbol(name) is SVar -> toMst()
is SConst -> number(doubleValue) is SConst -> toMst()
is Sum -> left.toMst() + right.toMst() is Sum -> left.toMst() + right.toMst()
is Prod -> left.toMst() * right.toMst() is Prod -> left.toMst() * right.toMst()
is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue) is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue)
@ -69,7 +81,7 @@ public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
* @param proto the prototype instance. * @param proto the prototype instance.
* @return a new variable. * @return a new variable.
*/ */
public fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value) internal fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
/** /**
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
@ -90,12 +102,12 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is MST.Symbolic -> toSVar() is MST.Symbolic -> toSVar()
is MST.Unary -> when (operation) { is MST.Unary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> value.toSFun<X>() SpaceOperations.PLUS_OPERATION -> +value.toSFun<X>()
SpaceOperations.MINUS_OPERATION -> (-value).toSFun() SpaceOperations.MINUS_OPERATION -> -value.toSFun<X>()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
PowerOperations.SQRT_OPERATION -> value.toSFun<X>().sqrt() PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun()) ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln() ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
else -> error("Unary operation $operation not defined in $this") else -> error("Unary operation $operation not defined in $this")
@ -110,27 +122,3 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }
} }
public class KMathNumber<T, A>(public val algebra: A, value: T) :
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
}
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
public val expr by lazy { MstExpression(algebra, mst) }
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
TODO()
}
public fun derivativeOrNull(orders: List<Symbol>): Expression<T> {
orders.map { MstAlgebra.symbol(it.identity).toSVar<KMathNumber<T, A>>() }
.fold<SVar<KMathNumber<T, A>>, SFun<KMathNumber<T, A>>>(mst.toSFun()) { result, sVar -> result.d(sVar) }
.toMst()
TODO()
}
}