Migrate to 1.8. Use universal autodiffs
This commit is contained in:
parent
6d47c0ccec
commit
3f4fe9e43b
@ -5,7 +5,7 @@ import space.kscience.kmath.benchmarks.addBenchmarkProperties
|
|||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
kotlin("multiplatform")
|
||||||
kotlin("plugin.allopen")
|
alias(spclibs.plugins.kotlin.plugin.allopen)
|
||||||
id("org.jetbrains.kotlinx.benchmark")
|
id("org.jetbrains.kotlinx.benchmark")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ kotlin {
|
|||||||
implementation(project(":kmath-tensors"))
|
implementation(project(":kmath-tensors"))
|
||||||
implementation(project(":kmath-multik"))
|
implementation(project(":kmath-multik"))
|
||||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||||
implementation(npmlibs.kotlinx.benchmark.runtime)
|
implementation(spclibs.kotlinx.benchmark.runtime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group = "space.kscience"
|
group = "space.kscience"
|
||||||
version = "0.3.1-dev-7"
|
version = "0.3.1-dev-8"
|
||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
plugins {
|
plugins {
|
||||||
`kotlin-dsl`
|
`kotlin-dsl`
|
||||||
`version-catalog`
|
`version-catalog`
|
||||||
kotlin("plugin.serialization") version "1.6.21"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
java.targetCompatibility = JavaVersion.VERSION_11
|
java.targetCompatibility = JavaVersion.VERSION_11
|
||||||
@ -13,18 +12,18 @@ repositories {
|
|||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
}
|
}
|
||||||
|
|
||||||
val toolsVersion = npmlibs.versions.tools.get()
|
val toolsVersion = spclibs.versions.tools.get()
|
||||||
val kotlinVersion = npmlibs.versions.kotlin.asProvider().get()
|
val kotlinVersion = spclibs.versions.kotlin.asProvider().get()
|
||||||
val benchmarksVersion = npmlibs.versions.kotlinx.benchmark.get()
|
val benchmarksVersion = spclibs.versions.kotlinx.benchmark.get()
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api("space.kscience:gradle-tools:$toolsVersion")
|
api("space.kscience:gradle-tools:$toolsVersion")
|
||||||
api(npmlibs.atomicfu.gradle)
|
|
||||||
//plugins form benchmarks
|
//plugins form benchmarks
|
||||||
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
|
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
|
||||||
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
//api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
||||||
//to be used inside build-script only
|
//to be used inside build-script only
|
||||||
implementation(npmlibs.kotlinx.serialization.json)
|
//implementation(spclibs.kotlinx.serialization.json)
|
||||||
|
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.14.+")
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets.all {
|
kotlin.sourceSets.all {
|
||||||
|
@ -26,7 +26,7 @@ dependencyResolutionManagement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
versionCatalogs {
|
versionCatalogs {
|
||||||
create("npmlibs") {
|
create("spclibs") {
|
||||||
from("space.kscience:version-catalog:$toolsVersion")
|
from("space.kscience:version-catalog:$toolsVersion")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,9 +5,6 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.benchmarks
|
package space.kscience.kmath.benchmarks
|
||||||
|
|
||||||
import kotlinx.serialization.Serializable
|
|
||||||
|
|
||||||
@Serializable
|
|
||||||
data class JmhReport(
|
data class JmhReport(
|
||||||
val jmhVersion: String,
|
val jmhVersion: String,
|
||||||
val benchmark: String,
|
val benchmark: String,
|
||||||
@ -37,7 +34,6 @@ data class JmhReport(
|
|||||||
val scoreUnit: String
|
val scoreUnit: String
|
||||||
}
|
}
|
||||||
|
|
||||||
@Serializable
|
|
||||||
data class PrimaryMetric(
|
data class PrimaryMetric(
|
||||||
override val score: Double,
|
override val score: Double,
|
||||||
override val scoreError: Double,
|
override val scoreError: Double,
|
||||||
@ -48,7 +44,6 @@ data class JmhReport(
|
|||||||
val rawData: List<List<Double>>? = null,
|
val rawData: List<List<Double>>? = null,
|
||||||
) : Metric
|
) : Metric
|
||||||
|
|
||||||
@Serializable
|
|
||||||
data class SecondaryMetric(
|
data class SecondaryMetric(
|
||||||
override val score: Double,
|
override val score: Double,
|
||||||
override val scoreError: Double,
|
override val scoreError: Double,
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
package space.kscience.kmath.benchmarks
|
package space.kscience.kmath.benchmarks
|
||||||
|
|
||||||
import kotlinx.benchmark.gradle.BenchmarksExtension
|
import kotlinx.benchmark.gradle.BenchmarksExtension
|
||||||
import kotlinx.serialization.decodeFromString
|
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
|
||||||
import kotlinx.serialization.json.Json
|
import com.fasterxml.jackson.module.kotlin.readValue
|
||||||
import org.gradle.api.Project
|
import org.gradle.api.Project
|
||||||
import space.kscience.gradle.KScienceReadmeExtension
|
import space.kscience.gradle.KScienceReadmeExtension
|
||||||
import java.time.LocalDateTime
|
import java.time.LocalDateTime
|
||||||
@ -45,6 +45,8 @@ private val ISO_DATE_TIME: DateTimeFormatter = DateTimeFormatterBuilder().run {
|
|||||||
|
|
||||||
private fun noun(number: Number, singular: String, plural: String) = if (number.toLong() == 1L) singular else plural
|
private fun noun(number: Number, singular: String, plural: String) = if (number.toLong() == 1L) singular else plural
|
||||||
|
|
||||||
|
private val jsonMapper = jacksonObjectMapper()
|
||||||
|
|
||||||
fun Project.addBenchmarkProperties() {
|
fun Project.addBenchmarkProperties() {
|
||||||
val benchmarksProject = this
|
val benchmarksProject = this
|
||||||
rootProject.subprojects.forEach { p ->
|
rootProject.subprojects.forEach { p ->
|
||||||
@ -60,8 +62,7 @@ fun Project.addBenchmarkProperties() {
|
|||||||
if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) {
|
if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) {
|
||||||
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
|
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
|
||||||
} else {
|
} else {
|
||||||
val reports =
|
val reports: List<JmhReport> = jsonMapper.readValue<List<JmhReport>>(resDirectory.resolve("jvm.json"))
|
||||||
Json.decodeFromString<List<JmhReport>>(resDirectory.resolve("jvm.json").readText())
|
|
||||||
|
|
||||||
buildString {
|
buildString {
|
||||||
appendLine("<details>")
|
appendLine("<details>")
|
||||||
|
@ -7,9 +7,9 @@ package space.kscience.kmath.fit
|
|||||||
|
|
||||||
import kotlinx.html.br
|
import kotlinx.html.br
|
||||||
import kotlinx.html.h3
|
import kotlinx.html.h3
|
||||||
import space.kscience.kmath.commons.expressions.DSProcessor
|
|
||||||
import space.kscience.kmath.commons.optimization.CMOptimizer
|
import space.kscience.kmath.commons.optimization.CMOptimizer
|
||||||
import space.kscience.kmath.distributions.NormalDistribution
|
import space.kscience.kmath.distributions.NormalDistribution
|
||||||
|
import space.kscience.kmath.expressions.autodiff
|
||||||
import space.kscience.kmath.expressions.chiSquaredExpression
|
import space.kscience.kmath.expressions.chiSquaredExpression
|
||||||
import space.kscience.kmath.expressions.symbol
|
import space.kscience.kmath.expressions.symbol
|
||||||
import space.kscience.kmath.operations.asIterable
|
import space.kscience.kmath.operations.asIterable
|
||||||
@ -67,7 +67,7 @@ suspend fun main() {
|
|||||||
val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma)
|
val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma)
|
||||||
|
|
||||||
// compute differentiable chi^2 sum for given model ax^2 + bx + c
|
// compute differentiable chi^2 sum for given model ax^2 + bx + c
|
||||||
val chi2 = DSProcessor.chiSquaredExpression(x, y, yErr) { arg ->
|
val chi2 = Double.autodiff.chiSquaredExpression(x, y, yErr) { arg ->
|
||||||
//bind variables to autodiff context
|
//bind variables to autodiff context
|
||||||
val a = bindSymbol(a)
|
val a = bindSymbol(a)
|
||||||
val b = bindSymbol(b)
|
val b = bindSymbol(b)
|
||||||
|
@ -7,10 +7,10 @@ package space.kscience.kmath.fit
|
|||||||
|
|
||||||
import kotlinx.html.br
|
import kotlinx.html.br
|
||||||
import kotlinx.html.h3
|
import kotlinx.html.h3
|
||||||
import space.kscience.kmath.commons.expressions.DSProcessor
|
|
||||||
import space.kscience.kmath.data.XYErrorColumnarData
|
import space.kscience.kmath.data.XYErrorColumnarData
|
||||||
import space.kscience.kmath.distributions.NormalDistribution
|
import space.kscience.kmath.distributions.NormalDistribution
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.expressions.autodiff
|
||||||
import space.kscience.kmath.expressions.binding
|
import space.kscience.kmath.expressions.binding
|
||||||
import space.kscience.kmath.expressions.symbol
|
import space.kscience.kmath.expressions.symbol
|
||||||
import space.kscience.kmath.operations.asIterable
|
import space.kscience.kmath.operations.asIterable
|
||||||
@ -63,7 +63,7 @@ suspend fun main() {
|
|||||||
|
|
||||||
val result = XYErrorColumnarData.of(x, y, yErr).fitWith(
|
val result = XYErrorColumnarData.of(x, y, yErr).fitWith(
|
||||||
QowOptimizer,
|
QowOptimizer,
|
||||||
DSProcessor,
|
Double.autodiff,
|
||||||
mapOf(a to 0.9, b to 1.2, c to 2.0)
|
mapOf(a to 0.9, b to 1.2, c to 2.0)
|
||||||
) { arg ->
|
) { arg ->
|
||||||
//bind variables to autodiff context
|
//bind variables to autodiff context
|
||||||
|
@ -36,7 +36,7 @@ private suspend fun runKMathChained(): Duration {
|
|||||||
return Duration.between(startTime, Instant.now())
|
return Duration.between(startTime, Instant.now())
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun runApacheDirect(): Duration {
|
private fun runCMDirect(): Duration {
|
||||||
val rng = RandomSource.create(RandomSource.MT, 123L)
|
val rng = RandomSource.create(RandomSource.MT, 123L)
|
||||||
|
|
||||||
val sampler = CMGaussianSampler.of(
|
val sampler = CMGaussianSampler.of(
|
||||||
@ -65,7 +65,7 @@ private fun runApacheDirect(): Duration {
|
|||||||
* Comparing chain sampling performance with direct sampling performance
|
* Comparing chain sampling performance with direct sampling performance
|
||||||
*/
|
*/
|
||||||
fun main(): Unit = runBlocking(Dispatchers.Default) {
|
fun main(): Unit = runBlocking(Dispatchers.Default) {
|
||||||
val directJob = async { runApacheDirect() }
|
val directJob = async { runCMDirect() }
|
||||||
val chainJob = async { runKMathChained() }
|
val chainJob = async { runKMathChained() }
|
||||||
println("KMath Chained: ${chainJob.await()}")
|
println("KMath Chained: ${chainJob.await()}")
|
||||||
println("Apache Direct: ${directJob.await()}")
|
println("Apache Direct: ${directJob.await()}")
|
||||||
|
@ -5,12 +5,11 @@
|
|||||||
kotlin.code.style=official
|
kotlin.code.style=official
|
||||||
kotlin.mpp.stability.nowarn=true
|
kotlin.mpp.stability.nowarn=true
|
||||||
kotlin.native.ignoreDisabledTargets=true
|
kotlin.native.ignoreDisabledTargets=true
|
||||||
kotlin.incremental.js.ir=true
|
|
||||||
|
|
||||||
org.gradle.configureondemand=true
|
org.gradle.configureondemand=true
|
||||||
org.gradle.jvmargs=-Xmx4096m
|
org.gradle.jvmargs=-Xmx4096m
|
||||||
|
|
||||||
toolsVersion=0.13.1-kotlin-1.7.20
|
toolsVersion=0.13.4-kotlin-1.8.0
|
||||||
|
|
||||||
|
|
||||||
org.gradle.parallel=true
|
org.gradle.parallel=true
|
||||||
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
|||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
@ -18,7 +18,8 @@ import space.kscience.kmath.operations.NumbersAddOps
|
|||||||
* @param bindings The map of bindings values. All bindings are considered free parameters
|
* @param bindings The map of bindings values. All bindings are considered free parameters
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public class DerivativeStructureField(
|
@Deprecated("Use generic DSAlgebra from the core")
|
||||||
|
public class CmDsField(
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
bindings: Map<Symbol, Double>,
|
bindings: Map<Symbol, Double>,
|
||||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
||||||
@ -108,25 +109,27 @@ public class DerivativeStructureField(
|
|||||||
/**
|
/**
|
||||||
* Auto-diff processor based on Commons-math [DerivativeStructure]
|
* Auto-diff processor based on Commons-math [DerivativeStructure]
|
||||||
*/
|
*/
|
||||||
public object DSProcessor : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
@Deprecated("Use generic DSAlgebra from the core")
|
||||||
|
public object CmDsProcessor : AutoDiffProcessor<Double, DerivativeStructure, CmDsField> {
|
||||||
override fun differentiate(
|
override fun differentiate(
|
||||||
function: DerivativeStructureField.() -> DerivativeStructure,
|
function: CmDsField.() -> DerivativeStructure,
|
||||||
): DerivativeStructureExpression = DerivativeStructureExpression(function)
|
): CmDsExpression = CmDsExpression(function)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constructs that creates a derivative structure with required order on-demand
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
*/
|
*/
|
||||||
public class DerivativeStructureExpression(
|
@Deprecated("Use generic DSAlgebra from the core")
|
||||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
public class CmDsExpression(
|
||||||
|
public val function: CmDsField.() -> DerivativeStructure,
|
||||||
) : DifferentiableExpression<Double> {
|
) : DifferentiableExpression<Double> {
|
||||||
override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
DerivativeStructureField(0, arguments).function().value
|
CmDsField(0, arguments).function().value
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
*/
|
*/
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
||||||
with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) }
|
with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) }
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -3,6 +3,8 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:Suppress("DEPRECATION")
|
||||||
|
|
||||||
package space.kscience.kmath.commons.expressions
|
package space.kscience.kmath.commons.expressions
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
@ -15,10 +17,10 @@ import kotlin.test.assertFails
|
|||||||
internal inline fun diff(
|
internal inline fun diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<Symbol, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField.() -> Unit,
|
block: CmDsField.() -> Unit,
|
||||||
) {
|
) {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
CmDsField(order, mapOf(*parameters)).run(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
@ -41,7 +43,7 @@ internal class AutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun autoDifTest() {
|
fun autoDifTest() {
|
||||||
val f = DerivativeStructureExpression {
|
val f = CmDsExpression {
|
||||||
val x by binding
|
val x by binding
|
||||||
val y by binding
|
val y by binding
|
||||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||||
|
@ -6,22 +6,25 @@
|
|||||||
package space.kscience.kmath.commons.optimization
|
package space.kscience.kmath.commons.optimization
|
||||||
|
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import space.kscience.kmath.commons.expressions.DSProcessor
|
|
||||||
import space.kscience.kmath.commons.expressions.DerivativeStructureExpression
|
|
||||||
import space.kscience.kmath.distributions.NormalDistribution
|
import space.kscience.kmath.distributions.NormalDistribution
|
||||||
|
import space.kscience.kmath.expressions.DSFieldExpression
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.y
|
import space.kscience.kmath.expressions.Symbol.Companion.y
|
||||||
|
import space.kscience.kmath.expressions.autodiff
|
||||||
import space.kscience.kmath.expressions.chiSquaredExpression
|
import space.kscience.kmath.expressions.chiSquaredExpression
|
||||||
import space.kscience.kmath.expressions.symbol
|
import space.kscience.kmath.expressions.symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleBufferOps.Companion.map
|
import space.kscience.kmath.operations.DoubleBufferOps.Companion.map
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.optimization.*
|
import space.kscience.kmath.optimization.*
|
||||||
import space.kscience.kmath.random.RandomGenerator
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
internal class OptimizeTest {
|
internal class OptimizeTest {
|
||||||
val normal = DerivativeStructureExpression {
|
val normal = DSFieldExpression(DoubleField) {
|
||||||
exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y).pow(2) / 2)
|
exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y).pow(2) / 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +63,7 @@ internal class OptimizeTest {
|
|||||||
|
|
||||||
val yErr = DoubleBuffer(x.size) { sigma }
|
val yErr = DoubleBuffer(x.size) { sigma }
|
||||||
|
|
||||||
val chi2 = DSProcessor.chiSquaredExpression(
|
val chi2 = Double.autodiff.chiSquaredExpression(
|
||||||
x, y, yErr
|
x, y, yErr
|
||||||
) { arg ->
|
) { arg ->
|
||||||
val cWithDefault = bindSymbolOrNull(c) ?: one
|
val cWithDefault = bindSymbolOrNull(c) ?: one
|
||||||
|
@ -80,7 +80,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
public val algebra: A,
|
public val algebra: A,
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
public val valueBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
|
|
||||||
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -116,7 +115,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
|
|
||||||
newCache[p][o] = DSCompiler(
|
newCache[p][o] = DSCompiler(
|
||||||
algebra,
|
algebra,
|
||||||
valueBufferFactory,
|
|
||||||
p,
|
p,
|
||||||
o,
|
o,
|
||||||
valueCompiler,
|
valueCompiler,
|
||||||
@ -141,7 +139,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
override val symbols: List<Symbol> = bindings.map { it.key }
|
override val symbols: List<Symbol> = bindings.map { it.key }
|
||||||
|
|
||||||
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
||||||
val buffer = valueBufferFactory(compiler.size) { algebra.zero }
|
val buffer = algebra.bufferFactory(compiler.size) { algebra.zero }
|
||||||
buffer[0] = value
|
buffer[0] = value
|
||||||
if (compiler.order > 0) {
|
if (compiler.order > 0) {
|
||||||
// the derivative of the variable with respect to itself is 1.
|
// the derivative of the variable with respect to itself is 1.
|
||||||
@ -207,7 +205,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
public override fun const(value: T): DS<T, A> {
|
public override fun const(value: T): DS<T, A> {
|
||||||
val buffer = valueBufferFactory(compiler.size) { algebra.zero }
|
val buffer = algebra.bufferFactory(compiler.size) { algebra.zero }
|
||||||
buffer[0] = value
|
buffer[0] = value
|
||||||
|
|
||||||
return DS(buffer)
|
return DS(buffer)
|
||||||
@ -245,11 +243,14 @@ public open class DSRing<T, A>(
|
|||||||
algebra: A,
|
algebra: A,
|
||||||
order: Int,
|
order: Int,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
valueBufferFactory: MutableBufferFactory<T>,
|
) : DSAlgebra<T, A>(algebra, order, bindings),
|
||||||
) : DSAlgebra<T, A>(algebra, order, bindings, valueBufferFactory),
|
Ring<DS<T, A>>,
|
||||||
Ring<DS<T, A>>, ScaleOperations<DS<T, A>>,
|
ScaleOperations<DS<T, A>>,
|
||||||
NumericAlgebra<DS<T, A>>,
|
NumericAlgebra<DS<T, A>>,
|
||||||
NumbersAddOps<DS<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
NumbersAddOps<DS<T, A>>
|
||||||
|
where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||||
|
|
||||||
|
public val elementBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): DSSymbol? =
|
override fun bindSymbolOrNull(value: String): DSSymbol? =
|
||||||
super<DSAlgebra>.bindSymbolOrNull(value)
|
super<DSAlgebra>.bindSymbolOrNull(value)
|
||||||
@ -261,14 +262,14 @@ public open class DSRing<T, A>(
|
|||||||
*/
|
*/
|
||||||
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
|
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
|
||||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
val newData = valueBufferFactory(compiler.size) { data[it] }
|
val newData = elementBufferFactory(compiler.size) { data[it] }
|
||||||
algebra.block(newData)
|
algebra.block(newData)
|
||||||
return DS(newData)
|
return DS(newData)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
|
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
|
||||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
val newData: Buffer<T> = data.mapToBuffer(valueBufferFactory) {
|
val newData: Buffer<T> = data.mapToBuffer(elementBufferFactory) {
|
||||||
algebra.block(it)
|
algebra.block(it)
|
||||||
}
|
}
|
||||||
return DS(newData)
|
return DS(newData)
|
||||||
@ -276,7 +277,7 @@ public open class DSRing<T, A>(
|
|||||||
|
|
||||||
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
|
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
|
||||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
val newData: Buffer<T> = data.mapIndexedToBuffer(valueBufferFactory, block)
|
val newData: Buffer<T> = data.mapIndexedToBuffer(elementBufferFactory, block)
|
||||||
return DS(newData)
|
return DS(newData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,22 +330,21 @@ public class DerivativeStructureRingExpression<T, A>(
|
|||||||
public val function: DSRing<T, A>.() -> DS<T, A>,
|
public val function: DSRing<T, A>.() -> DS<T, A>,
|
||||||
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
DSRing(algebra, 0, arguments, elementBufferFactory).function().value
|
DSRing(algebra, 0, arguments).function().value
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
with(
|
with(
|
||||||
DSRing(
|
DSRing(
|
||||||
algebra,
|
algebra,
|
||||||
symbols.size,
|
symbols.size,
|
||||||
arguments,
|
arguments
|
||||||
elementBufferFactory
|
|
||||||
)
|
)
|
||||||
) { function().derivative(symbols) }
|
) { function().derivative(symbols) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field over commons-math [DerivativeStructure].
|
* A field over [DS].
|
||||||
*
|
*
|
||||||
* @property order The derivation order.
|
* @property order The derivation order.
|
||||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||||
@ -354,8 +354,7 @@ public class DSField<T, A : ExtendedField<T>>(
|
|||||||
algebra: A,
|
algebra: A,
|
||||||
order: Int,
|
order: Int,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
valueBufferFactory: MutableBufferFactory<T>,
|
) : DSRing<T, A>(algebra, order, bindings), ExtendedField<DS<T, A>> {
|
||||||
) : DSRing<T, A>(algebra, order, bindings, valueBufferFactory), ExtendedField<DS<T, A>> {
|
|
||||||
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||||
|
|
||||||
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||||
@ -414,6 +413,7 @@ public class DSField<T, A : ExtendedField<T>>(
|
|||||||
is Int -> arg.transformDataBuffer { result ->
|
is Int -> arg.transformDataBuffer { result ->
|
||||||
compiler.pow(arg.data, 0, pow, result, 0)
|
compiler.pow(arg.data, 0, pow, result, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> arg.transformDataBuffer { result ->
|
else -> arg.transformDataBuffer { result ->
|
||||||
compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
|
compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
|
||||||
}
|
}
|
||||||
@ -439,18 +439,29 @@ public class DSField<T, A : ExtendedField<T>>(
|
|||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public class DSFieldExpression<T, A : ExtendedField<T>>(
|
public class DSFieldExpression<T, A : ExtendedField<T>>(
|
||||||
public val algebra: A,
|
public val algebra: A,
|
||||||
private val valueBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
|
|
||||||
public val function: DSField<T, A>.() -> DS<T, A>,
|
public val function: DSField<T, A>.() -> DS<T, A>,
|
||||||
) : DifferentiableExpression<T> {
|
) : DifferentiableExpression<T> {
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
DSField(algebra, 0, arguments, valueBufferFactory).function().value
|
DSField(algebra, 0, arguments).function().value
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
DSField(
|
DSField(
|
||||||
algebra,
|
algebra,
|
||||||
symbols.size,
|
symbols.size,
|
||||||
arguments,
|
arguments,
|
||||||
valueBufferFactory,
|
|
||||||
).run { function().derivative(symbols) }
|
).run { function().derivative(symbols) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DSFieldProcessor<T, A : ExtendedField<T>>(
|
||||||
|
public val algebra: A,
|
||||||
|
) : AutoDiffProcessor<T, DS<T, A>, DSField<T, A>> {
|
||||||
|
override fun differentiate(
|
||||||
|
function: DSField<T, A>.() -> DS<T, A>,
|
||||||
|
): DifferentiableExpression<T> = DSFieldExpression(algebra, function)
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public val Double.Companion.autodiff: DSFieldProcessor<Double, DoubleField> get() = DSFieldProcessor(DoubleField)
|
@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
|
|||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.structures.MutableBufferFactory
|
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
|
||||||
internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) {
|
internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) {
|
||||||
@ -56,7 +55,6 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
|
|||||||
*/
|
*/
|
||||||
public class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
public class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||||
public val algebra: A,
|
public val algebra: A,
|
||||||
public val bufferFactory: MutableBufferFactory<T>,
|
|
||||||
public val freeParameters: Int,
|
public val freeParameters: Int,
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
valueCompiler: DSCompiler<T, A>?,
|
valueCompiler: DSCompiler<T, A>?,
|
||||||
|
@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
|
|||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -22,7 +21,7 @@ internal inline fun diff(
|
|||||||
block: DSField<Double, DoubleField>.() -> Unit,
|
block: DSField<Double, DoubleField>.() -> Unit,
|
||||||
) {
|
) {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
DSField(DoubleField, order, mapOf(*parameters), ::DoubleBuffer).block()
|
DSField(DoubleField, order, mapOf(*parameters)).block()
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class DSTest {
|
internal class DSTest {
|
||||||
@ -45,7 +44,7 @@ internal class DSTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun dsExpressionTest() {
|
fun dsExpressionTest() {
|
||||||
val f = DSFieldExpression(DoubleField, ::DoubleBuffer) {
|
val f = DSFieldExpression(DoubleField) {
|
||||||
val x by binding
|
val x by binding
|
||||||
val y by binding
|
val y by binding
|
||||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||||
|
@ -17,7 +17,7 @@ kotlin.sourceSets {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
dokkaPlugin("org.jetbrains.dokka:mathjax-plugin:${npmlibs.versions.dokka.get()}")
|
dokkaPlugin("org.jetbrains.dokka:mathjax-plugin:${spclibs.versions.dokka.get()}")
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
@ -12,7 +12,7 @@ kotlin.sourceSets {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api(npmlibs.atomicfu)
|
api(spclibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
commonTest {
|
commonTest {
|
||||||
|
@ -14,7 +14,7 @@ kotlin.sourceSets {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
api(npmlibs.atomicfu)
|
api(spclibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@ kotlin.sourceSets {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
implementation(npmlibs.atomicfu)
|
implementation(spclibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ dependencyResolutionManagement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
versionCatalogs {
|
versionCatalogs {
|
||||||
create("npmlibs") {
|
create("spclibs") {
|
||||||
from("space.kscience:version-catalog:$toolsVersion")
|
from("space.kscience:version-catalog:$toolsVersion")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user