Replaced Double in algebra by numbers, DiffExpressions

This commit is contained in:
Alexander Nozik 2019-01-13 10:42:53 +03:00
parent 4c1547ba5c
commit 037735c210
13 changed files with 192 additions and 41 deletions

View File

@ -0,0 +1,15 @@
plugins {
kotlin("jvm")
}
dependencies {
api(project(":kmath-core"))
api("org.apache.commons:commons-math3:3.6.1")
testImplementation("org.jetbrains.kotlin:kotlin-test")
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
}
//dependencies {
//// compile(project(":kmath-core"))
//// //compile project(":kmath-coroutines")
////}

View File

@ -0,0 +1,106 @@
package scientifik.kmath.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import scientifik.kmath.operations.ExtendedField
import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KProperty
/**
* A field wrapping commons-math derivative structures
*/
class DerivativeStructureField(val order: Int, val parameters: Map<String, Double>) :
ExtendedField<DerivativeStructure> {
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
}
val variable = object : ReadOnlyProperty<Any?, DerivativeStructure> {
override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure {
return variables[property.name] ?: error("A variable with name ${property.name} does not exist")
}
}
fun variable(name: String): DerivativeStructure =
variables[name] ?: error("A variable with name ${name} does not exist")
fun Number.const() = DerivativeStructure(order, parameters.size, toDouble())
fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
return deriv(mapOf(parName to order))
}
fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
}
fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
is Double -> a.multiply(k)
is Int -> a.multiply(k)
else -> a.multiply(k.toDouble())
}
override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow)
is Int -> arg.pow(pow)
else -> arg.pow(pow.toDouble())
}
fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble())
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble())
operator fun Number.plus(s: DerivativeStructure) = s + this
operator fun Number.minus(s: DerivativeStructure) = s - this
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
override fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(0, arguments)
.run(function).value
/**
* Get the derivative expression with given orders
* TODO make result [DiffExpression]
*/
fun derivative(orders: Map<String, Int>): Expression<Double> {
return object : Expression<Double> {
override fun invoke(arguments: Map<String, Double>): Double =
DerivativeStructureField(orders.values.max() ?: 0, arguments)
.run {
function().deriv(orders)
}
}
}
//TODO add gradient and maybe other vector operators
}
fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(mapOf(*orders))
fun DiffExpression.derivative(name: String) = derivative(name to 1)

View File

@ -0,0 +1,31 @@
package scientifik.kmath.expressions
import org.junit.Test
import kotlin.test.assertEquals
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R) =
DerivativeStructureField(order, mapOf(*parameters)).run(block)
class AutoDiffTest {
@Test
fun derivativeStructureFieldTest() {
val res = diff(3, "x" to 1.0, "y" to 1.0) {
val x by variable
val y = variable("y")
val z = x * (-sin(x * y) + y)
z.deriv("x")
}
}
@Test
fun autoDifTest() {
val f = DiffExpression {
val x by variable
val y by variable
x.pow(2) + 2 * x * y + y.pow(2) + 1
}
assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
}
}

View File

@ -1,6 +1,7 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
@ -18,7 +19,7 @@ interface ExpressionContext<T> {
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class ConstantExpression<T>(val value: T) : Expression<T> {
@ -30,13 +31,13 @@ internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>,
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
}
internal class ProductExpression<T>(val context: Field<T>, val first: Expression<T>, val second: Expression<T>) :
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.multiply(first.invoke(arguments), second.invoke(arguments))
}
internal class ConstProductExpession<T>(val context: Field<T>, val expr: Expression<T>, val const: Double) :
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
@ -58,7 +59,7 @@ class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, Expression
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(field, a, b)
override fun multiply(a: Expression<T>, k: Double): Expression<T> = ConstProductExpession(field, a, k)
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(field, a, k)
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)

View File

