From 037735c210e437a4ecf02fbbe24fb37132c6ef8a Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 13 Jan 2019 10:42:53 +0300 Subject: [PATCH] Replaced Double in algebra by numbers, DiffExpressions --- kmath-commons/build.gradle.kts | 15 +++ .../kmath/expressions/DiffExpression.kt | 106 ++++++++++++++++++ .../kmath/expressions/AutoDiffTest.kt | 31 +++++ .../kmath/expressions/Expression.kt | 9 +- .../kotlin/scientifik/kmath/linear/Matrix.kt | 25 ++--- .../kotlin/scientifik/kmath/linear/Vector.kt | 2 +- .../scientifik/kmath/operations/Algebra.kt | 4 +- .../scientifik/kmath/operations/Complex.kt | 2 +- .../kmath/operations/NumberAlgebra.kt | 32 +++--- .../kmath/operations/OptionalOperations.kt | 2 +- .../scientifik/kmath/structures/NDAlgebra.kt | 2 +- .../kmath/structures/RealNDField.kt | 2 +- settings.gradle.kts | 1 + 13 files changed, 192 insertions(+), 41 deletions(-) create mode 100644 kmath-commons/build.gradle.kts create mode 100644 kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt create mode 100644 kmath-commons/src/test/kotlin/scientifik/kmath/expressions/AutoDiffTest.kt diff --git a/kmath-commons/build.gradle.kts b/kmath-commons/build.gradle.kts new file mode 100644 index 000000000..7b859a2b0 --- /dev/null +++ b/kmath-commons/build.gradle.kts @@ -0,0 +1,15 @@ +plugins { + kotlin("jvm") +} + +dependencies { + api(project(":kmath-core")) + api("org.apache.commons:commons-math3:3.6.1") + testImplementation("org.jetbrains.kotlin:kotlin-test") + testImplementation("org.jetbrains.kotlin:kotlin-test-junit") +} + +//dependencies { +//// compile(project(":kmath-core")) +//// //compile project(":kmath-coroutines") +////} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt new file mode 100644 index 000000000..4d5ea1b94 --- /dev/null +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt @@ -0,0 +1,106 @@ +package scientifik.kmath.expressions + +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure +import scientifik.kmath.operations.ExtendedField +import kotlin.properties.ReadOnlyProperty +import kotlin.reflect.KProperty + +/** + * A field wrapping commons-math derivative structures + */ +class DerivativeStructureField(val order: Int, val parameters: Map) : + ExtendedField { + + override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } + + override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } + + private val variables: Map = parameters.mapValues { (key, value) -> + DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) + } + + val variable = object : ReadOnlyProperty { + override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure { + return variables[property.name] ?: error("A variable with name ${property.name} does not exist") + } + } + + fun variable(name: String): DerivativeStructure = + variables[name] ?: error("A variable with name ${name} does not exist") + + + fun Number.const() = DerivativeStructure(order, parameters.size, toDouble()) + + fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { + return deriv(mapOf(parName to order)) + } + + fun DerivativeStructure.deriv(orders: Map): Double { + return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray()) + } + + fun DerivativeStructure.deriv(vararg orders: Pair): Double = deriv(mapOf(*orders)) + + override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) + + override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) { + is Double -> a.multiply(k) + is Int -> a.multiply(k) + else -> a.multiply(k.toDouble()) + } + + override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b) + + override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) + + override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() + + override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + + override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { + is Double -> arg.pow(pow) + is Int -> arg.pow(pow) + else -> arg.pow(pow.toDouble()) + } + + fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) + + override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() + + override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() + + operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble()) + operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble()) + operator fun Number.plus(s: DerivativeStructure) = s + this + operator fun Number.minus(s: DerivativeStructure) = s - this +} + +/** + * A constructs that creates a derivative structure with required order on-demand + */ +class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression { + override fun invoke(arguments: Map): Double = DerivativeStructureField(0, arguments) + .run(function).value + + /** + * Get the derivative expression with given orders + * TODO make result [DiffExpression] + */ + fun derivative(orders: Map): Expression { + return object : Expression { + override fun invoke(arguments: Map): Double = + DerivativeStructureField(orders.values.max() ?: 0, arguments) + .run { + function().deriv(orders) + } + } + } + + //TODO add gradient and maybe other vector operators +} + +fun DiffExpression.derivative(vararg orders: Pair) = derivative(mapOf(*orders)) +fun DiffExpression.derivative(name: String) = derivative(name to 1) + + + diff --git a/kmath-commons/src/test/kotlin/scientifik/kmath/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/scientifik/kmath/expressions/AutoDiffTest.kt new file mode 100644 index 000000000..2695bb918 --- /dev/null +++ b/kmath-commons/src/test/kotlin/scientifik/kmath/expressions/AutoDiffTest.kt @@ -0,0 +1,31 @@ +package scientifik.kmath.expressions + +import org.junit.Test +import kotlin.test.assertEquals + +inline fun diff(order: Int, vararg parameters: Pair, block: DerivativeStructureField.() -> R) = + DerivativeStructureField(order, mapOf(*parameters)).run(block) + +class AutoDiffTest { + @Test + fun derivativeStructureFieldTest() { + val res = diff(3, "x" to 1.0, "y" to 1.0) { + val x by variable + val y = variable("y") + val z = x * (-sin(x * y) + y) + z.deriv("x") + } + } + + @Test + fun autoDifTest() { + val f = DiffExpression { + val x by variable + val y by variable + x.pow(2) + 2 * x * y + y.pow(2) + 1 + } + + assertEquals(10.0, f("x" to 1.0, "y" to 2.0)) + assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0)) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index f7ae09d02..2f54ae1b2 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,6 +1,7 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space @@ -18,7 +19,7 @@ interface ExpressionContext { internal class VariableExpression(val name: String, val default: T? = null) : Expression { override fun invoke(arguments: Map): T = - arguments[name] ?: default ?: error("Parameter not found: $name") + arguments[name] ?: default ?: error("Parameter not found: $name") } internal class ConstantExpression(val value: T) : Expression { @@ -30,13 +31,13 @@ internal class SumExpression(val context: Space, val first: Expression, override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) } -internal class ProductExpression(val context: Field, val first: Expression, val second: Expression) : +internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : Expression { override fun invoke(arguments: Map): T = context.multiply(first.invoke(arguments), second.invoke(arguments)) } -internal class ConstProductExpession(val context: Field, val expr: Expression, val const: Double) : +internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } @@ -58,7 +59,7 @@ class ExpressionField(val field: Field) : Field>, Expression override fun add(a: Expression, b: Expression): Expression = SumExpression(field, a, b) - override fun multiply(a: Expression, k: Double): Expression = ConstProductExpession(field, a, k) + override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(field, a, k) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 35ad86379..0edf748e6 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -6,7 +6,6 @@ import scientifik.kmath.operations.Space import scientifik.kmath.operations.SpaceElement import scientifik.kmath.structures.* import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory -import scientifik.kmath.structures.Buffer.Companion.auto import scientifik.kmath.structures.Buffer.Companion.boxing @@ -19,8 +18,6 @@ interface MatrixSpace> : Space> { val rowNum: Int val colNum: Int - val shape get() = intArrayOf(rowNum, colNum) - /** * Produce a matrix with this context and given dimensions */ @@ -38,7 +35,7 @@ interface MatrixSpace> : Space> { override fun add(a: Matrix, b: Matrix): Matrix = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } } - override fun multiply(a: Matrix, k: Double): Matrix = + override fun multiply(a: Matrix, k: Number): Matrix = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } } companion object { @@ -61,8 +58,8 @@ interface MatrixSpace> : Space> { /** * Automatic buffered matrix, unboxed if it is possible */ - inline fun > smart(rows: Int, columns: Int, ring: R): MatrixSpace = - buffered(rows, columns, ring, ::auto) + inline fun > auto(rows: Int, columns: Int, ring: R): MatrixSpace = + buffered(rows, columns, ring, Buffer.Companion::auto) } } @@ -80,21 +77,19 @@ interface Matrix> : NDStructure, SpaceElement> - get() = (0 until numRows).map { i -> + val rows: Point> + get() = ListBuffer((0 until numRows).map { i -> context.point(numCols) { j -> get(i, j) } - } + }) - val columns: List> - get() = (0 until numCols).map { j -> + val columns: Point> + get() = ListBuffer((0 until numCols).map { j -> context.point(numRows) { i -> get(i, j) } - } + }) val features: Set @@ -134,7 +129,7 @@ data class StructureMatrixSpace>( private val bufferFactory: BufferFactory ) : MatrixSpace { - override val shape: IntArray = intArrayOf(rowNum, colNum) + val shape: IntArray = intArrayOf(rowNum, colNum) private val strides = DefaultStrides(shape) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt index a673c0197..a7373f647 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt @@ -30,7 +30,7 @@ interface VectorSpace> : Space> { override fun add(a: Point, b: Point): Point = produce { with(space) { a[it] + b[it] } } - override fun multiply(a: Point, k: Double): Point = produce { with(space) { a[it] * k } } + override fun multiply(a: Point, k: Number): Point = produce { with(space) { a[it] * k } } //TODO add basis diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index d77df4b7f..1534f4c82 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -7,7 +7,7 @@ package scientifik.kmath.operations * One must note that in some cases context is a singleton class, but in some cases it * works as a context for operations inside it. * - * TODO do we need commutative context? + * TODO do we need non-commutative context? */ interface Space { /** @@ -23,7 +23,7 @@ interface Space { /** * Multiplication operation for context element and real number */ - fun multiply(a: T, k: Double): T + fun multiply(a: T, k: Number): T //Operation to be performed in this context operator fun T.unaryMinus(): T = multiply(this, -1.0) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 107bb2f24..20e8f47ec 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -12,7 +12,7 @@ object ComplexField : Field { override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) - override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k) + override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 953f2e326..a57688ad8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -34,21 +34,23 @@ inline class Real(val value: Double) : FieldElement { object RealField : ExtendedField, Norm { override val zero: Double = 0.0 override fun add(a: Double, b: Double): Double = a + b - override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b + override fun multiply(a: Double, b: Double): Double = a * b + override fun multiply(a: Double, k: Number): Double = a * k.toDouble() + override val one: Double = 1.0 override fun divide(a: Double, b: Double): Double = a / b override fun sin(arg: Double): Double = kotlin.math.sin(arg) override fun cos(arg: Double): Double = kotlin.math.cos(arg) - override fun power(arg: Double, pow: Double): Double = arg.pow(pow) + override fun power(arg: Double, pow: Number): Double = arg.pow(pow.toDouble()) override fun exp(arg: Double): Double = kotlin.math.exp(arg) override fun ln(arg: Double): Double = kotlin.math.ln(arg) override fun norm(arg: Double): Double = kotlin.math.abs(arg) - override fun Double.unaryMinus(): Double = -this + override fun Double.unaryMinus(): Double = -this override fun Double.minus(b: Double): Double = this - b } @@ -56,51 +58,51 @@ object RealField : ExtendedField, Norm { /** * A field for [Int] without boxing. Does not produce corresponding field element */ -object IntRing : Ring, Norm { +object IntRing : Ring, Norm { override val zero: Int = 0 override fun add(a: Int, b: Int): Int = a + b override fun multiply(a: Int, b: Int): Int = a * b - override fun multiply(a: Int, k: Double): Int = (k * a).toInt() + override fun multiply(a: Int, k: Number): Int = (k * a) override val one: Int = 1 - override fun norm(arg: Int): Int = arg + override fun norm(arg: Int): Int = arg } /** * A field for [Short] without boxing. Does not produce appropriate field element */ -object ShortRing : Ring, Norm{ +object ShortRing : Ring, Norm { override val zero: Short = 0 override fun add(a: Short, b: Short): Short = (a + b).toShort() override fun multiply(a: Short, b: Short): Short = (a * b).toShort() - override fun multiply(a: Short, k: Double): Short = (a * k).toShort() + override fun multiply(a: Short, k: Number): Short = (a * k) override val one: Short = 1 - override fun norm(arg: Short): Short = arg + override fun norm(arg: Short): Short = arg } /** * A field for [Byte] values */ -object ByteRing : Ring, Norm { +object ByteRing : Ring, Norm { override val zero: Byte = 0 override fun add(a: Byte, b: Byte): Byte = (a + b).toByte() override fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() - override fun multiply(a: Byte, k: Double): Byte = (a * k).toByte() + override fun multiply(a: Byte, k: Number): Byte = (a * k) override val one: Byte = 1 - override fun norm(arg: Byte): Byte = arg + override fun norm(arg: Byte): Byte = arg } /** * A field for [Long] values */ -object LongRing : Ring, Norm { +object LongRing : Ring, Norm { override val zero: Long = 0 override fun add(a: Long, b: Long): Long = (a + b) override fun multiply(a: Long, b: Long): Long = (a * b) - override fun multiply(a: Long, k: Double): Long = (a * k).toLong() + override fun multiply(a: Long, k: Number): Long = (a * k) override val one: Long = 1 - override fun norm(arg: Long): Long = arg + override fun norm(arg: Long): Long = arg } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index 7a7866966..66ca205a1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -30,7 +30,7 @@ fun >> ctg(arg: T): T = arg.conte * A context extension to include power operations like square roots, etc */ interface PowerOperations { - fun power(arg: T, pow: Double): T + fun power(arg: T, pow: Number): T fun sqrt(arg: T) = power(arg, 0.5) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt index 63818afcd..097d52723 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt @@ -70,7 +70,7 @@ interface NDSpace, N : NDStructure> : Space, NDAlgebra add(arg, value) } operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 62c2338c6..bc5832e1c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -65,7 +65,7 @@ class RealNDField(override val shape: IntArray) : override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = BufferedNDFieldElement(this@RealNDField, buffer) - override fun power(arg: NDBuffer, pow: Double) = map(arg) { power(it, pow) } + override fun power(arg: NDBuffer, pow: Number) = map(arg) { power(it, pow) } override fun exp(arg: NDBuffer) = map(arg) { exp(it) } diff --git a/settings.gradle.kts b/settings.gradle.kts index 3a8aef730..fa0bb49a0 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -12,5 +12,6 @@ include( ":kmath-core", ":kmath-io", ":kmath-coroutines", + ":kmath-commons", ":benchmarks" )