Add adapters of scalar functions to MST and vice versa

This commit is contained in:
Iaroslav Postovalov 2020-10-12 20:26:03 +07:00
parent a12c645416
commit f0fbebd770
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
6 changed files with 82 additions and 5 deletions

View File

@ -12,6 +12,7 @@ allprojects {
maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/hotkeytlt/maven") maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://jitpack.io")
} }
group = "kscience.kmath" group = "kscience.kmath"

View File

@ -0,0 +1,8 @@
plugins {
id("ru.mipt.npm.jvm")
}
dependencies {
api("com.github.breandan:kotlingrad:0.3.2")
api(project(":kmath-ast"))
}

View File

@ -0,0 +1,64 @@
package kscience.kmath.ast.kotlingrad
import edu.umontreal.kotlingrad.experimental.*
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstExtendedField
import kscience.kmath.operations.*
/**
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
*
* @receiver a scalar function.
* @return the [MST].
*/
public fun <X : SFun<X>> SFun<X>.mst(): MST = MstExtendedField {
when (this@mst) {
is SVar -> symbol(name)
is SConst -> number(doubleValue)
is Sum -> left.mst() + right.mst()
is Prod -> left.mst() * right.mst()
is Power -> power(left.mst(), (right() as SConst<*>).doubleValue)
is Negative -> -input.mst()
is Log -> ln(left.mst()) / ln(right.mst())
is Sine -> sin(input.mst())
is Cosine -> cos(input.mst())
is Tangent -> tan(input.mst())
is DProd -> this@mst().mst()
is SComposition -> this@mst().mst()
is VSumAll<X, *> -> this@mst().mst()
is Derivative -> this@mst().mst()
}
}
/**
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
*
* @receiver an [MST].
* @return the scalar function.
*/
public fun <X : SFun<X>> MST.sfun(proto: X): SFun<X> {
return when (this) {
is MST.Numeric -> SConst(value)
is MST.Symbolic -> SVar(proto, value)
is MST.Unary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> value.sfun(proto)
SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto))
TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto))
TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto))
TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto))
PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5))
ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E())
ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto))
else -> error("Unary operation $operation not defined in $this")
}
is MST.Binary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto))
SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto)))
RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto))
FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One())))
else -> error("Binary operation $operation not defined in $this")
}
}
}

View File

@ -1,4 +1,6 @@
plugins { id("ru.mipt.npm.mpp") } plugins {
id("ru.mipt.npm.mpp")
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { commonMain {

View File

@ -1,4 +1,6 @@
plugins { id("ru.mipt.npm.jvm") } plugins {
id("ru.mipt.npm.jvm")
}
description = "Binding for https://github.com/JetBrains-Research/viktor" description = "Binding for https://github.com/JetBrains-Research/viktor"

View File

@ -7,7 +7,6 @@ pluginManagement {
maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
} }
val toolsVersion = "0.6.1-dev-1.4.20-M1" val toolsVersion = "0.6.1-dev-1.4.20-M1"
@ -25,11 +24,11 @@ pluginManagement {
} }
rootProject.name = "kmath" rootProject.name = "kmath"
include( include(
":kmath-memory", ":kmath-memory",
":kmath-core", ":kmath-core",
":kmath-functions", ":kmath-functions",
// ":kmath-io",
":kmath-coroutines", ":kmath-coroutines",
":kmath-histograms", ":kmath-histograms",
":kmath-commons", ":kmath-commons",
@ -40,5 +39,6 @@ include(
":kmath-geometry", ":kmath-geometry",
":kmath-ast", ":kmath-ast",
":examples", ":examples",
":kmath-ejml" ":kmath-ejml",
":kmath-ast-kotlingrad"
) )