Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116
@ -12,16 +12,22 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
|||||||
*/
|
*/
|
||||||
public class DerivativeStructureField(
|
public class DerivativeStructureField(
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
private val bindings: Map<Symbol, Double>
|
bindings: Map<Symbol, Double>,
|
||||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
|
public val numberOfVariables: Int = bindings.size
|
||||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
|
|
||||||
|
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||||
|
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements both [DerivativeStructure] and a [Symbol]
|
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||||
*/
|
*/
|
||||||
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
public inner class DerivativeStructureSymbol(
|
||||||
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
size: Int,
|
||||||
|
index: Int,
|
||||||
|
symbol: Symbol,
|
||||||
|
value: Double,
|
||||||
|
) : DerivativeStructure(size, order, index, value), Symbol {
|
||||||
override val identity: String = symbol.identity
|
override val identity: String = symbol.identity
|
||||||
override fun toString(): String = identity
|
override fun toString(): String = identity
|
||||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
@ -31,27 +37,26 @@ public class DerivativeStructureField(
|
|||||||
/**
|
/**
|
||||||
* Identity-based symbol bindings map
|
* Identity-based symbol bindings map
|
||||||
*/
|
*/
|
||||||
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
|
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
|
||||||
key.identity to DerivativeStructureSymbol(key, value)
|
key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
|
||||||
}
|
}.toMap()
|
||||||
|
|
||||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
|
override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value)
|
||||||
|
|
||||||
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||||
|
|
||||||
//public fun Number.const(): DerivativeStructure = const(toDouble())
|
override fun symbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value))
|
||||||
|
|
||||||
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
public fun DerivativeStructure.derivative(symbols: List<Symbol>): Double {
|
||||||
return derivative(mapOf(parameter to order))
|
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||||
|
val ordersCount = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
|
public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList())
|
||||||
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
|
|
||||||
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||||
|
|
||||||
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||||
@ -97,6 +102,7 @@ public class DerivativeStructureField(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constructs that creates a derivative structure with required order on-demand
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
*/
|
*/
|
||||||
@ -109,7 +115,7 @@ public class DerivativeStructureExpression(
|
|||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
*/
|
*/
|
||||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
||||||
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,14 +5,15 @@ import kotlin.contracts.InvocationKind
|
|||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFails
|
||||||
|
|
||||||
internal inline fun <R> diff(
|
internal inline fun diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<Symbol, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField.() -> R,
|
block: DerivativeStructureField.() -> Unit,
|
||||||
): R {
|
): Unit {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
@ -21,13 +22,16 @@ internal class AutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun derivativeStructureFieldTest() {
|
fun derivativeStructureFieldTest() {
|
||||||
val res: Double = diff(3, x to 1.0, y to 1.0) {
|
diff(2, x to 1.0, y to 1.0) {
|
||||||
val x = bind(x)//by binding()
|
val x = bind(x)//by binding()
|
||||||
val y = symbol("y")
|
val y = symbol("y")
|
||||||
val z = x * (-sin(x * y) + y)
|
val z = x * (-sin(x * y) + y) + 2.0
|
||||||
z.derivative(x)
|
println(z.derivative(x))
|
||||||
|
println(z.derivative(y,x))
|
||||||
|
assertEquals(z.derivative(x, y), z.derivative(y, x))
|
||||||
|
//check that improper order cause failure
|
||||||
|
assertFails { z.derivative(x,x,y) }
|
||||||
}
|
}
|
||||||
println(res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -40,5 +44,7 @@ internal class AutoDiffTest {
|
|||||||
|
|
||||||
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||||
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||||
|
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
|
||||||
|
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,6 @@ import kscience.kmath.stat.Distribution
|
|||||||
import kscience.kmath.stat.Fitting
|
import kscience.kmath.stat.Fitting
|
||||||
import kscience.kmath.stat.RandomGenerator
|
import kscience.kmath.stat.RandomGenerator
|
||||||
import kscience.kmath.stat.normal
|
import kscience.kmath.stat.normal
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
@ -53,7 +52,7 @@ internal class OptimizeTest {
|
|||||||
it.pow(2) + it + 1 + chain.nextDouble()
|
it.pow(2) + it + 1 + chain.nextDouble()
|
||||||
}
|
}
|
||||||
val yErr = x.map { sigma }
|
val yErr = x.map { sigma }
|
||||||
val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
val chi2 = Fitting.chiSquared(x, y, yErr) { x ->
|
||||||
val cWithDefault = bindOrNull(c) ?: one
|
val cWithDefault = bindOrNull(c) ?: one
|
||||||
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
||||||
}
|
}
|
||||||
|
@ -4,19 +4,17 @@ package kscience.kmath.expressions
|
|||||||
* An expression that provides derivatives
|
* An expression that provides derivatives
|
||||||
*/
|
*/
|
||||||
public interface DifferentiableExpression<T> : Expression<T> {
|
public interface DifferentiableExpression<T> : Expression<T> {
|
||||||
public fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>?
|
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(orders: Map<Symbol, Int>): Expression<T> =
|
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
|
||||||
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
|
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
|
||||||
derivative(mapOf(*orders))
|
derivative(symbols.toList())
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||||
derivative(StringSymbol(name) to 1)
|
derivative(StringSymbol(name))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A [DifferentiableExpression] that defines only first derivatives
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
@ -25,8 +23,8 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
|||||||
|
|
||||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||||
|
|
||||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? {
|
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
|
||||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
val dSymbol = symbols.firstOrNull() ?: return null
|
||||||
return derivativeOrNull(dSymbol)
|
return derivativeOrNull(dSymbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user