diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 272501729..c593f5103 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -1,9 +1,6 @@ package kscience.kmath.commons.expressions -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.ExpressionAlgebra -import kscience.kmath.expressions.Symbol +import kscience.kmath.expressions.* import kscience.kmath.operations.ExtendedField import org.apache.commons.math3.analysis.differentiation.DerivativeStructure @@ -92,6 +89,12 @@ public class DerivativeStructureField( public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this + + public companion object : AutoDiffProcessor { + override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression { + return DerivativeStructureExpression(function) + } + } } /** diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt index a62630ed3..3143dcca5 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt @@ -71,12 +71,8 @@ public object CMFit { public fun Expression.optimize( vararg symbols: Symbol, configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) - problem.expression(this) - return problem.optimize() -} +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) + /** * Optimize differentiable expression @@ -84,12 +80,7 @@ public fun Expression.optimize( public fun DifferentiableExpression.optimize( vararg symbols: Symbol, configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) - problem.diffExpression(this) - return problem.optimize() -} +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) public fun DifferentiableExpression.minimize( vararg startPoint: Pair, diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt index 2ca907d05..0d96faaa3 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -33,7 +33,7 @@ public class CMOptimizationProblem( public fun exportOptimizationData(): List = optimizationData.values.toList() - public fun initialGuess(map: Map): Unit { + public override fun initialGuess(map: Map): Unit { addOptimizationData(InitialGuess(map.toDoubleArray())) } @@ -94,10 +94,12 @@ public class CMOptimizationProblem( return OptimizationResult(point.toMap(), value, setOf(this)) } - public companion object { + public companion object : OptimizationProblemFactory { public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 public const val DEFAULT_MAX_ITER: Int = 1000 + + override fun build(symbols: List): CMOptimizationProblem = CMOptimizationProblem(symbols) } } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt deleted file mode 100644 index a246a817b..000000000 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt +++ /dev/null @@ -1,37 +0,0 @@ -package kscience.kmath.commons.optimization - -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.Symbol - -public interface OptimizationFeature - -//TODO move to prob/stat - -public class OptimizationResult( - public val point: Map, - public val value: T, - public val features: Set = emptySet(), -){ - override fun toString(): String { - return "OptimizationResult(point=$point, value=$value)" - } -} - -/** - * A configuration builder for optimization problem - */ -public interface OptimizationProblem { - /** - * Set an objective function expression - */ - public fun expression(expression: Expression): Unit - - /** - * - */ - public fun diffExpression(expression: DifferentiableExpression): Unit - public fun update(result: OptimizationResult) - public fun optimize(): OptimizationResult -} - diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index 5fe31caca..705839b57 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -23,10 +23,9 @@ public fun DifferentiableExpression.derivative(symbol: Symbol): Expressio public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) -//public interface DifferentiableExpressionBuilder>: ExpressionBuilder { -// public override fun expression(block: A.() -> E): DifferentiableExpression -//} - +/** + * A [DifferentiableExpression] that defines only first derivatives + */ public abstract class FirstDerivativeExpression : DifferentiableExpression { public abstract fun derivativeOrNull(symbol: Symbol): Expression? @@ -35,4 +34,11 @@ public abstract class FirstDerivativeExpression : DifferentiableExpression val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null return derivativeOrNull(dSymbol) } +} + +/** + * A factory that converts an expression in autodiff variables to a [DifferentiableExpression] + */ +public interface AutoDiffProcessor> { + public fun process(function: A.() -> I): DifferentiableExpression } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index 6231a40c1..e66832fdb 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -47,7 +47,7 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point DerivationResult.grad(vararg variables: Symbol): Point> F.simpleAutoDiff( bindings: Map, - body: AutoDiffField.() -> AutoDiffValue, + body: SimpleAutoDiffField.() -> AutoDiffValue, ): DerivationResult { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } - return AutoDiffContext(this, bindings).derivate(body) + return SimpleAutoDiffField(this, bindings).derivate(body) } public fun > F.simpleAutoDiff( vararg bindings: Pair, - body: AutoDiffField.() -> AutoDiffValue, + body: SimpleAutoDiffField.() -> AutoDiffValue, ): DerivationResult = simpleAutoDiff(bindings.toMap(), body) /** * Represents field in context of which functions can be derived. */ -public abstract class AutoDiffField> - : Field>, ExpressionAlgebra> { +public open class SimpleAutoDiffField>( + public val context: F, + bindings: Map, +) : Field>, ExpressionAlgebra> { - public abstract val context: F + // this stack contains pairs of blocks and values to apply them to + private var stack: Array = arrayOfNulls(8) + private var sp: Int = 0 + private val derivatives: MutableMap, T> = hashMapOf() + + /** + * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result + * with respect to this variable. + * + * @param T the non-nullable type of value. + * @property value The value of this variable. + */ + private class AutoDiffVariableWithDerivative( + override val identity: String, + value: T, + var d: T, + ) : AutoDiffValue(value), Symbol { + override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() + } + + private val bindings: Map> = bindings.entries.associate { + it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero) + } + + override fun bindOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity] + + private fun getDerivative(variable: AutoDiffValue): T = + (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero + + private fun setDerivative(variable: AutoDiffValue, value: T) { + if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value + } + + + @Suppress("UNCHECKED_CAST") + private fun runBackwardPass() { + while (sp > 0) { + val value = stack[--sp] + val block = stack[--sp] as F.(Any?) -> Unit + context.block(value) + } + } + + override val zero: AutoDiffValue get() = const(context.zero) + override val one: AutoDiffValue get() = const(context.one) + + override fun const(value: T): AutoDiffValue = AutoDiffValue(value) /** * A variable accessing inner state of derivatives. * Use this value in inner builders to avoid creating additional derivative bindings. */ - public abstract var AutoDiffValue.d: T + public var AutoDiffValue.d: T + get() = getDerivative(this) + set(value) = setDerivative(this, value) + + public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) /** * Performs update of derivative after the rest of the formula in the back-pass. @@ -101,9 +155,22 @@ public abstract class AutoDiffField> * } * ``` */ - public abstract fun derive(value: R, block: F.(R) -> Unit): R + @Suppress("UNCHECKED_CAST") + public fun derive(value: R, block: F.(R) -> Unit): R { + // save block to stack for backward pass + if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) + stack[sp++] = block + stack[sp++] = value + return value + } - public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) + + internal fun derivate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult { + val result = function() + result.d = context.one // computing derivative w.r.t result + runBackwardPass() + return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) + } // Overloads for Double constants @@ -119,68 +186,7 @@ public abstract class AutoDiffField> override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } -} -/** - * Automatic Differentiation context class. - */ -private class AutoDiffContext>( - override val context: F, - bindings: Map, -) : AutoDiffField() { - // this stack contains pairs of blocks and values to apply them to - private var stack: Array = arrayOfNulls(8) - private var sp: Int = 0 - private val derivatives: MutableMap, T> = hashMapOf() - override val zero: AutoDiffValue get() = const(context.zero) - override val one: AutoDiffValue get() = const(context.one) - - /** - * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result - * with respect to this variable. - * - * @param T the non-nullable type of value. - * @property value The value of this variable. - */ - private class AutoDiffVariableWithDeriv( - override val identity: String, - value: T, - var d: T, - ) : AutoDiffValue(value), Symbol{ - override fun toString(): String = identity - override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity - override fun hashCode(): Int = identity.hashCode() - } - - private val bindings: Map> = bindings.entries.associate { - it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero) - } - - override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv? = bindings[symbol.identity] - - override fun const(value: T): AutoDiffValue = AutoDiffValue(value) - - override var AutoDiffValue.d: T - get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero - set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value - - @Suppress("UNCHECKED_CAST") - override fun derive(value: R, block: F.(R) -> Unit): R { - // save block to stack for backward pass - if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) - stack[sp++] = block - stack[sp++] = value - return value - } - - @Suppress("UNCHECKED_CAST") - fun runBackwardPass() { - while (sp > 0) { - val value = stack[--sp] - val block = stack[--sp] as F.(Any?) -> Unit - context.block(value) - } - } // Basic math (+, -, *, /) @@ -206,13 +212,6 @@ private class AutoDiffContext>( derive(const { k.toDouble() * a.value }) { z -> a.d += z.d * k.toDouble() } - - inline fun derivate(function: AutoDiffField.() -> AutoDiffValue): DerivationResult { - val result = function() - result.d = context.one // computing derivative w.r.t result - runBackwardPass() - return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) - } } /** @@ -220,99 +219,178 @@ private class AutoDiffContext>( */ public class SimpleAutoDiffExpression>( public val field: F, - public val function: AutoDiffField.() -> AutoDiffValue, + public val function: SimpleAutoDiffField.() -> AutoDiffValue, ) : FirstDerivativeExpression() { public override operator fun invoke(arguments: Map): T { //val bindings = arguments.entries.map { it.key.bind(it.value) } - return AutoDiffContext(field, arguments).function().value + return SimpleAutoDiffField(field, arguments).function().value } override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> //val bindings = arguments.entries.map { it.key.bind(it.value) } - val derivationResult = AutoDiffContext(field, arguments).derivate(function) + val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function) derivationResult.derivative(symbol) } } +/** + * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] + */ +public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField> { + return object : AutoDiffProcessor, SimpleAutoDiffField> { + override fun process(function: SimpleAutoDiffField.() -> AutoDiffValue): DifferentiableExpression { + return SimpleAutoDiffExpression(field, function) + } + } +} + // Extensions for differentiation of various basic mathematical functions // x ^ 2 -public fun > AutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue = derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } // x ^ 1/2 -public fun > AutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue = derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } // x ^ y (const) -public fun > AutoDiffField.pow( +public fun > SimpleAutoDiffField.pow( x: AutoDiffValue, y: Double, ): AutoDiffValue = derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } -public fun > AutoDiffField.pow( +public fun > SimpleAutoDiffField.pow( x: AutoDiffValue, y: Int, ): AutoDiffValue = pow(x, y.toDouble()) // exp(x) -public fun > AutoDiffField.exp(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.exp(x: AutoDiffValue): AutoDiffValue = derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } // ln(x) -public fun > AutoDiffField.ln(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.ln(x: AutoDiffValue): AutoDiffValue = derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } // x ^ y (any) -public fun > AutoDiffField.pow( +public fun > SimpleAutoDiffField.pow( x: AutoDiffValue, y: AutoDiffValue, ): AutoDiffValue = exp(y * ln(x)) // sin(x) -public fun > AutoDiffField.sin(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.sin(x: AutoDiffValue): AutoDiffValue = derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } // cos(x) -public fun > AutoDiffField.cos(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.cos(x: AutoDiffValue): AutoDiffValue = derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } -public fun > AutoDiffField.tan(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.tan(x: AutoDiffValue): AutoDiffValue = derive(const { tan(x.value) }) { z -> val c = cos(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asin(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.asin(x: AutoDiffValue): AutoDiffValue = derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.acos(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.acos(x: AutoDiffValue): AutoDiffValue = derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.atan(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.atan(x: AutoDiffValue): AutoDiffValue = derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } -public fun > AutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue = - derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } +public fun > SimpleAutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue = + derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) } -public fun > AutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue = - derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } +public fun > SimpleAutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue = + derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) } -public fun > AutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue = - derive(const { tan(x.value) }) { z -> +public fun > SimpleAutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue = + derive(const { tanh(x.value) }) { z -> val c = cosh(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue = derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } -public fun > AutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue = derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } -public fun > AutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue = derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } +public class SimpleAutoDiffExtendedField>( + context: F, + bindings: Map, +) : ExtendedField>, SimpleAutoDiffField(context, bindings) { + // x ^ 2 + public fun sqr(x: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sqr(x) + + // x ^ 1/2 + public override fun sqrt(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sqrt(arg) + + // x ^ y (const) + public override fun power(arg: AutoDiffValue, pow: Number): AutoDiffValue = + (this as SimpleAutoDiffField).pow(arg, pow.toDouble()) + + // exp(x) + public override fun exp(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).exp(arg) + + // ln(x) + public override fun ln(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).ln(arg) + + // x ^ y (any) + public fun pow( + x: AutoDiffValue, + y: AutoDiffValue, + ): AutoDiffValue = exp(y * ln(x)) + + // sin(x) + public override fun sin(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sin(arg) + + // cos(x) + public override fun cos(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).cos(arg) + + public override fun tan(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).tan(arg) + + public override fun asin(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).asin(arg) + + public override fun acos(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).acos(arg) + + public override fun atan(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).atan(arg) + + public override fun sinh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sinh(arg) + + public override fun cosh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).cosh(arg) + + public override fun tanh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).tanh(arg) + + public override fun asinh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).asinh(arg) + + public override fun acosh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).acosh(arg) + + public override fun atanh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).atanh(arg) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt index defbb14ad..1603bc21d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt @@ -8,11 +8,6 @@ import kotlin.contracts.InvocationKind import kotlin.contracts.contract -//public interface ExpressionBuilder> { -// public fun expression(block: A.() -> E): Expression -//} - - /** * Creates a functional expression with this [Space]. */ diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt index ca8ec1e17..510ed23a9 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -2,6 +2,7 @@ package kscience.kmath.expressions import kscience.kmath.operations.RealField import kscience.kmath.structures.asBuffer +import kotlin.math.E import kotlin.math.PI import kotlin.math.pow import kotlin.math.sqrt @@ -13,18 +14,18 @@ class SimpleAutoDiffTest { fun dx( xBinding: Pair, - body: AutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, + body: SimpleAutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) } fun dxy( xBinding: Pair, yBinding: Pair, - body: AutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, + body: SimpleAutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.simpleAutoDiff(xBinding, yBinding) { body(bind(xBinding.first), bind(yBinding.first)) } - fun diff(block: AutoDiffField.() -> AutoDiffValue): SimpleAutoDiffExpression { + fun diff(block: SimpleAutoDiffField.() -> AutoDiffValue): SimpleAutoDiffExpression { return SimpleAutoDiffExpression(RealField, block) } @@ -45,7 +46,7 @@ class SimpleAutoDiffTest { @Test fun testPlusX2Expr() { - val expr = diff{ + val expr = diff { val x = bind(x) x + x } @@ -245,9 +246,9 @@ class SimpleAutoDiffTest { @Test fun testTanh() { - val y = dx(x to PI / 6) { x -> tanh(x) } - assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6) - assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2 + val y = dx(x to 1.0) { x -> tanh(x) } + assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6) + assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2 } @Test