Merge pull request #402 from mipt-npm/commandertvis/diff

Copy `DerivativeStructure` from Commons Math to multiplatform
This commit is contained in:
Alexander Nozik 2022-07-16 10:05:59 +03:00 committed by GitHub
commit fc036215a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 2048 additions and 26 deletions

View File

@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could * **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
break any moment. You can still use it, but be sure to fix the specific version. break any moment. You can still use it, but be sure to fix the specific version.
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked * **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
with `@UnstableKmathAPI` or other stability warning annotations. with `@UnstableKMathAPI` or other stability warning annotations.
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor * **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
versions, but not in patch versions. API is protected versions, but not in patch versions. API is protected
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.

View File

@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could * **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
break any moment. You can still use it, but be sure to fix the specific version. break any moment. You can still use it, but be sure to fix the specific version.
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked * **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
with `@UnstableKmathAPI` or other stability warning annotations. with `@UnstableKMathAPI` or other stability warning annotations.
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor * **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
versions, but not in patch versions. API is protected versions, but not in patch versions. API is protected
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.

View File

@ -0,0 +1,458 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.expressions
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableBufferFactory
import space.kscience.kmath.structures.asBuffer
import kotlin.math.max
import kotlin.math.min
/**
* Class representing both the value and the differentials of a function.
*
* This class is the workhorse of the differentiation package.
*
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper
* [Doubly Recursive Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
*
* Derived from
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
*/
@UnstableKMathAPI
public interface DS<T, A : Ring<T>> {
public val derivativeAlgebra: DSAlgebra<T, A>
public val data: Buffer<T>
}
/**
* Get a partial derivative.
*
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
* @return partial derivative.
* @see value
*/
@UnstableKMathAPI
private fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
/**
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
*/
@UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.derivative(symbols: List<Symbol>): T {
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
}
/**
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
*/
@UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.derivative(vararg symbols: Symbol): T {
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
}
/**
* The value part of the derivative structure.
*
* @see getPartialDerivative
*/
@UnstableKMathAPI
public val <T, A : Ring<T>> DS<T, A>.value: T get() = data[0]
@UnstableKMathAPI
public abstract class DSAlgebra<T, A : Ring<T>>(
public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val order: Int,
bindings: Map<Symbol, T>,
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
/**
* Get the compiler for number of free parameters and order.
*
* @return cached rules set.
*/
@PublishedApi
internal val compiler: DSCompiler<T, A> by lazy {
// get the cached compilers
val cache: Array<Array<DSCompiler<T, A>?>>? = null
// we need to create more compilers
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
if (cache != null) {
// preserve the already created compilers
for (i in cache.indices) {
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
}
}
// create the array in increasing diagonal order
for (diag in 0..numberOfVariables + order) {
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
val p: Int = diag - o
if (newCache[p][o] == null) {
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
newCache[p][o] = DSCompiler(
algebra,
bufferFactory,
p,
o,
valueCompiler,
derivativeCompiler,
)
}
}
}
return@lazy newCache[numberOfVariables][order]!!
}
private val variables: Map<Symbol, DSSymbol> by lazy {
bindings.entries.mapIndexed { index, (key, value) ->
key to DSSymbol(
index,
key,
value,
)
}.toMap()
}
override val symbols: List<Symbol> = bindings.map { it.key }
public val numberOfVariables: Int get() = symbols.size
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
val buffer = bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value
if (compiler.order > 0) {
// the derivative of the variable with respect to itself is 1.
val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply {
set(index, 1)
})
buffer[indexOfDerivative] = algebra.one
}
return buffer
}
@UnstableKMathAPI
private inner class DSImpl(
override val data: Buffer<T>,
) : DS<T, A> {
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
}
protected fun DS(data: Buffer<T>): DS<T, A> = DSImpl(data)
/**
* Build an instance representing a variable.
*
* Instances built using this constructor are considered to be the free variables with respect to which
* differentials are computed. As such, their differential with respect to themselves is +1.
*/
public fun variable(
index: Int,
value: T,
): DS<T, A> {
require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" }
return DS(bufferForVariable(index, value))
}
/**
* Build an instance from all its derivatives.
*
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
*/
public fun ofDerivatives(
vararg derivatives: T,
): DS<T, A> {
require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" }
val data = derivatives.asBuffer()
return DS(data)
}
/**
* A class implementing both [DS] and [Symbol].
*/
@UnstableKMathAPI
public inner class DSSymbol internal constructor(
index: Int,
symbol: Symbol,
value: T,
) : Symbol by symbol, DS<T, A> {
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
override val data: Buffer<T> = bufferForVariable(index, value)
}
public override fun const(value: T): DS<T, A> {
val buffer = bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value
return DS(buffer)
}
override fun bindSymbolOrNull(value: String): DSSymbol? = variables[StringSymbol(value)]
override fun bindSymbol(value: String): DSSymbol =
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
public fun bindSymbolOrNull(symbol: Symbol): DSSymbol? = variables[symbol.identity]
public fun bindSymbol(symbol: Symbol): DSSymbol =
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
public fun DS<T, A>.derivative(symbols: List<Symbol>): T {
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
}
public fun DS<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
}
/**
* A ring over [DS].
*
* @property order The derivation order.
* @param bindings The map of bindings values. All bindings are considered free parameters.
*/
@UnstableKMathAPI
public open class DSRing<T, A>(
algebra: A,
bufferFactory: MutableBufferFactory<T>,
order: Int,
bindings: Map<Symbol, T>,
) : DSAlgebra<T, A>(algebra, bufferFactory, order, bindings),
Ring<DS<T, A>>, ScaleOperations<DS<T, A>>,
NumericAlgebra<DS<T, A>>,
NumbersAddOps<DS<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
override fun bindSymbolOrNull(value: String): DSSymbol? =
super<DSAlgebra>.bindSymbolOrNull(value)
override fun DS<T, A>.unaryMinus(): DS<T, A> = mapData { -it }
/**
* Create a copy of given [Buffer] and modify it according to [block]
*/
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
val newData = bufferFactory(compiler.size) { data[it] }
algebra.block(newData)
return DS(newData)
}
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
val newData: Buffer<T> = data.map(bufferFactory) {
algebra.block(it)
}
return DS(newData)
}
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
val newData: Buffer<T> = data.mapIndexed(bufferFactory, block)
return DS(newData)
}
override val zero: DS<T, A> by lazy {
const(algebra.zero)
}
override val one: DS<T, A> by lazy {
const(algebra.one)
}
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
override fun add(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
require(right.derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
compiler.add(left.data, 0, right.data, 0, result, 0)
}
override fun scale(a: DS<T, A>, value: Double): DS<T, A> = a.mapData {
it.times(value)
}
override fun multiply(
left: DS<T, A>,
right: DS<T, A>,
): DS<T, A> = left.transformDataBuffer { result ->
compiler.multiply(left.data, 0, right.data, 0, result, 0)
}
//
// override fun DS<T, A>.minus(arg: DS): DS<T, A> = transformDataBuffer { result ->
// subtract(data, 0, arg.data, 0, result, 0)
// }
override operator fun DS<T, A>.plus(other: Number): DS<T, A> = transformDataBuffer {
it[0] += number(other)
}
//
// override operator fun DS<T, A>.minus(other: Number): DS<T, A> =
// this + (-other.toDouble())
override operator fun Number.plus(other: DS<T, A>): DS<T, A> = other + this
override operator fun Number.minus(other: DS<T, A>): DS<T, A> = other - this
}
@UnstableKMathAPI
public class DerivativeStructureRingExpression<T, A>(
public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val function: DSRing<T, A>.() -> DS<T, A>,
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
override operator fun invoke(arguments: Map<Symbol, T>): T =
DSRing(algebra, bufferFactory, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
with(
DSRing(
algebra,
bufferFactory,
symbols.size,
arguments
)
) { function().derivative(symbols) }
}
}
/**
* A field over commons-math [DerivativeStructure].
*
* @property order The derivation order.
* @param bindings The map of bindings values. All bindings are considered free parameters.
*/
@UnstableKMathAPI
public class DSField<T, A : ExtendedField<T>>(
algebra: A,
bufferFactory: MutableBufferFactory<T>,
order: Int,
bindings: Map<Symbol, T>,
) : DSRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DS<T, A>> {
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
compiler.divide(left.data, 0, right.data, 0, result, 0)
}
override fun sin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.sin(arg.data, 0, result, 0)
}
override fun cos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.cos(arg.data, 0, result, 0)
}
override fun tan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.tan(arg.data, 0, result, 0)
}
override fun asin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.asin(arg.data, 0, result, 0)
}
override fun acos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.acos(arg.data, 0, result, 0)
}
override fun atan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.atan(arg.data, 0, result, 0)
}
override fun sinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.sinh(arg.data, 0, result, 0)
}
override fun cosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.cosh(arg.data, 0, result, 0)
}
override fun tanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.tanh(arg.data, 0, result, 0)
}
override fun asinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.asinh(arg.data, 0, result, 0)
}
override fun acosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.acosh(arg.data, 0, result, 0)
}
override fun atanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.atanh(arg.data, 0, result, 0)
}
override fun power(arg: DS<T, A>, pow: Number): DS<T, A> = when (pow) {
is Int -> arg.transformDataBuffer { result ->
compiler.pow(arg.data, 0, pow, result, 0)
}
else -> arg.transformDataBuffer { result ->
compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
}
}
override fun sqrt(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.sqrt(arg.data, 0, result, 0)
}
public fun power(arg: DS<T, A>, pow: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.pow(arg.data, 0, pow.data, 0, result, 0)
}
override fun exp(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.exp(arg.data, 0, result, 0)
}
override fun ln(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
compiler.ln(arg.data, 0, result, 0)
}
}
@UnstableKMathAPI
public class DSFieldExpression<T, A : ExtendedField<T>>(
public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val function: DSField<T, A>.() -> DS<T, A>,
) : DifferentiableExpression<T> {
override operator fun invoke(arguments: Map<Symbol, T>): T =
DSField(algebra, bufferFactory, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
DSField(
algebra,
bufferFactory,
symbols.size,
arguments,
).run { function().derivative(symbols) }
}
}

