Safe modification of autodiff
This commit is contained in:
parent
6f9b704aa7
commit
c3f0dbe161
@ -1,5 +1,7 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@ -12,12 +14,26 @@ import kotlin.math.sqrt
|
|||||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
*/
|
*/
|
||||||
open class Variable(val x: Double) {
|
open class Variable(val value: Double) {
|
||||||
constructor(x: Number) : this(x.toDouble())
|
constructor(x: Number) : this(x.toDouble())
|
||||||
}
|
}
|
||||||
|
|
||||||
class DerivationResult(x: Double, val deriv: Map<Variable, Double>): Variable(x) {
|
class DerivationResult(value: Double, val deriv: Map<Variable, Double>) : Variable(value) {
|
||||||
fun deriv(variable: Variable) = deriv[variable] ?: 0.0
|
fun deriv(variable: Variable) = deriv[variable] ?: 0.0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* compute divergence
|
||||||
|
*/
|
||||||
|
fun div() = deriv.values.sum()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute a gradient for variables in given order
|
||||||
|
*/
|
||||||
|
fun grad(vararg variables: Variable): Point<Double> = if (variables.isEmpty()) {
|
||||||
|
error("Variable order is not provided for gradient construction")
|
||||||
|
} else {
|
||||||
|
variables.map(::deriv).toDoubleArray().asBuffer()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -38,7 +54,7 @@ fun deriv(body: AutoDiffField.() -> Variable): DerivationResult =
|
|||||||
val result = body()
|
val result = body()
|
||||||
result.d = 1.0 // computing derivative w.r.t result
|
result.d = 1.0 // computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
DerivationResult(result.x, derivatives)
|
DerivationResult(result.value, derivatives)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -61,19 +77,21 @@ abstract class AutoDiffField : Field<Variable> {
|
|||||||
*/
|
*/
|
||||||
abstract var Variable.d: Double
|
abstract var Variable.d: Double
|
||||||
|
|
||||||
|
abstract fun variable(value: Double): Variable
|
||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
operator fun Number.plus(that: Variable): Variable = derive(Variable(this.toDouble() + that.x)) { z ->
|
operator fun Number.plus(that: Variable): Variable = derive(variable(this.toDouble() + that.value)) { z ->
|
||||||
that.d += z.d
|
that.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Variable.plus(b: Number): Variable = b.plus(this)
|
operator fun Variable.plus(b: Number): Variable = b.plus(this)
|
||||||
|
|
||||||
operator fun Number.minus(that: Variable): Variable = derive(Variable(this.toDouble() - that.x)) { z ->
|
operator fun Number.minus(that: Variable): Variable = derive(variable(this.toDouble() - that.value)) { z ->
|
||||||
that.d -= z.d
|
that.d -= z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Variable.minus(that: Number): Variable = derive(Variable(this.x - that.toDouble())) { z ->
|
operator fun Variable.minus(that: Number): Variable = derive(variable(this.value - that.toDouble())) { z ->
|
||||||
this.d += z.d
|
this.d += z.d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -89,10 +107,22 @@ private class AutoDiffContext : AutoDiffField() {
|
|||||||
|
|
||||||
internal val derivatives = HashMap<Variable, Double>()
|
internal val derivatives = HashMap<Variable, Double>()
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variable coupled with its derivative. For internal use only
|
||||||
|
*/
|
||||||
|
class VariableWithDeriv(x: Double, var d: Double = 0.0): Variable(x)
|
||||||
|
|
||||||
|
override fun variable(value: Double): Variable = VariableWithDeriv(value)
|
||||||
|
|
||||||
override var Variable.d: Double
|
override var Variable.d: Double
|
||||||
get() = derivatives[this] ?: 0.0
|
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: 0.0
|
||||||
set(value) {
|
set(value) {
|
||||||
derivatives[this] = value
|
if(this is VariableWithDeriv){
|
||||||
|
d = value
|
||||||
|
}else {
|
||||||
|
derivatives[this] = value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
@ -117,25 +147,25 @@ private class AutoDiffContext : AutoDiffField() {
|
|||||||
|
|
||||||
|
|
||||||
override fun add(a: Variable, b: Variable): Variable =
|
override fun add(a: Variable, b: Variable): Variable =
|
||||||
derive(Variable(a.x + b.x)) { z ->
|
derive(variable(a.value + b.value)) { z ->
|
||||||
a.d += z.d
|
a.d += z.d
|
||||||
b.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Variable, b: Variable): Variable =
|
override fun multiply(a: Variable, b: Variable): Variable =
|
||||||
derive(Variable(a.x * b.x)) { z ->
|
derive(variable(a.value * b.value)) { z ->
|
||||||
a.d += z.d * b.x
|
a.d += z.d * b.value
|
||||||
b.d += z.d * a.x
|
b.d += z.d * a.value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: Variable, b: Variable): Variable =
|
override fun divide(a: Variable, b: Variable): Variable =
|
||||||
derive(Variable(a.x / b.x)) { z ->
|
derive(Variable(a.value / b.value)) { z ->
|
||||||
a.d += z.d / b.x
|
a.d += z.d / b.value
|
||||||
b.d -= z.d * a.x / (b.x * b.x)
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Variable, k: Number): Variable =
|
override fun multiply(a: Variable, k: Number): Variable =
|
||||||
derive(Variable(k.toDouble() * a.x)) { z ->
|
derive(variable(k.toDouble() * a.value)) { z ->
|
||||||
a.d += z.d * k.toDouble()
|
a.d += z.d * k.toDouble()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,41 +176,41 @@ private class AutoDiffContext : AutoDiffField() {
|
|||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
// x ^ 2
|
// x ^ 2
|
||||||
fun AutoDiffField.sqr(x: Variable): Variable = derive(Variable(x.x * x.x)) { z ->
|
fun AutoDiffField.sqr(x: Variable): Variable = derive(variable(x.value * x.value)) { z ->
|
||||||
x.d += z.d * 2 * x.x
|
x.d += z.d * 2 * x.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// x ^ 1/2
|
// x ^ 1/2
|
||||||
fun AutoDiffField.sqrt(x: Variable): Variable = derive(Variable(sqrt(x.x))) { z ->
|
fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable(sqrt(x.value))) { z ->
|
||||||
x.d += z.d * 0.5 / z.x
|
x.d += z.d * 0.5 / z.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// x ^ y (const)
|
// x ^ y (const)
|
||||||
fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(Variable(x.x.pow(y))) { z ->
|
fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable(x.value.pow(y))) { z ->
|
||||||
x.d += z.d * y * x.x.pow(y - 1)
|
x.d += z.d * y * x.value.pow(y - 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble())
|
fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble())
|
||||||
|
|
||||||
// exp(x)
|
// exp(x)
|
||||||
fun AutoDiffField.exp(x: Variable): Variable = derive(Variable(kotlin.math.exp(x.x))) { z ->
|
fun AutoDiffField.exp(x: Variable): Variable = derive(variable(kotlin.math.exp(x.value))) { z ->
|
||||||
x.d += z.d * z.x
|
x.d += z.d * z.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// ln(x)
|
// ln(x)
|
||||||
fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.x))) { z ->
|
fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.value))) { z ->
|
||||||
x.d += z.d / x.x
|
x.d += z.d / x.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// x ^ y (any)
|
// x ^ y (any)
|
||||||
fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x))
|
fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x))
|
||||||
|
|
||||||
// sin(x)
|
// sin(x)
|
||||||
fun AutoDiffField.sin(x: Variable): Variable = derive(Variable(kotlin.math.sin(x.x))) { z ->
|
fun AutoDiffField.sin(x: Variable): Variable = derive(variable(kotlin.math.sin(x.value))) { z ->
|
||||||
x.d += z.d * kotlin.math.cos(x.x)
|
x.d += z.d * kotlin.math.cos(x.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cos(x)
|
// cos(x)
|
||||||
fun AutoDiffField.cos(x: Variable): Variable = derive(Variable(kotlin.math.cos(x.x))) { z ->
|
fun AutoDiffField.cos(x: Variable): Variable = derive(variable(kotlin.math.cos(x.value))) { z ->
|
||||||
x.d -= z.d * kotlin.math.sin(x.x)
|
x.d -= z.d * kotlin.math.sin(x.value)
|
||||||
}
|
}
|
@ -37,6 +37,11 @@ interface Buffer<T> {
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
||||||
|
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer {
|
||||||
|
val array = DoubleArray(size) { initializer(it) }
|
||||||
|
return DoubleBuffer(array)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a boxing buffer of given type
|
* Create a boxing buffer of given type
|
||||||
*/
|
*/
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class AutoDiffTest {
|
class AutoDiffTest {
|
||||||
@Test
|
@Test
|
||||||
fun testPlusX2() {
|
fun testPlusX2() {
|
||||||
val x = Variable(3) // diff w.r.t this x at 3
|
val x = Variable(3) // diff w.r.t this x at 3
|
||||||
val y = deriv { x + x }
|
val y = deriv { x + x }
|
||||||
assertEquals(6.0, y.x) // y = x + x = 6
|
assertEquals(6.0, y.value) // y = x + x = 6
|
||||||
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -19,7 +21,7 @@ class AutoDiffTest {
|
|||||||
val x = Variable(2)
|
val x = Variable(2)
|
||||||
val y = Variable(3)
|
val y = Variable(3)
|
||||||
val z = deriv { x + y }
|
val z = deriv { x + y }
|
||||||
assertEquals(5.0, z.x) // z = x + y = 5
|
assertEquals(5.0, z.value) // z = x + y = 5
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||||
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
||||||
}
|
}
|
||||||
@ -30,7 +32,7 @@ class AutoDiffTest {
|
|||||||
val x = Variable(7)
|
val x = Variable(7)
|
||||||
val y = Variable(3)
|
val y = Variable(3)
|
||||||
val z = deriv { x - y }
|
val z = deriv { x - y }
|
||||||
assertEquals(4.0, z.x) // z = x - y = 4
|
assertEquals(4.0, z.value) // z = x - y = 4
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||||
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
||||||
}
|
}
|
||||||
@ -39,7 +41,7 @@ class AutoDiffTest {
|
|||||||
fun testMulX2() {
|
fun testMulX2() {
|
||||||
val x = Variable(3) // diff w.r.t this x at 3
|
val x = Variable(3) // diff w.r.t this x at 3
|
||||||
val y = deriv { x * x }
|
val y = deriv { x * x }
|
||||||
assertEquals(9.0, y.x) // y = x * x = 9
|
assertEquals(9.0, y.value) // y = x * x = 9
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ class AutoDiffTest {
|
|||||||
fun testSqr() {
|
fun testSqr() {
|
||||||
val x = Variable(3)
|
val x = Variable(3)
|
||||||
val y = deriv { sqr(x) }
|
val y = deriv { sqr(x) }
|
||||||
assertEquals(9.0, y.x) // y = x ^ 2 = 9
|
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ class AutoDiffTest {
|
|||||||
fun testSqrSqr() {
|
fun testSqrSqr() {
|
||||||
val x = Variable(2)
|
val x = Variable(2)
|
||||||
val y = deriv { sqr(sqr(x)) }
|
val y = deriv { sqr(sqr(x)) }
|
||||||
assertEquals(16.0, y.x) // y = x ^ 4 = 16
|
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||||
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +65,7 @@ class AutoDiffTest {
|
|||||||
fun testX3() {
|
fun testX3() {
|
||||||
val x = Variable(2) // diff w.r.t this x at 2
|
val x = Variable(2) // diff w.r.t this x at 2
|
||||||
val y = deriv { x * x * x }
|
val y = deriv { x * x * x }
|
||||||
assertEquals(8.0, y.x) // y = x * x * x = 8
|
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,7 +74,7 @@ class AutoDiffTest {
|
|||||||
val x = Variable(5)
|
val x = Variable(5)
|
||||||
val y = Variable(2)
|
val y = Variable(2)
|
||||||
val z = deriv { x / y }
|
val z = deriv { x / y }
|
||||||
assertEquals(2.5, z.x) // z = x / y = 2.5
|
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||||
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
||||||
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
||||||
}
|
}
|
||||||
@ -81,7 +83,7 @@ class AutoDiffTest {
|
|||||||
fun testPow3() {
|
fun testPow3() {
|
||||||
val x = Variable(2) // diff w.r.t this x at 2
|
val x = Variable(2) // diff w.r.t this x at 2
|
||||||
val y = deriv { pow(x, 3) }
|
val y = deriv { pow(x, 3) }
|
||||||
assertEquals(8.0, y.x) // y = x ^ 3 = 8
|
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,7 +92,7 @@ class AutoDiffTest {
|
|||||||
val x = Variable(2)
|
val x = Variable(2)
|
||||||
val y = Variable(3)
|
val y = Variable(3)
|
||||||
val z = deriv { pow(x, y) }
|
val z = deriv { pow(x, y) }
|
||||||
assertApprox(8.0, z.x) // z = x ^ y = 8
|
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||||
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||||
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
||||||
}
|
}
|
||||||
@ -99,7 +101,7 @@ class AutoDiffTest {
|
|||||||
fun testFromPaper() {
|
fun testFromPaper() {
|
||||||
val x = Variable(3)
|
val x = Variable(3)
|
||||||
val y = deriv { 2 * x + x * x * x }
|
val y = deriv { 2 * x + x * x * x }
|
||||||
assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33
|
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||||
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,7 +111,7 @@ class AutoDiffTest {
|
|||||||
val y = deriv {
|
val y = deriv {
|
||||||
Variable(1) * x
|
Variable(1) * x
|
||||||
}
|
}
|
||||||
assertEquals(1.0, y.x) // y = x ^ n = 1
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +124,7 @@ class AutoDiffTest {
|
|||||||
for (i in 1..n) res *= x
|
for (i in 1..n) res *= x
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
assertEquals(1.0, y.x) // y = x ^ n = 1
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +132,7 @@ class AutoDiffTest {
|
|||||||
fun testExample() {
|
fun testExample() {
|
||||||
val x = Variable(2)
|
val x = Variable(2)
|
||||||
val y = deriv { sqr(x) + 5 * x + 3 }
|
val y = deriv { sqr(x) + 5 * x + 3 }
|
||||||
assertEquals(17.0, y.x) // the value of result (y)
|
assertEquals(17.0, y.value) // the value of result (y)
|
||||||
assertEquals(9.0, y.deriv(x)) // dy/dx
|
assertEquals(9.0, y.deriv(x)) // dy/dx
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,7 +140,7 @@ class AutoDiffTest {
|
|||||||
fun testSqrt() {
|
fun testSqrt() {
|
||||||
val x = Variable(16)
|
val x = Variable(16)
|
||||||
val y = deriv { sqrt(x) }
|
val y = deriv { sqrt(x) }
|
||||||
assertEquals(4.0, y.x) // y = x ^ 1/2 = 4
|
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||||
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +148,7 @@ class AutoDiffTest {
|
|||||||
fun testSin() {
|
fun testSin() {
|
||||||
val x = Variable(PI / 6)
|
val x = Variable(PI / 6)
|
||||||
val y = deriv { sin(x) }
|
val y = deriv { sin(x) }
|
||||||
assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5
|
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||||
assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2
|
assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,10 +156,19 @@ class AutoDiffTest {
|
|||||||
fun testCos() {
|
fun testCos() {
|
||||||
val x = Variable(PI / 6)
|
val x = Variable(PI / 6)
|
||||||
val y = deriv { cos(x) }
|
val y = deriv { cos(x) }
|
||||||
assertApprox(kotlin.math.sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2
|
assertApprox(kotlin.math.sqrt(3.0) / 2, y.value) // y = cos(PI/6) = sqrt(3)/2
|
||||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5
|
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivGrad() {
|
||||||
|
val x = Variable(1.0)
|
||||||
|
val y = Variable(2.0)
|
||||||
|
val res = deriv { x * x + y * y }
|
||||||
|
assertEquals(6.0, res.div())
|
||||||
|
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
private fun assertApprox(a: Double, b: Double) {
|
private fun assertApprox(a: Double, b: Double) {
|
||||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user