Replaced Double in algebra by numbers, DiffExpressions
This commit is contained in:
parent
4c1547ba5c
commit
037735c210
15
kmath-commons/build.gradle.kts
Normal file
15
kmath-commons/build.gradle.kts
Normal file
@ -0,0 +1,15 @@
|
||||
plugins {
|
||||
kotlin("jvm")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api("org.apache.commons:commons-math3:3.6.1")
|
||||
testImplementation("org.jetbrains.kotlin:kotlin-test")
|
||||
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
|
||||
}
|
||||
|
||||
//dependencies {
|
||||
//// compile(project(":kmath-core"))
|
||||
//// //compile project(":kmath-coroutines")
|
||||
////}
|
@ -0,0 +1,106 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
import kotlin.reflect.KProperty
|
||||
|
||||
/**
|
||||
* A field wrapping commons-math derivative structures
|
||||
*/
|
||||
class DerivativeStructureField(val order: Int, val parameters: Map<String, Double>) :
|
||||
ExtendedField<DerivativeStructure> {
|
||||
|
||||
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
|
||||
|
||||
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
|
||||
|
||||
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
|
||||
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
||||
}
|
||||
|
||||
val variable = object : ReadOnlyProperty<Any?, DerivativeStructure> {
|
||||
override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure {
|
||||
return variables[property.name] ?: error("A variable with name ${property.name} does not exist")
|
||||
}
|
||||
}
|
||||
|
||||
fun variable(name: String): DerivativeStructure =
|
||||
variables[name] ?: error("A variable with name ${name} does not exist")
|
||||
|
||||
|
||||
fun Number.const() = DerivativeStructure(order, parameters.size, toDouble())
|
||||
|
||||
fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
||||
return deriv(mapOf(parName to order))
|
||||
}
|
||||
|
||||
fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
|
||||
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
|
||||
|
||||
override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||
|
||||
override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||
is Double -> a.multiply(k)
|
||||
is Int -> a.multiply(k)
|
||||
else -> a.multiply(k.toDouble())
|
||||
}
|
||||
|
||||
override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
|
||||
|
||||
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
||||
|
||||
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
||||
|
||||
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
||||
|
||||
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
|
||||
is Double -> arg.pow(pow)
|
||||
is Int -> arg.pow(pow)
|
||||
else -> arg.pow(pow.toDouble())
|
||||
}
|
||||
|
||||
fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
|
||||
|
||||
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
|
||||
|
||||
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
||||
|
||||
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble())
|
||||
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble())
|
||||
operator fun Number.plus(s: DerivativeStructure) = s + this
|
||||
operator fun Number.minus(s: DerivativeStructure) = s - this
|
||||
}
|
||||
|
||||
/**
|
||||
* A constructs that creates a derivative structure with required order on-demand
|
||||
*/
|
||||
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
|
||||
override fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(0, arguments)
|
||||
.run(function).value
|
||||
|
||||
/**
|
||||
* Get the derivative expression with given orders
|
||||
* TODO make result [DiffExpression]
|
||||
*/
|
||||
fun derivative(orders: Map<String, Int>): Expression<Double> {
|
||||
return object : Expression<Double> {
|
||||
override fun invoke(arguments: Map<String, Double>): Double =
|
||||
DerivativeStructureField(orders.values.max() ?: 0, arguments)
|
||||
.run {
|
||||
function().deriv(orders)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO add gradient and maybe other vector operators
|
||||
}
|
||||
|
||||
fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(mapOf(*orders))
|
||||
fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
||||
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R) =
|
||||
DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||
|
||||
class AutoDiffTest {
|
||||
@Test
|
||||
fun derivativeStructureFieldTest() {
|
||||
val res = diff(3, "x" to 1.0, "y" to 1.0) {
|
||||
val x by variable
|
||||
val y = variable("y")
|
||||
val z = x * (-sin(x * y) + y)
|
||||
z.deriv("x")
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun autoDifTest() {
|
||||
val f = DiffExpression {
|
||||
val x by variable
|
||||
val y by variable
|
||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||
}
|
||||
|
||||
assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
|
||||
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
|
||||
@ -30,13 +31,13 @@ internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>,
|
||||
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ProductExpression<T>(val context: Field<T>, val first: Expression<T>, val second: Expression<T>) :
|
||||
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ConstProductExpession<T>(val context: Field<T>, val expr: Expression<T>, val const: Double) :
|
||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||
}
|
||||
@ -58,7 +59,7 @@ class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, Expression
|
||||
|
||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(field, a, b)
|
||||
|
||||
override fun multiply(a: Expression<T>, k: Double): Expression<T> = ConstProductExpession(field, a, k)
|
||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(field, a, k)
|
||||
|
||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
||||
|
||||
|
@ -6,7 +6,6 @@ import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.operations.SpaceElement
|
||||
import scientifik.kmath.structures.*
|
||||
import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory
|
||||
import scientifik.kmath.structures.Buffer.Companion.auto
|
||||
import scientifik.kmath.structures.Buffer.Companion.boxing
|
||||
|
||||
|
||||
@ -19,8 +18,6 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
|
||||
val rowNum: Int
|
||||
val colNum: Int
|
||||
|
||||
val shape get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
/**
|
||||
* Produce a matrix with this context and given dimensions
|
||||
*/
|
||||
@ -38,7 +35,7 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
|
||||
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> =
|
||||
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
|
||||
|
||||
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> =
|
||||
override fun multiply(a: Matrix<T, R>, k: Number): Matrix<T, R> =
|
||||
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
|
||||
|
||||
companion object {
|
||||
@ -61,8 +58,8 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
|
||||
/**
|
||||
* Automatic buffered matrix, unboxed if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any, R : Ring<T>> smart(rows: Int, columns: Int, ring: R): MatrixSpace<T, R> =
|
||||
buffered(rows, columns, ring, ::auto)
|
||||
inline fun <reified T : Any, R : Ring<T>> auto(rows: Int, columns: Int, ring: R): MatrixSpace<T, R> =
|
||||
buffered(rows, columns, ring, Buffer.Companion::auto)
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,21 +77,19 @@ interface Matrix<T : Any, R : Ring<T>> : NDStructure<T>, SpaceElement<Matrix<T,
|
||||
|
||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||
|
||||
override val shape: IntArray get() = context.shape
|
||||
|
||||
val numRows get() = context.rowNum
|
||||
val numCols get() = context.colNum
|
||||
|
||||
//TODO replace by lazy buffers
|
||||
val rows: List<Point<T>>
|
||||
get() = (0 until numRows).map { i ->
|
||||
val rows: Point<Point<T>>
|
||||
get() = ListBuffer((0 until numRows).map { i ->
|
||||
context.point(numCols) { j -> get(i, j) }
|
||||
}
|
||||
})
|
||||
|
||||
val columns: List<Point<T>>
|
||||
get() = (0 until numCols).map { j ->
|
||||
val columns: Point<Point<T>>
|
||||
get() = ListBuffer((0 until numCols).map { j ->
|
||||
context.point(numRows) { i -> get(i, j) }
|
||||
}
|
||||
})
|
||||
|
||||
val features: Set<MatrixFeature>
|
||||
|
||||
@ -134,7 +129,7 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
|
||||
private val bufferFactory: BufferFactory<T>
|
||||
) : MatrixSpace<T, R> {
|
||||
|
||||
override val shape: IntArray = intArrayOf(rowNum, colNum)
|
||||
val shape: IntArray = intArrayOf(rowNum, colNum)
|
||||
|
||||
private val strides = DefaultStrides(shape)
|
||||
|
||||
|
@ -30,7 +30,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||
|
||||
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
|
||||
|
||||
override fun multiply(a: Point<T>, k: Double): Point<T> = produce { with(space) { a[it] * k } }
|
||||
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
|
||||
|
||||
//TODO add basis
|
||||
|
||||
|
@ -7,7 +7,7 @@ package scientifik.kmath.operations
|
||||
* One must note that in some cases context is a singleton class, but in some cases it
|
||||
* works as a context for operations inside it.
|
||||
*
|
||||
* TODO do we need commutative context?
|
||||
* TODO do we need non-commutative context?
|
||||
*/
|
||||
interface Space<T> {
|
||||
/**
|
||||
@ -23,7 +23,7 @@ interface Space<T> {
|
||||
/**
|
||||
* Multiplication operation for context element and real number
|
||||
*/
|
||||
fun multiply(a: T, k: Double): T
|
||||
fun multiply(a: T, k: Number): T
|
||||
|
||||
//Operation to be performed in this context
|
||||
operator fun T.unaryMinus(): T = multiply(this, -1.0)
|
||||
|
@ -12,7 +12,7 @@ object ComplexField : Field<Complex> {
|
||||
|
||||
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
||||
|
||||
override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k)
|
||||
override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
|
||||
|
||||
override fun multiply(a: Complex, b: Complex): Complex =
|
||||
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
||||
|
@ -34,14 +34,16 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
override val zero: Double = 0.0
|
||||
override fun add(a: Double, b: Double): Double = a + b
|
||||
override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b
|
||||
override fun multiply(a: Double, b: Double): Double = a * b
|
||||
override fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||
|
||||
override val one: Double = 1.0
|
||||
override fun divide(a: Double, b: Double): Double = a / b
|
||||
|
||||
override fun sin(arg: Double): Double = kotlin.math.sin(arg)
|
||||
override fun cos(arg: Double): Double = kotlin.math.cos(arg)
|
||||
|
||||
override fun power(arg: Double, pow: Double): Double = arg.pow(pow)
|
||||
override fun power(arg: Double, pow: Number): Double = arg.pow(pow.toDouble())
|
||||
|
||||
override fun exp(arg: Double): Double = kotlin.math.exp(arg)
|
||||
override fun ln(arg: Double): Double = kotlin.math.ln(arg)
|
||||
@ -60,7 +62,7 @@ object IntRing : Ring<Int>, Norm<Int,Int> {
|
||||
override val zero: Int = 0
|
||||
override fun add(a: Int, b: Int): Int = a + b
|
||||
override fun multiply(a: Int, b: Int): Int = a * b
|
||||
override fun multiply(a: Int, k: Double): Int = (k * a).toInt()
|
||||
override fun multiply(a: Int, k: Number): Int = (k * a)
|
||||
override val one: Int = 1
|
||||
|
||||
override fun norm(arg: Int): Int = arg
|
||||
@ -73,7 +75,7 @@ object ShortRing : Ring<Short>, Norm<Short,Short>{
|
||||
override val zero: Short = 0
|
||||
override fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
override fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
||||
override fun multiply(a: Short, k: Double): Short = (a * k).toShort()
|
||||
override fun multiply(a: Short, k: Number): Short = (a * k)
|
||||
override val one: Short = 1
|
||||
|
||||
override fun norm(arg: Short): Short = arg
|
||||
@ -86,7 +88,7 @@ object ByteRing : Ring<Byte>, Norm<Byte,Byte> {
|
||||
override val zero: Byte = 0
|
||||
override fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
override fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
||||
override fun multiply(a: Byte, k: Double): Byte = (a * k).toByte()
|
||||
override fun multiply(a: Byte, k: Number): Byte = (a * k)
|
||||
override val one: Byte = 1
|
||||
|
||||
override fun norm(arg: Byte): Byte = arg
|
||||
@ -99,7 +101,7 @@ object LongRing : Ring<Long>, Norm<Long,Long> {
|
||||
override val zero: Long = 0
|
||||
override fun add(a: Long, b: Long): Long = (a + b)
|
||||
override fun multiply(a: Long, b: Long): Long = (a * b)
|
||||
override fun multiply(a: Long, k: Double): Long = (a * k).toLong()
|
||||
override fun multiply(a: Long, k: Number): Long = (a * k)
|
||||
override val one: Long = 1
|
||||
|
||||
override fun norm(arg: Long): Long = arg
|
||||
|
@ -30,7 +30,7 @@ fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.conte
|
||||
* A context extension to include power operations like square roots, etc
|
||||
*/
|
||||
interface PowerOperations<T> {
|
||||
fun power(arg: T, pow: Double): T
|
||||
fun power(arg: T, pow: Number): T
|
||||
fun sqrt(arg: T) = power(arg, 0.5)
|
||||
}
|
||||
|
||||
|
@ -70,7 +70,7 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
|
||||
/**
|
||||
* Multiply all elements by constant
|
||||
*/
|
||||
override fun multiply(a: N, k: Double): N = map(a) { multiply(it, k) }
|
||||
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
||||
|
||||
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
||||
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
||||
|
@ -65,7 +65,7 @@ class RealNDField(override val shape: IntArray) :
|
||||
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
|
||||
BufferedNDFieldElement(this@RealNDField, buffer)
|
||||
|
||||
override fun power(arg: NDBuffer<Double>, pow: Double) = map(arg) { power(it, pow) }
|
||||
override fun power(arg: NDBuffer<Double>, pow: Number) = map(arg) { power(it, pow) }
|
||||
|
||||
override fun exp(arg: NDBuffer<Double>) = map(arg) { exp(it) }
|
||||
|
||||
|
@ -12,5 +12,6 @@ include(
|
||||
":kmath-core",
|
||||
":kmath-io",
|
||||
":kmath-coroutines",
|
||||
":kmath-commons",
|
||||
":benchmarks"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user