File diff suppressed because it is too large Load Diff

View File

@ -188,7 +188,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
*/ */
public fun <T : Any, A : Ring<T>> buffered( public fun <T : Any, A : Ring<T>> buffered(
algebra: A, algebra: A,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory)) ): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
@Deprecated("use DoubleField.linearSpace") @Deprecated("use DoubleField.linearSpace")

View File

@ -27,5 +27,5 @@ public annotation class UnstableKMathAPI
RequiresOptIn.Level.WARNING, RequiresOptIn.Level.WARNING,
) )
public annotation class PerformancePitfall( public annotation class PerformancePitfall(
val message: String = "Potential performance problem" val message: String = "Potential performance problem",
) )

View File

@ -69,7 +69,7 @@ public class MutableBufferND<T>(
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND] * Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
*/ */
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer( public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto, factory: MutableBufferFactory<R> = MutableBufferFactory(MutableBuffer.Companion::auto),
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): MutableBufferND<R> { ): MutableBufferND<R> {
return if (this is MutableBufferND<T>) return if (this is MutableBufferND<T>)

View File

@ -120,7 +120,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
*/ */
public fun <T> buffered( public fun <T> buffered(
strides: Strides, strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
initializer: (IntArray) -> T, initializer: (IntArray) -> T,
): BufferND<T> = BufferND(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) ): BufferND<T> = BufferND(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
@ -140,7 +140,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
public fun <T> buffered( public fun <T> buffered(
shape: IntArray, shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
initializer: (IntArray) -> T, initializer: (IntArray) -> T,
): BufferND<T> = buffered(DefaultStrides(shape), bufferFactory, initializer) ): BufferND<T> = buffered(DefaultStrides(shape), bufferFactory, initializer)

View File

@ -6,12 +6,10 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.linear.Point import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.asBuffer
import kotlin.math.* import kotlin.math.*
/** /**
@ -21,7 +19,7 @@ public abstract class DoubleBufferOps : BufferAlgebra<Double, DoubleField>, Exte
Norm<Buffer<Double>, Double> { Norm<Buffer<Double>, Double> {
override val elementAlgebra: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer override val bufferFactory: BufferFactory<Double> get() = BufferFactory(::DoubleBuffer)
override fun Buffer<Double>.map(block: DoubleField.(Double) -> Double): DoubleBuffer = override fun Buffer<Double>.map(block: DoubleField.(Double) -> Double): DoubleBuffer =
mapInline { DoubleField.block(it) } mapInline { DoubleField.block(it) }

View File

@ -61,31 +61,39 @@ public inline fun <reified T> Buffer<T>.toTypedArray(): Array<T> = Array(size, :
/** /**
* Create a new buffer from this one with the given mapping function and using [Buffer.Companion.auto] buffer factory. * Create a new buffer from this one with the given mapping function and using [Buffer.Companion.auto] buffer factory.
*/ */
public inline fun <T : Any, reified R : Any> Buffer<T>.map(block: (T) -> R): Buffer<R> = public inline fun <T, reified R : Any> Buffer<T>.map(block: (T) -> R): Buffer<R> =
Buffer.auto(size) { block(get(it)) } Buffer.auto(size) { block(get(it)) }
/** /**
* Create a new buffer from this one with the given mapping function. * Create a new buffer from this one with the given mapping function.
* Provided [bufferFactory] is used to construct the new buffer. * Provided [bufferFactory] is used to construct the new buffer.
*/ */
public inline fun <T : Any, R : Any> Buffer<T>.map( public inline fun <T, R> Buffer<T>.map(
bufferFactory: BufferFactory<R>, bufferFactory: BufferFactory<R>,
crossinline block: (T) -> R, crossinline block: (T) -> R,
): Buffer<R> = bufferFactory(size) { block(get(it)) } ): Buffer<R> = bufferFactory(size) { block(get(it)) }
/** /**
* Create a new buffer from this one with the given indexed mapping function. * Create a new buffer from this one with the given mapping (indexed) function.
* Provided [BufferFactory] is used to construct the new buffer. * Provided [bufferFactory] is used to construct the new buffer.
*/ */
public inline fun <T : Any, reified R : Any> Buffer<T>.mapIndexed( public inline fun <T, R> Buffer<T>.mapIndexed(
bufferFactory: BufferFactory<R> = Buffer.Companion::auto, bufferFactory: BufferFactory<R>,
crossinline block: (index: Int, value: T) -> R, crossinline block: (index: Int, value: T) -> R,
): Buffer<R> = bufferFactory(size) { block(it, get(it)) } ): Buffer<R> = bufferFactory(size) { block(it, get(it)) }
/**
* Create a new buffer from this one with the given indexed mapping function.
* Provided [BufferFactory] is used to construct the new buffer.
*/
public inline fun <T, reified R : Any> Buffer<T>.mapIndexed(
crossinline block: (index: Int, value: T) -> R,
): Buffer<R> = BufferFactory<R>(Buffer.Companion::auto).invoke(size) { block(it, get(it)) }
/** /**
* Fold given buffer according to [operation] * Fold given buffer according to [operation]
*/ */
public inline fun <T : Any, R> Buffer<T>.fold(initial: R, operation: (acc: R, T) -> R): R { public inline fun <T, R> Buffer<T>.fold(initial: R, operation: (acc: R, T) -> R): R {
var accumulator = initial var accumulator = initial
for (index in this.indices) accumulator = operation(accumulator, get(index)) for (index in this.indices) accumulator = operation(accumulator, get(index))
return accumulator return accumulator
@ -95,9 +103,9 @@ public inline fun <T : Any, R> Buffer<T>.fold(initial: R, operation: (acc: R, T)
* Zip two buffers using given [transform]. * Zip two buffers using given [transform].
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public inline fun <T1 : Any, T2 : Any, reified R : Any> Buffer<T1>.zip( public inline fun <T1, T2 : Any, reified R : Any> Buffer<T1>.zip(
other: Buffer<T2>, other: Buffer<T2>,
bufferFactory: BufferFactory<R> = Buffer.Companion::auto, bufferFactory: BufferFactory<R> = BufferFactory(Buffer.Companion::auto),
crossinline transform: (T1, T2) -> R, crossinline transform: (T1, T2) -> R,
): Buffer<R> { ): Buffer<R> {
require(size == other.size) { "Buffer size mismatch in zip: expected $size but found ${other.size}" } require(size == other.size) { "Buffer size mismatch in zip: expected $size but found ${other.size}" }

View File

@ -14,14 +14,18 @@ import kotlin.reflect.KClass
* *
* @param T the type of buffer. * @param T the type of buffer.
*/ */
public typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T> public fun interface BufferFactory<T> {
public operator fun invoke(size: Int, builder: (Int) -> T): Buffer<T>
}
/** /**
* Function that produces [MutableBuffer] from its size and function that supplies values. * Function that produces [MutableBuffer] from its size and function that supplies values.
* *
* @param T the type of buffer. * @param T the type of buffer.
*/ */
public typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T> public fun interface MutableBufferFactory<T>: BufferFactory<T>{
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T>
}
/** /**
* A generic read-only random-access structure for both primitives and objects. * A generic read-only random-access structure for both primitives and objects.

View File

@ -0,0 +1,59 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
@file:OptIn(UnstableKMathAPI::class)
package space.kscience.kmath.expressions
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFails
internal inline fun diff(
order: Int,
vararg parameters: Pair<Symbol, Double>,
block: DSField<Double, DoubleField>.() -> Unit,
) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
DSField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
}
internal class AutoDiffTest {
private val x by symbol
private val y by symbol
@Test
fun dsAlgebraTest() {
diff(2, x to 1.0, y to 1.0) {
val x = bindSymbol(x)//by binding()
val y = bindSymbol("y")
val z = x * (-sin(x * y) + y) + 2.0
println(z.derivative(x))
println(z.derivative(y, x))
assertEquals(z.derivative(x, y), z.derivative(y, x))
// check improper order cause failure
assertFails { z.derivative(x, x, y) }
}
}
@Test
fun dsExpressionTest() {
val f = DSFieldExpression(DoubleField, ::DoubleBuffer) {
val x by binding
val y by binding
x.pow(2) + 2 * x * y + y.pow(2) + 1
}
assertEquals(10.0, f(x to 1.0, y to 2.0))
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
}
}

View File

@ -28,7 +28,7 @@ public class UniformHistogramGroupND<V : Any, A : Field<V>>(
private val lower: Buffer<Double>, private val lower: Buffer<Double>,
private val upper: Buffer<Double>, private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 }, private val binNums: IntArray = IntArray(lower.size) { 20 },
private val bufferFactory: BufferFactory<V> = Buffer.Companion::boxing, private val bufferFactory: BufferFactory<V> = BufferFactory(Buffer.Companion::boxing),
) : HistogramGroupND<Double, HyperSquareDomain, V> { ) : HistogramGroupND<Double, HyperSquareDomain, V> {
init { init {
@ -114,7 +114,7 @@ public class UniformHistogramGroupND<V : Any, A : Field<V>>(
public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges( public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges(
valueAlgebraND: FieldOpsND<V, A>, valueAlgebraND: FieldOpsND<V, A>,
vararg ranges: ClosedFloatingPointRange<Double>, vararg ranges: ClosedFloatingPointRange<Double>,
bufferFactory: BufferFactory<V> = Buffer.Companion::boxing, bufferFactory: BufferFactory<V> = BufferFactory(Buffer.Companion::boxing),
): UniformHistogramGroupND<V, A> = UniformHistogramGroupND( ): UniformHistogramGroupND<V, A> = UniformHistogramGroupND(
valueAlgebraND, valueAlgebraND,
ranges.map(ClosedFloatingPointRange<Double>::start).asBuffer(), ranges.map(ClosedFloatingPointRange<Double>::start).asBuffer(),
@ -140,7 +140,7 @@ public fun Histogram.Companion.uniformDoubleNDFromRanges(
public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges( public fun <V : Any, A : Field<V>> Histogram.Companion.uniformNDFromRanges(
valueAlgebraND: FieldOpsND<V, A>, valueAlgebraND: FieldOpsND<V, A>,
vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>, vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>,
bufferFactory: BufferFactory<V> = Buffer.Companion::boxing, bufferFactory: BufferFactory<V> = BufferFactory(Buffer.Companion::boxing),
): UniformHistogramGroupND<V, A> = UniformHistogramGroupND( ): UniformHistogramGroupND<V, A> = UniformHistogramGroupND(
valueAlgebraND, valueAlgebraND,
ListBuffer( ListBuffer(

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.multik package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.ndarray.data.DataType import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExponentialOperations import space.kscience.kmath.operations.ExponentialOperations
@ -22,10 +23,13 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
override fun tan(arg: StructureND<Double>): MultikTensor<Double> = sin(arg) / cos(arg) override fun tan(arg: StructureND<Double>): MultikTensor<Double> = sin(arg) / cos(arg)
@PerformancePitfall
override fun asin(arg: StructureND<Double>): MultikTensor<Double> = arg.map { asin(it) } override fun asin(arg: StructureND<Double>): MultikTensor<Double> = arg.map { asin(it) }
@PerformancePitfall
override fun acos(arg: StructureND<Double>): MultikTensor<Double> = arg.map { acos(it) } override fun acos(arg: StructureND<Double>): MultikTensor<Double> = arg.map { acos(it) }
@PerformancePitfall
override fun atan(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atan(it) } override fun atan(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atan(it) }
override fun exp(arg: StructureND<Double>): MultikTensor<Double> = multikMath.mathEx.exp(arg.asMultik().array).wrap() override fun exp(arg: StructureND<Double>): MultikTensor<Double> = multikMath.mathEx.exp(arg.asMultik().array).wrap()
@ -42,10 +46,13 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
return (expPlus - expMinus) / (expPlus + expMinus) return (expPlus - expMinus) / (expPlus + expMinus)
} }
@PerformancePitfall
override fun asinh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { asinh(it) } override fun asinh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { asinh(it) }
@PerformancePitfall
override fun acosh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { acosh(it) } override fun acosh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { acosh(it) }
@PerformancePitfall
override fun atanh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atanh(it) } override fun atanh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atanh(it) }
} }

View File

@ -48,7 +48,8 @@ internal object InternalUtils {
cache.copyInto( cache.copyInto(
logFactorials, logFactorials,
BEGIN_LOG_FACTORIALS, BEGIN_LOG_FACTORIALS,
BEGIN_LOG_FACTORIALS, endCopy BEGIN_LOG_FACTORIALS,
endCopy,
) )
} else } else
// All values to be computed // All values to be computed

View File

@ -35,7 +35,7 @@ public fun interface Sampler<out T : Any> {
public fun <T : Any> Sampler<T>.sampleBuffer( public fun <T : Any> Sampler<T>.sampleBuffer(
generator: RandomGenerator, generator: RandomGenerator,
size: Int, size: Int,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
): Chain<Buffer<T>> { ): Chain<Buffer<T>> {
require(size > 1) require(size > 1)
//creating temporary storage once //creating temporary storage once