Migrate to 1.8. Use universal autodiffs

This commit is contained in:
Alexander Nozik 2022-12-31 15:02:52 +03:00
parent 6d47c0ccec
commit 3f4fe9e43b
22 changed files with 84 additions and 74 deletions

View File

@ -5,7 +5,7 @@ import space.kscience.kmath.benchmarks.addBenchmarkProperties
plugins { plugins {
kotlin("multiplatform") kotlin("multiplatform")
kotlin("plugin.allopen") alias(spclibs.plugins.kotlin.plugin.allopen)
id("org.jetbrains.kotlinx.benchmark") id("org.jetbrains.kotlinx.benchmark")
} }
@ -44,7 +44,7 @@ kotlin {
implementation(project(":kmath-tensors")) implementation(project(":kmath-tensors"))
implementation(project(":kmath-multik")) implementation(project(":kmath-multik"))
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion") implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
implementation(npmlibs.kotlinx.benchmark.runtime) implementation(spclibs.kotlinx.benchmark.runtime)
} }
} }

View File

@ -15,7 +15,7 @@ allprojects {
} }
group = "space.kscience" group = "space.kscience"
version = "0.3.1-dev-7" version = "0.3.1-dev-8"
} }
subprojects { subprojects {

View File

@ -1,7 +1,6 @@
plugins { plugins {
`kotlin-dsl` `kotlin-dsl`
`version-catalog` `version-catalog`
kotlin("plugin.serialization") version "1.6.21"
} }
java.targetCompatibility = JavaVersion.VERSION_11 java.targetCompatibility = JavaVersion.VERSION_11
@ -13,18 +12,18 @@ repositories {
gradlePluginPortal() gradlePluginPortal()
} }
val toolsVersion = npmlibs.versions.tools.get() val toolsVersion = spclibs.versions.tools.get()
val kotlinVersion = npmlibs.versions.kotlin.asProvider().get() val kotlinVersion = spclibs.versions.kotlin.asProvider().get()
val benchmarksVersion = npmlibs.versions.kotlinx.benchmark.get() val benchmarksVersion = spclibs.versions.kotlinx.benchmark.get()
dependencies { dependencies {
api("space.kscience:gradle-tools:$toolsVersion") api("space.kscience:gradle-tools:$toolsVersion")
api(npmlibs.atomicfu.gradle)
//plugins form benchmarks //plugins form benchmarks
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion") api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion") //api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
//to be used inside build-script only //to be used inside build-script only
implementation(npmlibs.kotlinx.serialization.json) //implementation(spclibs.kotlinx.serialization.json)
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.14.+")
} }
kotlin.sourceSets.all { kotlin.sourceSets.all {

View File

@ -26,7 +26,7 @@ dependencyResolutionManagement {
} }
versionCatalogs { versionCatalogs {
create("npmlibs") { create("spclibs") {
from("space.kscience:version-catalog:$toolsVersion") from("space.kscience:version-catalog:$toolsVersion")
} }
} }

View File

@ -5,9 +5,6 @@
package space.kscience.kmath.benchmarks package space.kscience.kmath.benchmarks
import kotlinx.serialization.Serializable
@Serializable
data class JmhReport( data class JmhReport(
val jmhVersion: String, val jmhVersion: String,
val benchmark: String, val benchmark: String,
@ -37,7 +34,6 @@ data class JmhReport(
val scoreUnit: String val scoreUnit: String
} }
@Serializable
data class PrimaryMetric( data class PrimaryMetric(
override val score: Double, override val score: Double,
override val scoreError: Double, override val scoreError: Double,
@ -48,7 +44,6 @@ data class JmhReport(
val rawData: List<List<Double>>? = null, val rawData: List<List<Double>>? = null,
) : Metric ) : Metric
@Serializable
data class SecondaryMetric( data class SecondaryMetric(
override val score: Double, override val score: Double,
override val scoreError: Double, override val scoreError: Double,

View File

@ -6,8 +6,8 @@
package space.kscience.kmath.benchmarks package space.kscience.kmath.benchmarks
import kotlinx.benchmark.gradle.BenchmarksExtension import kotlinx.benchmark.gradle.BenchmarksExtension
import kotlinx.serialization.decodeFromString import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import kotlinx.serialization.json.Json import com.fasterxml.jackson.module.kotlin.readValue
import org.gradle.api.Project import org.gradle.api.Project
import space.kscience.gradle.KScienceReadmeExtension import space.kscience.gradle.KScienceReadmeExtension
import java.time.LocalDateTime import java.time.LocalDateTime
@ -45,6 +45,8 @@ private val ISO_DATE_TIME: DateTimeFormatter = DateTimeFormatterBuilder().run {
private fun noun(number: Number, singular: String, plural: String) = if (number.toLong() == 1L) singular else plural private fun noun(number: Number, singular: String, plural: String) = if (number.toLong() == 1L) singular else plural
private val jsonMapper = jacksonObjectMapper()
fun Project.addBenchmarkProperties() { fun Project.addBenchmarkProperties() {
val benchmarksProject = this val benchmarksProject = this
rootProject.subprojects.forEach { p -> rootProject.subprojects.forEach { p ->
@ -60,8 +62,7 @@ fun Project.addBenchmarkProperties() {
if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) { if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) {
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**." "> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
} else { } else {
val reports = val reports: List<JmhReport> = jsonMapper.readValue<List<JmhReport>>(resDirectory.resolve("jvm.json"))
Json.decodeFromString<List<JmhReport>>(resDirectory.resolve("jvm.json").readText())
buildString { buildString {
appendLine("<details>") appendLine("<details>")

View File

@ -7,9 +7,9 @@ package space.kscience.kmath.fit
import kotlinx.html.br import kotlinx.html.br
import kotlinx.html.h3 import kotlinx.html.h3
import space.kscience.kmath.commons.expressions.DSProcessor
import space.kscience.kmath.commons.optimization.CMOptimizer import space.kscience.kmath.commons.optimization.CMOptimizer
import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.expressions.autodiff
import space.kscience.kmath.expressions.chiSquaredExpression import space.kscience.kmath.expressions.chiSquaredExpression
import space.kscience.kmath.expressions.symbol import space.kscience.kmath.expressions.symbol
import space.kscience.kmath.operations.asIterable import space.kscience.kmath.operations.asIterable
@ -67,7 +67,7 @@ suspend fun main() {
val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma) val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma)
// compute differentiable chi^2 sum for given model ax^2 + bx + c // compute differentiable chi^2 sum for given model ax^2 + bx + c
val chi2 = DSProcessor.chiSquaredExpression(x, y, yErr) { arg -> val chi2 = Double.autodiff.chiSquaredExpression(x, y, yErr) { arg ->
//bind variables to autodiff context //bind variables to autodiff context
val a = bindSymbol(a) val a = bindSymbol(a)
val b = bindSymbol(b) val b = bindSymbol(b)

View File

@ -7,10 +7,10 @@ package space.kscience.kmath.fit
import kotlinx.html.br import kotlinx.html.br
import kotlinx.html.h3 import kotlinx.html.h3
import space.kscience.kmath.commons.expressions.DSProcessor
import space.kscience.kmath.data.XYErrorColumnarData import space.kscience.kmath.data.XYErrorColumnarData
import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.autodiff
import space.kscience.kmath.expressions.binding import space.kscience.kmath.expressions.binding
import space.kscience.kmath.expressions.symbol import space.kscience.kmath.expressions.symbol
import space.kscience.kmath.operations.asIterable import space.kscience.kmath.operations.asIterable
@ -63,7 +63,7 @@ suspend fun main() {
val result = XYErrorColumnarData.of(x, y, yErr).fitWith( val result = XYErrorColumnarData.of(x, y, yErr).fitWith(
QowOptimizer, QowOptimizer,
DSProcessor, Double.autodiff,
mapOf(a to 0.9, b to 1.2, c to 2.0) mapOf(a to 0.9, b to 1.2, c to 2.0)
) { arg -> ) { arg ->
//bind variables to autodiff context //bind variables to autodiff context

View File

@ -36,7 +36,7 @@ private suspend fun runKMathChained(): Duration {
return Duration.between(startTime, Instant.now()) return Duration.between(startTime, Instant.now())
} }
private fun runApacheDirect(): Duration { private fun runCMDirect(): Duration {
val rng = RandomSource.create(RandomSource.MT, 123L) val rng = RandomSource.create(RandomSource.MT, 123L)
val sampler = CMGaussianSampler.of( val sampler = CMGaussianSampler.of(
@ -65,7 +65,7 @@ private fun runApacheDirect(): Duration {
* Comparing chain sampling performance with direct sampling performance * Comparing chain sampling performance with direct sampling performance
*/ */
fun main(): Unit = runBlocking(Dispatchers.Default) { fun main(): Unit = runBlocking(Dispatchers.Default) {
val directJob = async { runApacheDirect() } val directJob = async { runCMDirect() }
val chainJob = async { runKMathChained() } val chainJob = async { runKMathChained() }
println("KMath Chained: ${chainJob.await()}") println("KMath Chained: ${chainJob.await()}")
println("Apache Direct: ${directJob.await()}") println("Apache Direct: ${directJob.await()}")

View File

@ -5,12 +5,11 @@
kotlin.code.style=official kotlin.code.style=official
kotlin.mpp.stability.nowarn=true kotlin.mpp.stability.nowarn=true
kotlin.native.ignoreDisabledTargets=true kotlin.native.ignoreDisabledTargets=true
kotlin.incremental.js.ir=true
org.gradle.configureondemand=true org.gradle.configureondemand=true
org.gradle.jvmargs=-Xmx4096m org.gradle.jvmargs=-Xmx4096m
toolsVersion=0.13.1-kotlin-1.7.20 toolsVersion=0.13.4-kotlin-1.8.0
org.gradle.parallel=true org.gradle.parallel=true

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -18,7 +18,8 @@ import space.kscience.kmath.operations.NumbersAddOps
* @param bindings The map of bindings values. All bindings are considered free parameters * @param bindings The map of bindings values. All bindings are considered free parameters
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class DerivativeStructureField( @Deprecated("Use generic DSAlgebra from the core")
public class CmDsField(
public val order: Int, public val order: Int,
bindings: Map<Symbol, Double>, bindings: Map<Symbol, Double>,
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>, ) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
@ -108,25 +109,27 @@ public class DerivativeStructureField(
/** /**
* Auto-diff processor based on Commons-math [DerivativeStructure] * Auto-diff processor based on Commons-math [DerivativeStructure]
*/ */
public object DSProcessor : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> { @Deprecated("Use generic DSAlgebra from the core")
public object CmDsProcessor : AutoDiffProcessor<Double, DerivativeStructure, CmDsField> {
override fun differentiate( override fun differentiate(
function: DerivativeStructureField.() -> DerivativeStructure, function: CmDsField.() -> DerivativeStructure,
): DerivativeStructureExpression = DerivativeStructureExpression(function) ): CmDsExpression = CmDsExpression(function)
} }
/** /**
* A constructs that creates a derivative structure with required order on-demand * A constructs that creates a derivative structure with required order on-demand
*/ */
public class DerivativeStructureExpression( @Deprecated("Use generic DSAlgebra from the core")
public val function: DerivativeStructureField.() -> DerivativeStructure, public class CmDsExpression(
public val function: CmDsField.() -> DerivativeStructure,
) : DifferentiableExpression<Double> { ) : DifferentiableExpression<Double> {
override operator fun invoke(arguments: Map<Symbol, Double>): Double = override operator fun invoke(arguments: Map<Symbol, Double>): Double =
DerivativeStructureField(0, arguments).function().value CmDsField(0, arguments).function().value
/** /**
* Get the derivative expression with given orders * Get the derivative expression with given orders
*/ */
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments -> override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) } with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) }
} }
} }

View File

@ -3,6 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:Suppress("DEPRECATION")
package space.kscience.kmath.commons.expressions package space.kscience.kmath.commons.expressions
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
@ -15,10 +17,10 @@ import kotlin.test.assertFails
internal inline fun diff( internal inline fun diff(
order: Int, order: Int,
vararg parameters: Pair<Symbol, Double>, vararg parameters: Pair<Symbol, Double>,
block: DerivativeStructureField.() -> Unit, block: CmDsField.() -> Unit,
) { ) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
DerivativeStructureField(order, mapOf(*parameters)).run(block) CmDsField(order, mapOf(*parameters)).run(block)
} }
internal class AutoDiffTest { internal class AutoDiffTest {
@ -41,7 +43,7 @@ internal class AutoDiffTest {
@Test @Test
fun autoDifTest() { fun autoDifTest() {
val f = DerivativeStructureExpression { val f = CmDsExpression {
val x by binding val x by binding
val y by binding val y by binding
x.pow(2) + 2 * x * y + y.pow(2) + 1 x.pow(2) + 2 * x * y + y.pow(2) + 1

View File

@ -6,22 +6,25 @@
package space.kscience.kmath.commons.optimization package space.kscience.kmath.commons.optimization
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import space.kscience.kmath.commons.expressions.DSProcessor
import space.kscience.kmath.commons.expressions.DerivativeStructureExpression
import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.expressions.DSFieldExpression
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.Symbol.Companion.y import space.kscience.kmath.expressions.Symbol.Companion.y
import space.kscience.kmath.expressions.autodiff
import space.kscience.kmath.expressions.chiSquaredExpression import space.kscience.kmath.expressions.chiSquaredExpression
import space.kscience.kmath.expressions.symbol import space.kscience.kmath.expressions.symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleBufferOps.Companion.map import space.kscience.kmath.operations.DoubleBufferOps.Companion.map
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.optimization.* import space.kscience.kmath.optimization.*
import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.asBuffer
import kotlin.test.Test import kotlin.test.Test
@OptIn(UnstableKMathAPI::class)
internal class OptimizeTest { internal class OptimizeTest {
val normal = DerivativeStructureExpression { val normal = DSFieldExpression(DoubleField) {
exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y).pow(2) / 2) exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y).pow(2) / 2)
} }
@ -60,7 +63,7 @@ internal class OptimizeTest {
val yErr = DoubleBuffer(x.size) { sigma } val yErr = DoubleBuffer(x.size) { sigma }
val chi2 = DSProcessor.chiSquaredExpression( val chi2 = Double.autodiff.chiSquaredExpression(
x, y, yErr x, y, yErr
) { arg -> ) { arg ->
val cWithDefault = bindSymbolOrNull(c) ?: one val cWithDefault = bindSymbolOrNull(c) ?: one

View File

@ -80,7 +80,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
public val algebra: A, public val algebra: A,
public val order: Int, public val order: Int,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
public val valueBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer { ) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
/** /**
@ -116,7 +115,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
newCache[p][o] = DSCompiler( newCache[p][o] = DSCompiler(
algebra, algebra,
valueBufferFactory,
p, p,
o, o,
valueCompiler, valueCompiler,
@ -141,7 +139,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
override val symbols: List<Symbol> = bindings.map { it.key } override val symbols: List<Symbol> = bindings.map { it.key }
private fun bufferForVariable(index: Int, value: T): Buffer<T> { private fun bufferForVariable(index: Int, value: T): Buffer<T> {
val buffer = valueBufferFactory(compiler.size) { algebra.zero } val buffer = algebra.bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value buffer[0] = value
if (compiler.order > 0) { if (compiler.order > 0) {
// the derivative of the variable with respect to itself is 1. // the derivative of the variable with respect to itself is 1.
@ -207,7 +205,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
} }
public override fun const(value: T): DS<T, A> { public override fun const(value: T): DS<T, A> {
val buffer = valueBufferFactory(compiler.size) { algebra.zero } val buffer = algebra.bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value buffer[0] = value
return DS(buffer) return DS(buffer)
@ -245,11 +243,14 @@ public open class DSRing<T, A>(
algebra: A, algebra: A,
order: Int, order: Int,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
valueBufferFactory: MutableBufferFactory<T>, ) : DSAlgebra<T, A>(algebra, order, bindings),
) : DSAlgebra<T, A>(algebra, order, bindings, valueBufferFactory), Ring<DS<T, A>>,
Ring<DS<T, A>>, ScaleOperations<DS<T, A>>, ScaleOperations<DS<T, A>>,
NumericAlgebra<DS<T, A>>, NumericAlgebra<DS<T, A>>,
NumbersAddOps<DS<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> { NumbersAddOps<DS<T, A>>
where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
public val elementBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory
override fun bindSymbolOrNull(value: String): DSSymbol? = override fun bindSymbolOrNull(value: String): DSSymbol? =
super<DSAlgebra>.bindSymbolOrNull(value) super<DSAlgebra>.bindSymbolOrNull(value)
@ -261,14 +262,14 @@ public open class DSRing<T, A>(
*/ */
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> { protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" } require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
val newData = valueBufferFactory(compiler.size) { data[it] } val newData = elementBufferFactory(compiler.size) { data[it] }
algebra.block(newData) algebra.block(newData)
return DS(newData) return DS(newData)
} }
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> { protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" } require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
val newData: Buffer<T> = data.mapToBuffer(valueBufferFactory) { val newData: Buffer<T> = data.mapToBuffer(elementBufferFactory) {
algebra.block(it) algebra.block(it)
} }
return DS(newData) return DS(newData)
@ -276,7 +277,7 @@ public open class DSRing<T, A>(
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> { protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" } require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
val newData: Buffer<T> = data.mapIndexedToBuffer(valueBufferFactory, block) val newData: Buffer<T> = data.mapIndexedToBuffer(elementBufferFactory, block)
return DS(newData) return DS(newData)
} }
@ -329,22 +330,21 @@ public class DerivativeStructureRingExpression<T, A>(
public val function: DSRing<T, A>.() -> DS<T, A>, public val function: DSRing<T, A>.() -> DS<T, A>,
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> { ) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
override operator fun invoke(arguments: Map<Symbol, T>): T = override operator fun invoke(arguments: Map<Symbol, T>): T =
DSRing(algebra, 0, arguments, elementBufferFactory).function().value DSRing(algebra, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments -> override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
with( with(
DSRing( DSRing(
algebra, algebra,
symbols.size, symbols.size,
arguments, arguments
elementBufferFactory
) )
) { function().derivative(symbols) } ) { function().derivative(symbols) }
} }
} }
/** /**
* A field over commons-math [DerivativeStructure]. * A field over [DS].
* *
* @property order The derivation order. * @property order The derivation order.
* @param bindings The map of bindings values. All bindings are considered free parameters. * @param bindings The map of bindings values. All bindings are considered free parameters.
@ -354,8 +354,7 @@ public class DSField<T, A : ExtendedField<T>>(
algebra: A, algebra: A,
order: Int, order: Int,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
valueBufferFactory: MutableBufferFactory<T>, ) : DSRing<T, A>(algebra, order, bindings), ExtendedField<DS<T, A>> {
) : DSRing<T, A>(algebra, order, bindings, valueBufferFactory), ExtendedField<DS<T, A>> {
override fun number(value: Number): DS<T, A> = const(algebra.number(value)) override fun number(value: Number): DS<T, A> = const(algebra.number(value))
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result -> override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
@ -414,6 +413,7 @@ public class DSField<T, A : ExtendedField<T>>(
is Int -> arg.transformDataBuffer { result -> is Int -> arg.transformDataBuffer { result ->
compiler.pow(arg.data, 0, pow, result, 0) compiler.pow(arg.data, 0, pow, result, 0)
} }
else -> arg.transformDataBuffer { result -> else -> arg.transformDataBuffer { result ->
compiler.pow(arg.data, 0, pow.toDouble(), result, 0) compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
} }
@ -439,18 +439,29 @@ public class DSField<T, A : ExtendedField<T>>(
@UnstableKMathAPI @UnstableKMathAPI
public class DSFieldExpression<T, A : ExtendedField<T>>( public class DSFieldExpression<T, A : ExtendedField<T>>(
public val algebra: A, public val algebra: A,
private val valueBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
public val function: DSField<T, A>.() -> DS<T, A>, public val function: DSField<T, A>.() -> DS<T, A>,
) : DifferentiableExpression<T> { ) : DifferentiableExpression<T> {
override operator fun invoke(arguments: Map<Symbol, T>): T = override operator fun invoke(arguments: Map<Symbol, T>): T =
DSField(algebra, 0, arguments, valueBufferFactory).function().value DSField(algebra, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments -> override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
DSField( DSField(
algebra, algebra,
symbols.size, symbols.size,
arguments, arguments,
valueBufferFactory,
).run { function().derivative(symbols) } ).run { function().derivative(symbols) }
} }
} }
@UnstableKMathAPI
public class DSFieldProcessor<T, A : ExtendedField<T>>(
public val algebra: A,
) : AutoDiffProcessor<T, DS<T, A>, DSField<T, A>> {
override fun differentiate(
function: DSField<T, A>.() -> DS<T, A>,
): DifferentiableExpression<T> = DSFieldExpression(algebra, function)
}
@UnstableKMathAPI
public val Double.Companion.autodiff: DSFieldProcessor<Double, DoubleField> get() = DSFieldProcessor(DoubleField)

View File

@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableBufferFactory
import kotlin.math.min import kotlin.math.min
internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) { internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) {
@ -56,7 +55,6 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
*/ */
public class DSCompiler<T, out A : Algebra<T>> internal constructor( public class DSCompiler<T, out A : Algebra<T>> internal constructor(
public val algebra: A, public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val freeParameters: Int, public val freeParameters: Int,
public val order: Int, public val order: Int,
valueCompiler: DSCompiler<T, A>?, valueCompiler: DSCompiler<T, A>?,

View File

@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.test.Test import kotlin.test.Test
@ -22,7 +21,7 @@ internal inline fun diff(
block: DSField<Double, DoubleField>.() -> Unit, block: DSField<Double, DoubleField>.() -> Unit,
) { ) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
DSField(DoubleField, order, mapOf(*parameters), ::DoubleBuffer).block() DSField(DoubleField, order, mapOf(*parameters)).block()
} }
internal class DSTest { internal class DSTest {
@ -45,7 +44,7 @@ internal class DSTest {
@Test @Test
fun dsExpressionTest() { fun dsExpressionTest() {
val f = DSFieldExpression(DoubleField, ::DoubleBuffer) { val f = DSFieldExpression(DoubleField) {
val x by binding val x by binding
val y by binding val y by binding
x.pow(2) + 2 * x * y + y.pow(2) + 1 x.pow(2) + 2 * x * y + y.pow(2) + 1

View File

@ -17,7 +17,7 @@ kotlin.sourceSets {
} }
dependencies { dependencies {
dokkaPlugin("org.jetbrains.dokka:mathjax-plugin:${npmlibs.versions.dokka.get()}") dokkaPlugin("org.jetbrains.dokka:mathjax-plugin:${spclibs.versions.dokka.get()}")
} }
readme { readme {

View File

@ -12,7 +12,7 @@ kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
api(npmlibs.atomicfu) api(spclibs.atomicfu)
} }
} }
commonTest { commonTest {

View File

@ -14,7 +14,7 @@ kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-coroutines")) api(project(":kmath-coroutines"))
api(npmlibs.atomicfu) api(spclibs.atomicfu)
} }
} }
} }

View File

@ -10,7 +10,7 @@ kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-coroutines")) api(project(":kmath-coroutines"))
implementation(npmlibs.atomicfu) implementation(spclibs.atomicfu)
} }
} }

View File

@ -13,7 +13,7 @@ dependencyResolutionManagement {
} }
versionCatalogs { versionCatalogs {
create("npmlibs") { create("spclibs") {
from("space.kscience:version-catalog:$toolsVersion") from("space.kscience:version-catalog:$toolsVersion")
} }
} }