Migrate to 1.8. Use universal autodiffs
This commit is contained in:
parent
6d47c0ccec
commit
3f4fe9e43b
@ -5,7 +5,7 @@ import space.kscience.kmath.benchmarks.addBenchmarkProperties
|
||||
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
kotlin("plugin.allopen")
|
||||
alias(spclibs.plugins.kotlin.plugin.allopen)
|
||||
id("org.jetbrains.kotlinx.benchmark")
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ kotlin {
|
||||
implementation(project(":kmath-tensors"))
|
||||
implementation(project(":kmath-multik"))
|
||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||
implementation(npmlibs.kotlinx.benchmark.runtime)
|
||||
implementation(spclibs.kotlinx.benchmark.runtime)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ allprojects {
|
||||
}
|
||||
|
||||
group = "space.kscience"
|
||||
version = "0.3.1-dev-7"
|
||||
version = "0.3.1-dev-8"
|
||||
}
|
||||
|
||||
subprojects {
|
||||
|
@ -1,7 +1,6 @@
|
||||
plugins {
|
||||
`kotlin-dsl`
|
||||
`version-catalog`
|
||||
kotlin("plugin.serialization") version "1.6.21"
|
||||
}
|
||||
|
||||
java.targetCompatibility = JavaVersion.VERSION_11
|
||||
@ -13,18 +12,18 @@ repositories {
|
||||
gradlePluginPortal()
|
||||
}
|
||||
|
||||
val toolsVersion = npmlibs.versions.tools.get()
|
||||
val kotlinVersion = npmlibs.versions.kotlin.asProvider().get()
|
||||
val benchmarksVersion = npmlibs.versions.kotlinx.benchmark.get()
|
||||
val toolsVersion = spclibs.versions.tools.get()
|
||||
val kotlinVersion = spclibs.versions.kotlin.asProvider().get()
|
||||
val benchmarksVersion = spclibs.versions.kotlinx.benchmark.get()
|
||||
|
||||
dependencies {
|
||||
api("space.kscience:gradle-tools:$toolsVersion")
|
||||
api(npmlibs.atomicfu.gradle)
|
||||
//plugins form benchmarks
|
||||
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
|
||||
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
||||
//api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
||||
//to be used inside build-script only
|
||||
implementation(npmlibs.kotlinx.serialization.json)
|
||||
//implementation(spclibs.kotlinx.serialization.json)
|
||||
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.14.+")
|
||||
}
|
||||
|
||||
kotlin.sourceSets.all {
|
||||
|
@ -26,7 +26,7 @@ dependencyResolutionManagement {
|
||||
}
|
||||
|
||||
versionCatalogs {
|
||||
create("npmlibs") {
|
||||
create("spclibs") {
|
||||
from("space.kscience:version-catalog:$toolsVersion")
|
||||
}
|
||||
}
|
||||
|
@ -5,9 +5,6 @@
|
||||
|
||||
package space.kscience.kmath.benchmarks
|
||||
|
||||
import kotlinx.serialization.Serializable
|
||||
|
||||
@Serializable
|
||||
data class JmhReport(
|
||||
val jmhVersion: String,
|
||||
val benchmark: String,
|
||||
@ -37,7 +34,6 @@ data class JmhReport(
|
||||
val scoreUnit: String
|
||||
}
|
||||
|
||||
@Serializable
|
||||
data class PrimaryMetric(
|
||||
override val score: Double,
|
||||
override val scoreError: Double,
|
||||
@ -48,7 +44,6 @@ data class JmhReport(
|
||||
val rawData: List<List<Double>>? = null,
|
||||
) : Metric
|
||||
|
||||
@Serializable
|
||||
data class SecondaryMetric(
|
||||
override val score: Double,
|
||||
override val scoreError: Double,
|
||||
|
@ -6,8 +6,8 @@
|
||||
package space.kscience.kmath.benchmarks
|
||||
|
||||
import kotlinx.benchmark.gradle.BenchmarksExtension
|
||||
import kotlinx.serialization.decodeFromString
|
||||
import kotlinx.serialization.json.Json
|
||||
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
|
||||
import com.fasterxml.jackson.module.kotlin.readValue
|
||||
import org.gradle.api.Project
|
||||
import space.kscience.gradle.KScienceReadmeExtension
|
||||
import java.time.LocalDateTime
|
||||
@ -45,6 +45,8 @@ private val ISO_DATE_TIME: DateTimeFormatter = DateTimeFormatterBuilder().run {
|
||||
|
||||
private fun noun(number: Number, singular: String, plural: String) = if (number.toLong() == 1L) singular else plural
|
||||
|
||||
private val jsonMapper = jacksonObjectMapper()
|
||||
|
||||
fun Project.addBenchmarkProperties() {
|
||||
val benchmarksProject = this
|
||||
rootProject.subprojects.forEach { p ->
|
||||
@ -60,8 +62,7 @@ fun Project.addBenchmarkProperties() {
|
||||
if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) {
|
||||
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
|
||||
} else {
|
||||
val reports =
|
||||
Json.decodeFromString<List<JmhReport>>(resDirectory.resolve("jvm.json").readText())
|
||||
val reports: List<JmhReport> = jsonMapper.readValue<List<JmhReport>>(resDirectory.resolve("jvm.json"))
|
||||
|
||||
buildString {
|
||||
appendLine("<details>")
|
||||
|
@ -7,9 +7,9 @@ package space.kscience.kmath.fit
|
||||
|
||||
import kotlinx.html.br
|
||||
import kotlinx.html.h3
|
||||
import space.kscience.kmath.commons.expressions.DSProcessor
|
||||
import space.kscience.kmath.commons.optimization.CMOptimizer
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.expressions.autodiff
|
||||
import space.kscience.kmath.expressions.chiSquaredExpression
|
||||
import space.kscience.kmath.expressions.symbol
|
||||
import space.kscience.kmath.operations.asIterable
|
||||
@ -67,7 +67,7 @@ suspend fun main() {
|
||||
val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma)
|
||||
|
||||
// compute differentiable chi^2 sum for given model ax^2 + bx + c
|
||||
val chi2 = DSProcessor.chiSquaredExpression(x, y, yErr) { arg ->
|
||||
val chi2 = Double.autodiff.chiSquaredExpression(x, y, yErr) { arg ->
|
||||
//bind variables to autodiff context
|
||||
val a = bindSymbol(a)
|
||||
val b = bindSymbol(b)
|
||||
|
@ -7,10 +7,10 @@ package space.kscience.kmath.fit
|
||||
|
||||
import kotlinx.html.br
|
||||
import kotlinx.html.h3
|
||||
import space.kscience.kmath.commons.expressions.DSProcessor
|
||||
import space.kscience.kmath.data.XYErrorColumnarData
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.expressions.autodiff
|
||||
import space.kscience.kmath.expressions.binding
|
||||
import space.kscience.kmath.expressions.symbol
|
||||
import space.kscience.kmath.operations.asIterable
|
||||
@ -63,7 +63,7 @@ suspend fun main() {
|
||||
|
||||
val result = XYErrorColumnarData.of(x, y, yErr).fitWith(
|
||||
QowOptimizer,
|
||||
DSProcessor,
|
||||
Double.autodiff,
|
||||
mapOf(a to 0.9, b to 1.2, c to 2.0)
|
||||
) { arg ->
|
||||
//bind variables to autodiff context
|
||||
|
@ -36,7 +36,7 @@ private suspend fun runKMathChained(): Duration {
|
||||
return Duration.between(startTime, Instant.now())
|
||||
}
|
||||
|
||||
private fun runApacheDirect(): Duration {
|
||||
private fun runCMDirect(): Duration {
|
||||
val rng = RandomSource.create(RandomSource.MT, 123L)
|
||||
|
||||
val sampler = CMGaussianSampler.of(
|
||||
@ -65,7 +65,7 @@ private fun runApacheDirect(): Duration {
|
||||
* Comparing chain sampling performance with direct sampling performance
|
||||
*/
|
||||
fun main(): Unit = runBlocking(Dispatchers.Default) {
|
||||
val directJob = async { runApacheDirect() }
|
||||
val directJob = async { runCMDirect() }
|
||||
val chainJob = async { runKMathChained() }
|
||||
println("KMath Chained: ${chainJob.await()}")
|
||||
println("Apache Direct: ${directJob.await()}")
|
||||
|
@ -5,12 +5,11 @@
|
||||
kotlin.code.style=official
|
||||
kotlin.mpp.stability.nowarn=true
|
||||
kotlin.native.ignoreDisabledTargets=true
|
||||
kotlin.incremental.js.ir=true
|
||||
|
||||
org.gradle.configureondemand=true
|
||||
org.gradle.jvmargs=-Xmx4096m
|
||||
|
||||
toolsVersion=0.13.1-kotlin-1.7.20
|
||||
toolsVersion=0.13.4-kotlin-1.8.0
|
||||
|
||||
|
||||
org.gradle.parallel=true
|
||||
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
@ -18,7 +18,8 @@ import space.kscience.kmath.operations.NumbersAddOps
|
||||
* @param bindings The map of bindings values. All bindings are considered free parameters
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class DerivativeStructureField(
|
||||
@Deprecated("Use generic DSAlgebra from the core")
|
||||
public class CmDsField(
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, Double>,
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
||||
@ -108,25 +109,27 @@ public class DerivativeStructureField(
|
||||
/**
|
||||
* Auto-diff processor based on Commons-math [DerivativeStructure]
|
||||
*/
|
||||
public object DSProcessor : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
||||
@Deprecated("Use generic DSAlgebra from the core")
|
||||
public object CmDsProcessor : AutoDiffProcessor<Double, DerivativeStructure, CmDsField> {
|
||||
override fun differentiate(
|
||||
function: DerivativeStructureField.() -> DerivativeStructure,
|
||||
): DerivativeStructureExpression = DerivativeStructureExpression(function)
|
||||
function: CmDsField.() -> DerivativeStructure,
|
||||
): CmDsExpression = CmDsExpression(function)
|
||||
}
|
||||
|
||||
/**
|
||||
* A constructs that creates a derivative structure with required order on-demand
|
||||
*/
|
||||
public class DerivativeStructureExpression(
|
||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||
@Deprecated("Use generic DSAlgebra from the core")
|
||||
public class CmDsExpression(
|
||||
public val function: CmDsField.() -> DerivativeStructure,
|
||||
) : DifferentiableExpression<Double> {
|
||||
override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||
DerivativeStructureField(0, arguments).function().value
|
||||
CmDsField(0, arguments).function().value
|
||||
|
||||
/**
|
||||
* Get the derivative expression with given orders
|
||||
*/
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
||||
with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) }
|
||||
with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) }
|
||||
}
|
||||
}
|
@ -3,6 +3,8 @@
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
@file:Suppress("DEPRECATION")
|
||||
|
||||
package space.kscience.kmath.commons.expressions
|
||||
|
||||
import space.kscience.kmath.expressions.*
|
||||
@ -15,10 +17,10 @@ import kotlin.test.assertFails
|
||||
internal inline fun diff(
|
||||
order: Int,
|
||||
vararg parameters: Pair<Symbol, Double>,
|
||||
block: DerivativeStructureField.() -> Unit,
|
||||
block: CmDsField.() -> Unit,
|
||||
) {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||
CmDsField(order, mapOf(*parameters)).run(block)
|
||||
}
|
||||
|
||||
internal class AutoDiffTest {
|
||||
@ -41,7 +43,7 @@ internal class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun autoDifTest() {
|
||||
val f = DerivativeStructureExpression {
|
||||
val f = CmDsExpression {
|
||||
val x by binding
|
||||
val y by binding
|
||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||
|
@ -6,22 +6,25 @@
|
||||
package space.kscience.kmath.commons.optimization
|
||||
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import space.kscience.kmath.commons.expressions.DSProcessor
|
||||
import space.kscience.kmath.commons.expressions.DerivativeStructureExpression
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.expressions.DSFieldExpression
|
||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||
import space.kscience.kmath.expressions.Symbol.Companion.y
|
||||
import space.kscience.kmath.expressions.autodiff
|
||||
import space.kscience.kmath.expressions.chiSquaredExpression
|
||||
import space.kscience.kmath.expressions.symbol
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleBufferOps.Companion.map
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.optimization.*
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import space.kscience.kmath.structures.asBuffer
|
||||
import kotlin.test.Test
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
internal class OptimizeTest {
|
||||
val normal = DerivativeStructureExpression {
|
||||
val normal = DSFieldExpression(DoubleField) {
|
||||
exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y).pow(2) / 2)
|
||||
}
|
||||
|
||||
@ -60,7 +63,7 @@ internal class OptimizeTest {
|
||||
|
||||
val yErr = DoubleBuffer(x.size) { sigma }
|
||||
|
||||
val chi2 = DSProcessor.chiSquaredExpression(
|
||||
val chi2 = Double.autodiff.chiSquaredExpression(
|
||||
x, y, yErr
|
||||
) { arg ->
|
||||
val cWithDefault = bindSymbolOrNull(c) ?: one
|
||||
|
@ -80,7 +80,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
public val algebra: A,
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
public val valueBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
|
||||
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
||||
|
||||
/**
|
||||
@ -116,7 +115,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
|
||||
newCache[p][o] = DSCompiler(
|
||||
algebra,
|
||||
valueBufferFactory,
|
||||
p,
|
||||
o,
|
||||
valueCompiler,
|
||||
@ -141,7 +139,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
override val symbols: List<Symbol> = bindings.map { it.key }
|
||||
|
||||
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
||||
val buffer = valueBufferFactory(compiler.size) { algebra.zero }
|
||||
val buffer = algebra.bufferFactory(compiler.size) { algebra.zero }
|
||||
buffer[0] = value
|
||||
if (compiler.order > 0) {
|
||||
// the derivative of the variable with respect to itself is 1.
|
||||
@ -207,7 +205,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
}
|
||||
|
||||
public override fun const(value: T): DS<T, A> {
|
||||
val buffer = valueBufferFactory(compiler.size) { algebra.zero }
|
||||
val buffer = algebra.bufferFactory(compiler.size) { algebra.zero }
|
||||
buffer[0] = value
|
||||
|
||||
return DS(buffer)
|
||||
@ -245,11 +243,14 @@ public open class DSRing<T, A>(
|
||||
algebra: A,
|
||||
order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
valueBufferFactory: MutableBufferFactory<T>,
|
||||
) : DSAlgebra<T, A>(algebra, order, bindings, valueBufferFactory),
|
||||
Ring<DS<T, A>>, ScaleOperations<DS<T, A>>,
|
||||
) : DSAlgebra<T, A>(algebra, order, bindings),
|
||||
Ring<DS<T, A>>,
|
||||
ScaleOperations<DS<T, A>>,
|
||||
NumericAlgebra<DS<T, A>>,
|
||||
NumbersAddOps<DS<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||
NumbersAddOps<DS<T, A>>
|
||||
where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||
|
||||
public val elementBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory
|
||||
|
||||
override fun bindSymbolOrNull(value: String): DSSymbol? =
|
||||
super<DSAlgebra>.bindSymbolOrNull(value)
|
||||
@ -261,14 +262,14 @@ public open class DSRing<T, A>(
|
||||
*/
|
||||
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
|
||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||
val newData = valueBufferFactory(compiler.size) { data[it] }
|
||||
val newData = elementBufferFactory(compiler.size) { data[it] }
|
||||
algebra.block(newData)
|
||||
return DS(newData)
|
||||
}
|
||||
|
||||
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
|
||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||
val newData: Buffer<T> = data.mapToBuffer(valueBufferFactory) {
|
||||
val newData: Buffer<T> = data.mapToBuffer(elementBufferFactory) {
|
||||
algebra.block(it)
|
||||
}
|
||||
return DS(newData)
|
||||
@ -276,7 +277,7 @@ public open class DSRing<T, A>(
|
||||
|
||||
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
|
||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||
val newData: Buffer<T> = data.mapIndexedToBuffer(valueBufferFactory, block)
|
||||
val newData: Buffer<T> = data.mapIndexedToBuffer(elementBufferFactory, block)
|
||||
return DS(newData)
|
||||
}
|
||||
|
||||
@ -329,22 +330,21 @@ public class DerivativeStructureRingExpression<T, A>(
|
||||
public val function: DSRing<T, A>.() -> DS<T, A>,
|
||||
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||
DSRing(algebra, 0, arguments, elementBufferFactory).function().value
|
||||
DSRing(algebra, 0, arguments).function().value
|
||||
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||
with(
|
||||
DSRing(
|
||||
algebra,
|
||||
symbols.size,
|
||||
arguments,
|
||||
elementBufferFactory
|
||||
arguments
|
||||
)
|
||||
) { function().derivative(symbols) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A field over commons-math [DerivativeStructure].
|
||||
* A field over [DS].
|
||||
*
|
||||
* @property order The derivation order.
|
||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||
@ -354,8 +354,7 @@ public class DSField<T, A : ExtendedField<T>>(
|
||||
algebra: A,
|
||||
order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
valueBufferFactory: MutableBufferFactory<T>,
|
||||
) : DSRing<T, A>(algebra, order, bindings, valueBufferFactory), ExtendedField<DS<T, A>> {
|
||||
) : DSRing<T, A>(algebra, order, bindings), ExtendedField<DS<T, A>> {
|
||||
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||
|
||||
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||
@ -414,6 +413,7 @@ public class DSField<T, A : ExtendedField<T>>(
|
||||
is Int -> arg.transformDataBuffer { result ->
|
||||
compiler.pow(arg.data, 0, pow, result, 0)
|
||||
}
|
||||
|
||||
else -> arg.transformDataBuffer { result ->
|
||||
compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
|
||||
}
|
||||
@ -439,18 +439,29 @@ public class DSField<T, A : ExtendedField<T>>(
|
||||
@UnstableKMathAPI
|
||||
public class DSFieldExpression<T, A : ExtendedField<T>>(
|
||||
public val algebra: A,
|
||||
private val valueBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
|
||||
public val function: DSField<T, A>.() -> DS<T, A>,
|
||||
) : DifferentiableExpression<T> {
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||
DSField(algebra, 0, arguments, valueBufferFactory).function().value
|
||||
DSField(algebra, 0, arguments).function().value
|
||||
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||
DSField(
|
||||
algebra,
|
||||
symbols.size,
|
||||
arguments,
|
||||
valueBufferFactory,
|
||||
).run { function().derivative(symbols) }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@UnstableKMathAPI
|
||||
public class DSFieldProcessor<T, A : ExtendedField<T>>(
|
||||
public val algebra: A,
|
||||
) : AutoDiffProcessor<T, DS<T, A>, DSField<T, A>> {
|
||||
override fun differentiate(
|
||||
function: DSField<T, A>.() -> DS<T, A>,
|
||||
): DifferentiableExpression<T> = DSFieldExpression(algebra, function)
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public val Double.Companion.autodiff: DSFieldProcessor<Double, DoubleField> get() = DSFieldProcessor(DoubleField)
|
@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.MutableBufferFactory
|
||||
import kotlin.math.min
|
||||
|
||||
internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) {
|
||||
@ -56,7 +55,6 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
|
||||
*/
|
||||
public class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val freeParameters: Int,
|
||||
public val order: Int,
|
||||
valueCompiler: DSCompiler<T, A>?,
|
||||
|
@ -9,7 +9,6 @@ package space.kscience.kmath.expressions
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.test.Test
|
||||
@ -22,7 +21,7 @@ internal inline fun diff(
|
||||
block: DSField<Double, DoubleField>.() -> Unit,
|
||||
) {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
DSField(DoubleField, order, mapOf(*parameters), ::DoubleBuffer).block()
|
||||
DSField(DoubleField, order, mapOf(*parameters)).block()
|
||||
}
|
||||
|
||||
internal class DSTest {
|
||||
@ -45,7 +44,7 @@ internal class DSTest {
|
||||
|
||||
@Test
|
||||
fun dsExpressionTest() {
|
||||
val f = DSFieldExpression(DoubleField, ::DoubleBuffer) {
|
||||
val f = DSFieldExpression(DoubleField) {
|
||||
val x by binding
|
||||
val y by binding
|
||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||
|
@ -17,7 +17,7 @@ kotlin.sourceSets {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
dokkaPlugin("org.jetbrains.dokka:mathjax-plugin:${npmlibs.versions.dokka.get()}")
|
||||
dokkaPlugin("org.jetbrains.dokka:mathjax-plugin:${spclibs.versions.dokka.get()}")
|
||||
}
|
||||
|
||||
readme {
|
||||
|
@ -12,7 +12,7 @@ kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api(npmlibs.atomicfu)
|
||||
api(spclibs.atomicfu)
|
||||
}
|
||||
}
|
||||
commonTest {
|
||||
|
@ -14,7 +14,7 @@ kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-coroutines"))
|
||||
api(npmlibs.atomicfu)
|
||||
api(spclibs.atomicfu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-coroutines"))
|
||||
implementation(npmlibs.atomicfu)
|
||||
implementation(spclibs.atomicfu)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ dependencyResolutionManagement {
|
||||
}
|
||||
|
||||
versionCatalogs {
|
||||
create("npmlibs") {
|
||||
create("spclibs") {
|
||||
from("space.kscience:version-catalog:$toolsVersion")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user