Merge pull request #140 from mipt-npm/autodiff-update
Upgrade AutoDiff support of trigonometric ops, also fix some problems with MstAlgebra
This commit is contained in:
commit
51b7d4e73e
@ -2,7 +2,7 @@
|
||||
|
||||
## [Unreleased]
|
||||
### Added
|
||||
|
||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
||||
### Changed
|
||||
|
||||
### Deprecated
|
||||
@ -10,7 +10,7 @@
|
||||
### Removed
|
||||
|
||||
### Fixed
|
||||
|
||||
- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140)
|
||||
### Security
|
||||
## [0.1.4]
|
||||
|
||||
|
@ -44,7 +44,9 @@ object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val zero: MST
|
||||
get() = MstSpace.zero
|
||||
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun number(value: Number): MST = MstSpace.number(value)
|
||||
@ -67,8 +69,11 @@ object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
object MstField : Field<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val one: MST = number(1.0)
|
||||
override val zero: MST
|
||||
get() = MstRing.zero
|
||||
|
||||
override val one: MST
|
||||
get() = MstRing.one
|
||||
|
||||
override fun symbol(value: String): MST = MstRing.symbol(value)
|
||||
override fun number(value: Number): MST = MstRing.number(value)
|
||||
@ -89,14 +94,25 @@ object MstField : Field<MST> {
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
object MstExtendedField : ExtendedField<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val one: MST = number(1.0)
|
||||
override val zero: MST
|
||||
get() = MstField.zero
|
||||
|
||||
override val one: MST
|
||||
get() = MstField.one
|
||||
|
||||
override fun symbol(value: String): MST = MstField.symbol(value)
|
||||
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
||||
|
@ -65,7 +65,6 @@ inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Varia
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||
abstract val context: F
|
||||
|
||||
@ -152,7 +151,6 @@ internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) :
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
|
||||
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
@ -173,38 +171,66 @@ internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) :
|
||||
}
|
||||
}
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
||||
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||
|
||||
// x ^ 1/2
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||
|
||||
// x ^ y (const)
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
||||
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
|
||||
pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
||||
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||
|
||||
// ln(x)
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
|
||||
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||
|
||||
// x ^ y (any)
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
||||
exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||
|
||||
// cos(x)
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
||||
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
|
||||
derive(variable { tan(x.value) }) { z ->
|
||||
val c = cos(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
|
||||
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
|
||||
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
|
||||
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { tan(x.value) }) { z ->
|
||||
val c = cosh(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||
|
||||
|
@ -3,19 +3,19 @@ package scientifik.kmath.misc
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class AutoDiffTest {
|
||||
fun Variable(int: Int): Variable<Double> = Variable(int.toDouble())
|
||||
|
||||
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
|
||||
inline fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
|
||||
RealField.deriv(body)
|
||||
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
val x = Variable(3) // diff w.r.t this x at 3
|
||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
||||
val y = deriv { x + x }
|
||||
assertEquals(6.0, y.value) // y = x + x = 6
|
||||
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
||||
@ -24,8 +24,8 @@ class AutoDiffTest {
|
||||
@Test
|
||||
fun testPlus() {
|
||||
// two variables
|
||||
val x = Variable(2)
|
||||
val y = Variable(3)
|
||||
val x = Variable(2.0)
|
||||
val y = Variable(3.0)
|
||||
val z = deriv { x + y }
|
||||
assertEquals(5.0, z.value) // z = x + y = 5
|
||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||
@ -35,8 +35,8 @@ class AutoDiffTest {
|
||||
@Test
|
||||
fun testMinus() {
|
||||
// two variables
|
||||
val x = Variable(7)
|
||||
val y = Variable(3)
|
||||
val x = Variable(7.0)
|
||||
val y = Variable(3.0)
|
||||
val z = deriv { x - y }
|
||||
assertEquals(4.0, z.value) // z = x - y = 4
|
||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||
@ -45,7 +45,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testMulX2() {
|
||||
val x = Variable(3) // diff w.r.t this x at 3
|
||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
||||
val y = deriv { x * x }
|
||||
assertEquals(9.0, y.value) // y = x * x = 9
|
||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||
@ -53,7 +53,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testSqr() {
|
||||
val x = Variable(3)
|
||||
val x = Variable(3.0)
|
||||
val y = deriv { sqr(x) }
|
||||
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||
@ -61,7 +61,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testSqrSqr() {
|
||||
val x = Variable(2)
|
||||
val x = Variable(2.0)
|
||||
val y = deriv { sqr(sqr(x)) }
|
||||
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
||||
@ -69,7 +69,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testX3() {
|
||||
val x = Variable(2) // diff w.r.t this x at 2
|
||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
||||
val y = deriv { x * x * x }
|
||||
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
||||
@ -77,8 +77,8 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testDiv() {
|
||||
val x = Variable(5)
|
||||
val y = Variable(2)
|
||||
val x = Variable(5.0)
|
||||
val y = Variable(2.0)
|
||||
val z = deriv { x / y }
|
||||
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
||||
@ -87,7 +87,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testPow3() {
|
||||
val x = Variable(2) // diff w.r.t this x at 2
|
||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
||||
val y = deriv { pow(x, 3) }
|
||||
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||
@ -95,8 +95,8 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testPowFull() {
|
||||
val x = Variable(2)
|
||||
val y = Variable(3)
|
||||
val x = Variable(2.0)
|
||||
val y = Variable(3.0)
|
||||
val z = deriv { pow(x, y) }
|
||||
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||
@ -105,7 +105,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testFromPaper() {
|
||||
val x = Variable(3)
|
||||
val x = Variable(3.0)
|
||||
val y = deriv { 2 * x + x * x * x }
|
||||
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||
@ -113,9 +113,9 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testInnerVariable() {
|
||||
val x = Variable(1)
|
||||
val x = Variable(1.0)
|
||||
val y = deriv {
|
||||
Variable(1) * x
|
||||
Variable(1.0) * x
|
||||
}
|
||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||
@ -124,9 +124,9 @@ class AutoDiffTest {
|
||||
@Test
|
||||
fun testLongChain() {
|
||||
val n = 10_000
|
||||
val x = Variable(1)
|
||||
val x = Variable(1.0)
|
||||
val y = deriv {
|
||||
var res = Variable(1)
|
||||
var res = Variable(1.0)
|
||||
for (i in 1..n) res *= x
|
||||
res
|
||||
}
|
||||
@ -136,7 +136,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testExample() {
|
||||
val x = Variable(2)
|
||||
val x = Variable(2.0)
|
||||
val y = deriv { sqr(x) + 5 * x + 3 }
|
||||
assertEquals(17.0, y.value) // the value of result (y)
|
||||
assertEquals(9.0, y.deriv(x)) // dy/dx
|
||||
@ -144,7 +144,7 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testSqrt() {
|
||||
val x = Variable(16)
|
||||
val x = Variable(16.0)
|
||||
val y = deriv { sqrt(x) }
|
||||
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||
@ -152,18 +152,98 @@ class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testSin() {
|
||||
val x = Variable(PI / 6)
|
||||
val x = Variable(PI / 6.0)
|
||||
val y = deriv { sin(x) }
|
||||
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||
assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2
|
||||
assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCos() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { cos(x) }
|
||||
assertApprox(kotlin.math.sqrt(3.0) / 2, y.value) // y = cos(PI/6) = sqrt(3)/2
|
||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5
|
||||
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTan() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { tan(x) }
|
||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
||||
assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsin() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { asin(x) }
|
||||
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
||||
assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcos() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { acos(x) }
|
||||
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtan() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { atan(x) }
|
||||
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
||||
assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSinh() {
|
||||
val x = Variable(0.0)
|
||||
val y = deriv { sinh(x) }
|
||||
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
||||
assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCosh() {
|
||||
val x = Variable(0.0)
|
||||
val y = deriv { cosh(x) }
|
||||
assertApprox(1.0, y.value) //y = cosh(0)
|
||||
assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTanh() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { tanh(x) }
|
||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsinh() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { asinh(x) }
|
||||
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
||||
assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcosh() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { acosh(x) }
|
||||
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtanh() {
|
||||
val x = Variable(PI / 6.0)
|
||||
val y = deriv { atanh(x) }
|
||||
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
||||
assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36)
|
||||
}
|
||||
|
||||
@Test
|
||||
|
Loading…
Reference in New Issue
Block a user