forked from kscience/kmath
Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
commit
b522e5919e
@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
|
||||
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
||||
break any moment. You can still use it, but be sure to fix the specific version.
|
||||
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
||||
with `@UnstableKmathAPI` or other stability warning annotations.
|
||||
with `@UnstableKMathAPI` or other stability warning annotations.
|
||||
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
||||
versions, but not in patch versions. API is protected
|
||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
|
||||
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
||||
break any moment. You can still use it, but be sure to fix the specific version.
|
||||
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
||||
with `@UnstableKmathAPI` or other stability warning annotations.
|
||||
with `@UnstableKMathAPI` or other stability warning annotations.
|
||||
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
||||
versions, but not in patch versions. API is protected
|
||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||
|
@ -0,0 +1,458 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.expressions
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.MutableBufferFactory
|
||||
import space.kscience.kmath.structures.asBuffer
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
|
||||
/**
|
||||
* Class representing both the value and the differentials of a function.
|
||||
*
|
||||
* This class is the workhorse of the differentiation package.
|
||||
*
|
||||
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper
|
||||
* [Doubly Recursive Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
|
||||
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
|
||||
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
|
||||
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
|
||||
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
|
||||
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
|
||||
*
|
||||
* Derived from
|
||||
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public interface DS<T, A : Ring<T>> {
|
||||
public val derivativeAlgebra: DSAlgebra<T, A>
|
||||
public val data: Buffer<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a partial derivative.
|
||||
*
|
||||
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
|
||||
* @return partial derivative.
|
||||
* @see value
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
private fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
|
||||
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
|
||||
|
||||
/**
|
||||
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : Ring<T>> DS<T, A>.derivative(symbols: List<Symbol>): T {
|
||||
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : Ring<T>> DS<T, A>.derivative(vararg symbols: Symbol): T {
|
||||
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
/**
|
||||
* The value part of the derivative structure.
|
||||
*
|
||||
* @see getPartialDerivative
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public val <T, A : Ring<T>> DS<T, A>.value: T get() = data[0]
|
||||
|
||||
@UnstableKMathAPI
|
||||
public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
||||
|
||||
/**
|
||||
* Get the compiler for number of free parameters and order.
|
||||
*
|
||||
* @return cached rules set.
|
||||
*/
|
||||
@PublishedApi
|
||||
internal val compiler: DSCompiler<T, A> by lazy {
|
||||
// get the cached compilers
|
||||
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
||||
|
||||
// we need to create more compilers
|
||||
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
||||
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
||||
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
||||
|
||||
if (cache != null) {
|
||||
// preserve the already created compilers
|
||||
for (i in cache.indices) {
|
||||
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
||||
}
|
||||
}
|
||||
|
||||
// create the array in increasing diagonal order
|
||||
for (diag in 0..numberOfVariables + order) {
|
||||
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
||||
val p: Int = diag - o
|
||||
if (newCache[p][o] == null) {
|
||||
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
||||
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
||||
|
||||
newCache[p][o] = DSCompiler(
|
||||
algebra,
|
||||
bufferFactory,
|
||||
p,
|
||||
o,
|
||||
valueCompiler,
|
||||
derivativeCompiler,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return@lazy newCache[numberOfVariables][order]!!
|
||||
}
|
||||
|
||||
private val variables: Map<Symbol, DSSymbol> by lazy {
|
||||
bindings.entries.mapIndexed { index, (key, value) ->
|
||||
key to DSSymbol(
|
||||
index,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
}.toMap()
|
||||
}
|
||||
override val symbols: List<Symbol> = bindings.map { it.key }
|
||||
|
||||
public val numberOfVariables: Int get() = symbols.size
|
||||
|
||||
|
||||
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||
buffer[0] = value
|
||||
if (compiler.order > 0) {
|
||||
// the derivative of the variable with respect to itself is 1.
|
||||
|
||||
val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply {
|
||||
set(index, 1)
|
||||
})
|
||||
|
||||
buffer[indexOfDerivative] = algebra.one
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
private inner class DSImpl(
|
||||
override val data: Buffer<T>,
|
||||
) : DS<T, A> {
|
||||
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||
}
|
||||
|
||||
protected fun DS(data: Buffer<T>): DS<T, A> = DSImpl(data)
|
||||
|
||||
|
||||
/**
|
||||
* Build an instance representing a variable.
|
||||
*
|
||||
* Instances built using this constructor are considered to be the free variables with respect to which
|
||||
* differentials are computed. As such, their differential with respect to themselves is +1.
|
||||
*/
|
||||
public fun variable(
|
||||
index: Int,
|
||||
value: T,
|
||||
): DS<T, A> {
|
||||
require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" }
|
||||
return DS(bufferForVariable(index, value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Build an instance from all its derivatives.
|
||||
*
|
||||
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
|
||||
*/
|
||||
public fun ofDerivatives(
|
||||
vararg derivatives: T,
|
||||
): DS<T, A> {
|
||||
require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" }
|
||||
val data = derivatives.asBuffer()
|
||||
|
||||
return DS(data)
|
||||
}
|
||||
|
||||
/**
|
||||
* A class implementing both [DS] and [Symbol].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public inner class DSSymbol internal constructor(
|
||||
index: Int,
|
||||
symbol: Symbol,
|
||||
value: T,
|
||||
) : Symbol by symbol, DS<T, A> {
|
||||
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||
override val data: Buffer<T> = bufferForVariable(index, value)
|
||||
}
|
||||
|
||||
public override fun const(value: T): DS<T, A> {
|
||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||
buffer[0] = value
|
||||
|
||||
return DS(buffer)
|
||||
}
|
||||
|
||||
override fun bindSymbolOrNull(value: String): DSSymbol? = variables[StringSymbol(value)]
|
||||
|
||||
override fun bindSymbol(value: String): DSSymbol =
|
||||
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
|
||||
|
||||
public fun bindSymbolOrNull(symbol: Symbol): DSSymbol? = variables[symbol.identity]
|
||||
|
||||
public fun bindSymbol(symbol: Symbol): DSSymbol =
|
||||
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
|
||||
|
||||
public fun DS<T, A>.derivative(symbols: List<Symbol>): T {
|
||||
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
|
||||
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
public fun DS<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A ring over [DS].
|
||||
*
|
||||
* @property order The derivation order.
|
||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public open class DSRing<T, A>(
|
||||
algebra: A,
|
||||
bufferFactory: MutableBufferFactory<T>,
|
||||
order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : DSAlgebra<T, A>(algebra, bufferFactory, order, bindings),
|
||||
Ring<DS<T, A>>, ScaleOperations<DS<T, A>>,
|
||||
NumericAlgebra<DS<T, A>>,
|
||||
NumbersAddOps<DS<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||
|
||||
override fun bindSymbolOrNull(value: String): DSSymbol? =
|
||||
super<DSAlgebra>.bindSymbolOrNull(value)
|
||||
|
||||
override fun DS<T, A>.unaryMinus(): DS<T, A> = mapData { -it }
|
||||
|
||||
/**
|
||||
* Create a copy of given [Buffer] and modify it according to [block]
|
||||
*/
|
||||
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
|
||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||
val newData = bufferFactory(compiler.size) { data[it] }
|
||||
algebra.block(newData)
|
||||
return DS(newData)
|
||||
}
|
||||
|
||||
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
|
||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||
val newData: Buffer<T> = data.map(bufferFactory) {
|
||||
algebra.block(it)
|
||||
}
|
||||
return DS(newData)
|
||||
}
|
||||
|
||||
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
|
||||
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||
val newData: Buffer<T> = data.mapIndexed(bufferFactory, block)
|
||||
return DS(newData)
|
||||
}
|
||||
|
||||
override val zero: DS<T, A> by lazy {
|
||||
const(algebra.zero)
|
||||
}
|
||||
|
||||
override val one: DS<T, A> by lazy {
|
||||
const(algebra.one)
|
||||
}
|
||||
|
||||
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||
|
||||
override fun add(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||
require(right.derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||
compiler.add(left.data, 0, right.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun scale(a: DS<T, A>, value: Double): DS<T, A> = a.mapData {
|
||||
it.times(value)
|
||||
}
|
||||
|
||||
override fun multiply(
|
||||
left: DS<T, A>,
|
||||
right: DS<T, A>,
|
||||
): DS<T, A> = left.transformDataBuffer { result ->
|
||||
compiler.multiply(left.data, 0, right.data, 0, result, 0)
|
||||
}
|
||||
//
|
||||
// override fun DS<T, A>.minus(arg: DS): DS<T, A> = transformDataBuffer { result ->
|
||||
// subtract(data, 0, arg.data, 0, result, 0)
|
||||
// }
|
||||
|
||||
override operator fun DS<T, A>.plus(other: Number): DS<T, A> = transformDataBuffer {
|
||||
it[0] += number(other)
|
||||
}
|
||||
|
||||
//
|
||||
// override operator fun DS<T, A>.minus(other: Number): DS<T, A> =
|
||||
// this + (-other.toDouble())
|
||||
|
||||
override operator fun Number.plus(other: DS<T, A>): DS<T, A> = other + this
|
||||
override operator fun Number.minus(other: DS<T, A>): DS<T, A> = other - this
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public class DerivativeStructureRingExpression<T, A>(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val function: DSRing<T, A>.() -> DS<T, A>,
|
||||
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||
DSRing(algebra, bufferFactory, 0, arguments).function().value
|
||||
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||
with(
|
||||
DSRing(
|
||||
algebra,
|
||||
bufferFactory,
|
||||
symbols.size,
|
||||
arguments
|
||||
)
|
||||
) { function().derivative(symbols) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A field over commons-math [DerivativeStructure].
|
||||
*
|
||||
* @property order The derivation order.
|
||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class DSField<T, A : ExtendedField<T>>(
|
||||
algebra: A,
|
||||
bufferFactory: MutableBufferFactory<T>,
|
||||
order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : DSRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DS<T, A>> {
|
||||
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||
|
||||
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||
compiler.divide(left.data, 0, right.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun sin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.sin(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun cos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.cos(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun tan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.tan(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun asin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.asin(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun acos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.acos(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun atan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.atan(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun sinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.sinh(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun cosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.cosh(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun tanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.tanh(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun asinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.asinh(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun acosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.acosh(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun atanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.atanh(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun power(arg: DS<T, A>, pow: Number): DS<T, A> = when (pow) {
|
||||
is Int -> arg.transformDataBuffer { result ->
|
||||
compiler.pow(arg.data, 0, pow, result, 0)
|
||||
}
|
||||
else -> arg.transformDataBuffer { result ->
|
||||
compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
|
||||
}
|
||||
}
|
||||
|
||||
override fun sqrt(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.sqrt(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
public fun power(arg: DS<T, A>, pow: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.pow(arg.data, 0, pow.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun exp(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.exp(arg.data, 0, result, 0)
|
||||
}
|
||||
|
||||
override fun ln(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||
compiler.ln(arg.data, 0, result, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public class DSFieldExpression<T, A : ExtendedField<T>>(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val function: DSField<T, A>.() -> DS<T, A>,
|
||||
) : DifferentiableExpression<T> {
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||
DSField(algebra, bufferFactory, 0, arguments).function().value
|
||||
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||
DSField(
|
||||
algebra,
|
||||
bufferFactory,
|
||||
symbols.size,
|
||||
arguments,
|
||||
).run { function().derivative(symbols) }
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -188,7 +188,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
|
||||
*/
|
||||
public fun <T : Any, A : Ring<T>> buffered(
|
||||
algebra: A,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
|
||||
): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
|
||||
|
||||
@Deprecated("use DoubleField.linearSpace")
|
||||
|
@ -27,5 +27,5 @@ public annotation class UnstableKMathAPI
|
||||
RequiresOptIn.Level.WARNING,
|
||||
)
|
||||
public annotation class PerformancePitfall(
|
||||
val message: String = "Potential performance problem"
|
||||
val message: String = "Potential performance problem",
|
||||
)
|
||||
|
@ -69,7 +69,7 @@ public class MutableBufferND<T>(
|
||||
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
|
||||
*/
|
||||
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
||||
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
|
||||
factory: MutableBufferFactory<R> = MutableBufferFactory(MutableBuffer.Companion::auto),
|
||||
crossinline transform: (T) -> R,
|
||||
): MutableBufferND<R> {
|
||||
return if (this is MutableBufferND<T>)
|
||||
|
@ -120,7 +120,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
|
||||
*/
|
||||
public fun <T> buffered(
|
||||
strides: Strides,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
|
||||
initializer: (IntArray) -> T,
|
||||
): BufferND<T> = BufferND(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
@ -140,7 +140,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
|
||||
|
||||
public fun <T> buffered(
|
||||
shape: IntArray,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
|
||||
initializer: (IntArray) -> T,
|
||||
): BufferND<T> = buffered(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
|
@ -6,12 +6,10 @@
|
||||
package space.kscience.kmath.operations
|
||||
|
||||
import space.kscience.kmath.linear.Point
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import space.kscience.kmath.structures.asBuffer
|
||||
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
@ -21,7 +19,7 @@ public abstract class DoubleBufferOps : BufferAlgebra<Double, DoubleField>, Exte
|
||||
Norm<Buffer<Double>, Double> {
|
||||
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer
|
||||
override val bufferFactory: BufferFactory<Double> get() = BufferFactory(::DoubleBuffer)
|
||||
|
||||
override fun Buffer<Double>.map(block: DoubleField.(Double) -> Double): DoubleBuffer =
|
||||
mapInline { DoubleField.block(it) }
|
||||
|
@ -61,31 +61,39 @@ public inline fun <reified T> Buffer<T>.toTypedArray(): Array<T> = Array(size, :
|
||||
/**
|
||||
* Create a new buffer from this one with the given mapping function and using [Buffer.Companion.auto] buffer factory.
|
||||
*/
|
||||
public inline fun <T : Any, reified R : Any> Buffer<T>.map(block: (T) -> R): Buffer<R> =
|
||||
public inline fun <T, reified R : Any> Buffer<T>.map(block: (T) -> R): Buffer<R> =
|
||||
Buffer.auto(size) { block(get(it)) }
|
||||
|
||||
/**
|
||||
* Create a new buffer from this one with the given mapping function.
|
||||
* Provided [bufferFactory] is used to construct the new buffer.
|
||||
*/
|
||||
public inline fun <T : Any, R : Any> Buffer<T>.map(
|
||||
public inline fun <T, R> Buffer<T>.map(
|
||||
bufferFactory: BufferFactory<R>,
|
||||
crossinline block: (T) -> R,
|
||||
): Buffer<R> = bufferFactory(size) { block(get(it)) }
|
||||
|
||||
/**
|
||||
* Create a new buffer from this one with the given indexed mapping function.
|
||||
* Provided [BufferFactory] is used to construct the new buffer.
|
||||
* Create a new buffer from this one with the given mapping (indexed) function.
|
||||
* Provided [bufferFactory] is used to construct the new buffer.
|
||||
*/
|
||||
public inline fun <T : Any, reified R : Any> Buffer<T>.mapIndexed(
|
||||
bufferFactory: BufferFactory<R> = Buffer.Companion::auto,
|
||||
public inline fun <T, R> Buffer<T>.mapIndexed(
|
||||
bufferFactory: BufferFactory<R>,
|
||||
crossinline block: (index: Int, value: T) -> R,
|
||||
): Buffer<R> = bufferFactory(size) { block(it, get(it)) }
|
||||
|
||||
/**
|
||||
* Create a new buffer from this one with the given indexed mapping function.
|
||||
* Provided [BufferFactory] is used to construct the new buffer.
|
||||
*/
|
||||
public inline fun <T, reified R : Any> Buffer<T>.mapIndexed(
|
||||
crossinline block: (index: Int, value: T) -> R,
|
||||
): Buffer<R> = BufferFactory<R>(Buffer.Companion::auto).invoke(size) { block(it, get(it)) }
|
||||
|
||||
/**
|
||||
* Fold given buffer according to [operation]
|
||||
*/
|
||||
public inline fun <T : Any, R> Buffer<T>.fold(initial: R, operation: (acc: R, T) -> R): R {
|
||||
public inline fun <T, R> Buffer<T>.fold(initial: R, operation: (acc: R, T) -> R): R {
|
||||
var accumulator = initial
|
||||
for (index in this.indices) accumulator = operation(accumulator, get(index))
|
||||
return accumulator
|
||||
@ -95,9 +103,9 @@ public inline fun <T : Any, R> Buffer<T>.fold(initial: R, operation: (acc: R, T)
|
||||
* Zip two buffers using given [transform].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public inline fun <T1 : Any, T2 : Any, reified R : Any> Buffer<T1>.zip(
|
||||
public inline fun <T1, T2 : Any, reified R : Any> Buffer<T1>.zip(
|
||||
other: Buffer<T2>,
|
||||
bufferFactory: BufferFactory<R> = Buffer.Companion::auto,
|
||||
bufferFactory: BufferFactory<R> = BufferFactory(Buffer.Companion::auto),
|
||||
crossinline transform: (T1, T2) -> R,
|
||||
): Buffer<R> {
|
||||
require(size == other.size) { "Buffer size mismatch in zip: expected $size but found ${other.size}" }
|
||||
|
@ -14,14 +14,18 @@ import kotlin.reflect.KClass
|
||||
*
|
||||
* @param T the type of buffer.
|
||||
*/
|
||||
public typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
||||
public fun interface BufferFactory<T> {
|
||||
public operator fun invoke(size: Int, builder: (Int) -> T): Buffer<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Function that produces [MutableBuffer] from its size and function that supplies values.
|
||||
*
|
||||
* @param T the type of buffer.
|
||||
*/
|
||||
public typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
||||
public fun interface MutableBufferFactory<T>: BufferFactory<T>{
|
||||
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* A generic read-only random-access structure for both primitives and objects.
|
||||
|
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
@file:OptIn(UnstableKMathAPI::class)
|
||||
|
||||
package space.kscience.kmath.expressions
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFails
|
||||
|
||||
internal inline fun diff(
|
||||
order: Int,
|
||||
vararg parameters: Pair<Symbol, Double>,
|
||||
block: DSField<Double, DoubleField>.() -> Unit,
|
||||
) {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
DSField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
|
||||
}
|
||||
|
||||
internal class AutoDiffTest {
|
||||
private val x by symbol
|
||||
private val y by symbol
|
||||
|
||||
@Test
|
||||
fun dsAlgebraTest() {
|
||||
diff(2, x to 1.0, y to 1.0) {
|
||||
val x = bindSymbol(x)//by binding()
|
||||
val y = bindSymbol("y")
|
||||
val z = x * (-sin(x * y) + y) + 2.0
|
||||
println(z.derivative(x))
|
||||
println(z.derivative(y, x))
|
||||
assertEquals(z.derivative(x, y), z.derivative(y, x))
|
||||
// check improper order cause failure
|
||||
assertFails { z.derivative(x, x, y) }
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun dsExpressionTest() {
|
||||
val f = DSFieldExpression(DoubleField, ::DoubleBuffer) {
|
||||
val x by binding
|
||||
val y by binding
|
||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||
}
|
||||
|
||||
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
|
||||
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
|
||||
}
|
||||
}
|
@ -28,7 +28,7 @@ public class UniformHistogramGroupND<V : Any, A : Field<V>>(
|
||||
private val lower: Buffer<Double>,
|
||||
private val upper: Buffer<Double>,
|
||||
private val binNums: IntArray = IntArray(lower.size) { 20 },
|
||||
private val bufferFactory: BufferFactory<V> = Buffer.Companion::boxing,
|
||||
private val bufferFactory: BufferFactory<V> = BufferFactory(Buffer.Companion::boxing),
|
||||
) : HistogramGroupND<Double, HyperSquareDomain, V> {
|
||||
|
||||
init {
|
||||
@ -114,7 +114,7 @@ public class UniformHistogramGroupND<V : Any, A : Field<V>>(
|
||||
public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges(
|
||||
valueAlgebraND: FieldOpsND<V, A>,
|
||||
vararg ranges: ClosedFloatingPointRange<Double>,
|
||||
bufferFactory: BufferFactory<V> = Buffer.Companion::boxing,
|
||||
bufferFactory: BufferFactory<V> = BufferFactory(Buffer.Companion::boxing),
|
||||
): UniformHistogramGroupND<V, A> = UniformHistogramGroupND(
|
||||
valueAlgebraND,
|
||||
ranges.map(ClosedFloatingPointRange<Double>::start).asBuffer(),
|
||||
@ -140,7 +140,7 @@ public fun Histogram.Companion.uniformDoubleNDFromRanges(
|
||||
public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges(
|
||||
valueAlgebraND: FieldOpsND<V, A>,
|
||||
vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>,
|
||||
bufferFactory: BufferFactory<V> = Buffer.Companion::boxing,
|
||||
bufferFactory: BufferFactory<V> = BufferFactory(Buffer.Companion::boxing),
|
||||
): UniformHistogramGroupND<V, A> = UniformHistogramGroupND(
|
||||
valueAlgebraND,
|
||||
ListBuffer(
|
||||
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.ExponentialOperations
|
||||
@ -22,10 +23,13 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
|
||||
|
||||
override fun tan(arg: StructureND<Double>): MultikTensor<Double> = sin(arg) / cos(arg)
|
||||
|
||||
@PerformancePitfall
|
||||
override fun asin(arg: StructureND<Double>): MultikTensor<Double> = arg.map { asin(it) }
|
||||
|
||||
@PerformancePitfall
|
||||
override fun acos(arg: StructureND<Double>): MultikTensor<Double> = arg.map { acos(it) }
|
||||
|
||||
@PerformancePitfall
|
||||
override fun atan(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atan(it) }
|
||||
|
||||
override fun exp(arg: StructureND<Double>): MultikTensor<Double> = multikMath.mathEx.exp(arg.asMultik().array).wrap()
|
||||
@ -42,10 +46,13 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
|
||||
return (expPlus - expMinus) / (expPlus + expMinus)
|
||||
}
|
||||
|
||||
@PerformancePitfall
|
||||
override fun asinh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { asinh(it) }
|
||||
|
||||
@PerformancePitfall
|
||||
override fun acosh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { acosh(it) }
|
||||
|
||||
@PerformancePitfall
|
||||
override fun atanh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atanh(it) }
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,8 @@ internal object InternalUtils {
|
||||
cache.copyInto(
|
||||
logFactorials,
|
||||
BEGIN_LOG_FACTORIALS,
|
||||
BEGIN_LOG_FACTORIALS, endCopy
|
||||
BEGIN_LOG_FACTORIALS,
|
||||
endCopy,
|
||||
)
|
||||
} else
|
||||
// All values to be computed
|
||||
|
@ -35,7 +35,7 @@ public fun interface Sampler<out T : Any> {
|
||||
public fun <T : Any> Sampler<T>.sampleBuffer(
|
||||
generator: RandomGenerator,
|
||||
size: Int,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
|
||||
): Chain<Buffer<T>> {
|
||||
require(size > 1)
|
||||
//creating temporary storage once
|
||||
|
Loading…
Reference in New Issue
Block a user