@ -6,7 +6,6 @@ import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.*
import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory
import scientifik.kmath.structures.Buffer.Companion.auto
import scientifik.kmath.structures.Buffer.Companion.boxing
@ -19,8 +18,6 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
val rowNum: Int
val colNum: Int
val shape get() = intArrayOf(rowNum, colNum)
/**
* Produce a matrix with this context and given dimensions
*/
@ -38,7 +35,7 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> =
override fun multiply(a: Matrix<T, R>, k: Number): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
companion object {
@ -61,8 +58,8 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
/**
* Automatic buffered matrix, unboxed if it is possible
*/
inline fun <reified T : Any, R : Ring<T>> smart(rows: Int, columns: Int, ring: R): MatrixSpace<T, R> =
buffered(rows, columns, ring, ::auto)
inline fun <reified T : Any, R : Ring<T>> auto(rows: Int, columns: Int, ring: R): MatrixSpace<T, R> =
buffered(rows, columns, ring, Buffer.Companion::auto)
}
}
@ -80,21 +77,19 @@ interface Matrix<T : Any, R : Ring<T>> : NDStructure<T>, SpaceElement<Matrix<T,
override fun get(index: IntArray): T = get(index[0], index[1])
override val shape: IntArray get() = context.shape
val numRows get() = context.rowNum
val numCols get() = context.colNum
//TODO replace by lazy buffers
val rows: List<Point<T>>
get() = (0 until numRows).map { i ->
val rows: Point<Point<T>>
get() = ListBuffer((0 until numRows).map { i ->
context.point(numCols) { j -> get(i, j) }
}
})
val columns: List<Point<T>>
get() = (0 until numCols).map { j ->
val columns: Point<Point<T>>
get() = ListBuffer((0 until numCols).map { j ->
context.point(numRows) { i -> get(i, j) }
}
})
val features: Set<MatrixFeature>
@ -134,7 +129,7 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
private val bufferFactory: BufferFactory<T>
) : MatrixSpace<T, R> {
override val shape: IntArray = intArrayOf(rowNum, colNum)
val shape: IntArray = intArrayOf(rowNum, colNum)
private val strides = DefaultStrides(shape)

View File

@ -30,7 +30,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
override fun multiply(a: Point<T>, k: Double): Point<T> = produce { with(space) { a[it] * k } }
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
//TODO add basis

View File

@ -7,7 +7,7 @@ package scientifik.kmath.operations
* One must note that in some cases context is a singleton class, but in some cases it
* works as a context for operations inside it.
*
* TODO do we need commutative context?
* TODO do we need non-commutative context?
*/
interface Space<T> {
/**
@ -23,7 +23,7 @@ interface Space<T> {
/**
* Multiplication operation for context element and real number
*/
fun multiply(a: T, k: Double): T
fun multiply(a: T, k: Number): T
//Operation to be performed in this context
operator fun T.unaryMinus(): T = multiply(this, -1.0)

View File

@ -12,7 +12,7 @@ object ComplexField : Field<Complex> {
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k)
override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
override fun multiply(a: Complex, b: Complex): Complex =
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)

View File

@ -34,21 +34,23 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0
override fun add(a: Double, b: Double): Double = a + b
override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b
override fun multiply(a: Double, b: Double): Double = a * b
override fun multiply(a: Double, k: Number): Double = a * k.toDouble()
override val one: Double = 1.0
override fun divide(a: Double, b: Double): Double = a / b
override fun sin(arg: Double): Double = kotlin.math.sin(arg)
override fun cos(arg: Double): Double = kotlin.math.cos(arg)
override fun power(arg: Double, pow: Double): Double = arg.pow(pow)
override fun power(arg: Double, pow: Number): Double = arg.pow(pow.toDouble())
override fun exp(arg: Double): Double = kotlin.math.exp(arg)
override fun ln(arg: Double): Double = kotlin.math.ln(arg)
override fun norm(arg: Double): Double = kotlin.math.abs(arg)
override fun Double.unaryMinus(): Double = -this
override fun Double.unaryMinus(): Double = -this
override fun Double.minus(b: Double): Double = this - b
}
@ -56,51 +58,51 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
/**
* A field for [Int] without boxing. Does not produce corresponding field element
*/
object IntRing : Ring<Int>, Norm<Int,Int> {
object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int = 0
override fun add(a: Int, b: Int): Int = a + b
override fun multiply(a: Int, b: Int): Int = a * b
override fun multiply(a: Int, k: Double): Int = (k * a).toInt()
override fun multiply(a: Int, k: Number): Int = (k * a)
override val one: Int = 1
override fun norm(arg: Int): Int = arg
override fun norm(arg: Int): Int = arg
}
/**
* A field for [Short] without boxing. Does not produce appropriate field element
*/
object ShortRing : Ring<Short>, Norm<Short,Short>{
object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short = 0
override fun add(a: Short, b: Short): Short = (a + b).toShort()
override fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override fun multiply(a: Short, k: Double): Short = (a * k).toShort()
override fun multiply(a: Short, k: Number): Short = (a * k)
override val one: Short = 1
override fun norm(arg: Short): Short = arg
override fun norm(arg: Short): Short = arg
}
/**
* A field for [Byte] values
*/
object ByteRing : Ring<Byte>, Norm<Byte,Byte> {
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte = 0
override fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
override fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override fun multiply(a: Byte, k: Double): Byte = (a * k).toByte()
override fun multiply(a: Byte, k: Number): Byte = (a * k)
override val one: Byte = 1
override fun norm(arg: Byte): Byte = arg
override fun norm(arg: Byte): Byte = arg
}
/**
* A field for [Long] values
*/
object LongRing : Ring<Long>, Norm<Long,Long> {
object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long = 0
override fun add(a: Long, b: Long): Long = (a + b)
override fun multiply(a: Long, b: Long): Long = (a * b)
override fun multiply(a: Long, k: Double): Long = (a * k).toLong()
override fun multiply(a: Long, k: Number): Long = (a * k)
override val one: Long = 1
override fun norm(arg: Long): Long = arg
override fun norm(arg: Long): Long = arg
}

View File

@ -30,7 +30,7 @@ fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.conte
* A context extension to include power operations like square roots, etc
*/
interface PowerOperations<T> {
fun power(arg: T, pow: Double): T
fun power(arg: T, pow: Number): T
fun sqrt(arg: T) = power(arg, 0.5)
}

View File

@ -70,7 +70,7 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
/**
* Multiply all elements by constant
*/
override fun multiply(a: N, k: Double): N = map(a) { multiply(it, k) }
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }

View File

@ -65,7 +65,7 @@ class RealNDField(override val shape: IntArray) :
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
BufferedNDFieldElement(this@RealNDField, buffer)
override fun power(arg: NDBuffer<Double>, pow: Double) = map(arg) { power(it, pow) }
override fun power(arg: NDBuffer<Double>, pow: Number) = map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Double>) = map(arg) { exp(it) }

View File

@ -12,5 +12,6 @@ include(
":kmath-core",
":kmath-io",
":kmath-coroutines",
":kmath-commons",
":benchmarks"
)