From f364060acfa20539b5c4ff7b89a89c3b2236efce Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 11 Jun 2020 12:16:22 +0700 Subject: [PATCH 01/69] Add project stub --- kmath-nd4j/build.gradle.kts | 3 +++ settings.gradle.kts | 1 + 2 files changed, 4 insertions(+) create mode 100644 kmath-nd4j/build.gradle.kts diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts new file mode 100644 index 000000000..5da7d66b7 --- /dev/null +++ b/kmath-nd4j/build.gradle.kts @@ -0,0 +1,3 @@ +plugins { + id("scientifik.jvm") +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 57173250b..afb5598b4 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -39,6 +39,7 @@ include( ":kmath-commons", ":kmath-viktor", ":kmath-koma", + ":kmath-nd4j", ":kmath-prob", ":kmath-io", ":kmath-dimensions", From 3df9892de53c0aa719091ad3d0c4aa0405081655 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 11 Jun 2020 14:10:39 +0700 Subject: [PATCH 02/69] Implement the ND4J module for scalars --- build.gradle.kts | 1 + kmath-nd4j/build.gradle.kts | 10 ++++++++-- .../INDArrayScalarsIterator.kt | 19 +++++++++++++++++++ .../scientifik.kmath.nd4j/ND4JStructure.kt | 14 ++++++++++++++ 4 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt create mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt diff --git a/build.gradle.kts b/build.gradle.kts index 6d102a77a..10e030520 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -11,6 +11,7 @@ allprojects { repositories { jcenter() maven("https://dl.bintray.com/kotlin/kotlinx") + mavenCentral() } group = "scientifik" diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts index 5da7d66b7..59354a8f9 100644 --- a/kmath-nd4j/build.gradle.kts +++ b/kmath-nd4j/build.gradle.kts @@ -1,3 +1,9 @@ -plugins { - id("scientifik.jvm") +plugins { id("scientifik.jvm") } + +dependencies { + api(project(":kmath-core")) + api("org.nd4j:nd4j-api:1.0.0-beta7") + testImplementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7") + testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") + testImplementation("org.slf4j:slf4j-simple:1.7.30") } diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt new file mode 100644 index 000000000..2c2dc970f --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt @@ -0,0 +1,19 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.api.shape.Shape + +internal class INDArrayScalarsIterator(private val iterateOver: INDArray) : Iterator> { + private var i: Int = 0 + + override fun hasNext(): Boolean = i < iterateOver.length() + + override fun next(): Pair { + val idx = if (iterateOver.ordering() == 'c') + Shape.ind2subC(iterateOver, i++.toLong())!! + else + Shape.ind2sub(iterateOver, i++.toLong())!! + + return narrowToIntArray(idx) to iterateOver.getScalar(*idx) + } +} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt new file mode 100644 index 000000000..eb9f9b80c --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt @@ -0,0 +1,14 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import scientifik.kmath.structures.NDStructure + +internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() } + +data class ND4JStructure(val ndArray: INDArray) : NDStructure { + override val shape: IntArray + get() = narrowToIntArray(ndArray.shape()) + + override fun get(index: IntArray): INDArray = ndArray.getScalar(*index) + override fun elements(): Sequence> = Sequence { INDArrayScalarsIterator(ndArray) } +} From 9a4dd315072e6cb27c430c799074b1dd36a06f94 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 11 Jun 2020 14:17:46 +0700 Subject: [PATCH 03/69] Move narrowToIntArray to new file --- kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt | 3 +++ .../src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt new file mode 100644 index 000000000..5d341dd68 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt @@ -0,0 +1,3 @@ +package scientifik.kmath.nd4j + +internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() } \ No newline at end of file diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt index eb9f9b80c..1d0301ff9 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt @@ -3,8 +3,6 @@ package scientifik.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray import scientifik.kmath.structures.NDStructure -internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() } - data class ND4JStructure(val ndArray: INDArray) : NDStructure { override val shape: IntArray get() = narrowToIntArray(ndArray.shape()) From d0cc75098bdc523c1c5f945c5a83a460efa00af8 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 11 Jun 2020 14:36:19 +0700 Subject: [PATCH 04/69] Rework with specialized NDStructure implementations --- .../kotlin/scientifik.kmath.nd4j/Arrays.kt | 3 +- .../INDArrayScalarsIterator.kt | 19 ---------- .../scientifik.kmath.nd4j/ND4JStructure.kt | 12 ------ .../scientifik.kmath.nd4j/NDArrayIterators.kt | 37 +++++++++++++++++++ .../ScalarsND4JStructure.kt | 34 +++++++++++++++++ 5 files changed, 73 insertions(+), 32 deletions(-) delete mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt delete mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt create mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/NDArrayIterators.kt create mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt index 5d341dd68..3d5062a4f 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt @@ -1,3 +1,4 @@ package scientifik.kmath.nd4j -internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() } \ No newline at end of file +internal fun widenToLongArray(ia: IntArray): LongArray = LongArray(ia.size) { ia[it].toLong() } +internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() } diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt deleted file mode 100644 index 2c2dc970f..000000000 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayScalarsIterator.kt +++ /dev/null @@ -1,19 +0,0 @@ -package scientifik.kmath.nd4j - -import org.nd4j.linalg.api.ndarray.INDArray -import org.nd4j.linalg.api.shape.Shape - -internal class INDArrayScalarsIterator(private val iterateOver: INDArray) : Iterator> { - private var i: Int = 0 - - override fun hasNext(): Boolean = i < iterateOver.length() - - override fun next(): Pair { - val idx = if (iterateOver.ordering() == 'c') - Shape.ind2subC(iterateOver, i++.toLong())!! - else - Shape.ind2sub(iterateOver, i++.toLong())!! - - return narrowToIntArray(idx) to iterateOver.getScalar(*idx) - } -} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt deleted file mode 100644 index 1d0301ff9..000000000 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ND4JStructure.kt +++ /dev/null @@ -1,12 +0,0 @@ -package scientifik.kmath.nd4j - -import org.nd4j.linalg.api.ndarray.INDArray -import scientifik.kmath.structures.NDStructure - -data class ND4JStructure(val ndArray: INDArray) : NDStructure { - override val shape: IntArray - get() = narrowToIntArray(ndArray.shape()) - - override fun get(index: IntArray): INDArray = ndArray.getScalar(*index) - override fun elements(): Sequence> = Sequence { INDArrayScalarsIterator(ndArray) } -} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/NDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/NDArrayIterators.kt new file mode 100644 index 000000000..426b1ec2d --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/NDArrayIterators.kt @@ -0,0 +1,37 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.api.shape.Shape + +internal sealed class INDArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { + private var i: Int = 0 + + override fun hasNext(): Boolean = i < iterateOver.length() + + abstract fun getSingle(indices: LongArray): T + + final override fun next(): Pair { + val la = if (iterateOver.ordering() == 'c') + Shape.ind2subC(iterateOver, i++.toLong())!! + else + Shape.ind2sub(iterateOver, i++.toLong())!! + + return narrowToIntArray(la) to getSingle(la) + } +} + +internal class INDArrayDoubleIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices) +} + +internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) +} + +internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices)) +} + +internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices) +} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt new file mode 100644 index 000000000..ef8c3ec2e --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt @@ -0,0 +1,34 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import scientifik.kmath.structures.NDStructure + +interface INDArrayStructureBase : NDStructure { + val ndArray: INDArray + + override val shape: IntArray + get() = narrowToIntArray(ndArray.shape()) + + fun elementsIterator(): Iterator> + override fun elements(): Sequence> = Sequence { elementsIterator() } +} + +data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase { + override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) + override fun get(index: IntArray): Int = ndArray.getInt(*index) +} + +data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructureBase { + override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) + override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) +} + +data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructureBase { + override fun elementsIterator(): Iterator> = INDArrayDoubleIterator(ndArray) + override fun get(index: IntArray): Double = ndArray.getDouble(*index) +} + +data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructureBase { + override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) + override fun get(index: IntArray): Float = ndArray.getFloat(*index) +} From bac6451443f5cf9995453fc780cd2a8d037e60cd Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 21:17:40 +0700 Subject: [PATCH 05/69] Add tests --- .../kmath/nd4j/INDArrayStructureTest.kt | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt new file mode 100644 index 000000000..e851f2e80 --- /dev/null +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.factory.Nd4j +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class INDArrayStructureTest { + @Test + fun testElements() { + val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct = INDArrayDoubleStructure(nd) + val res = struct.elements().map { it.second }.toList() + assertEquals(listOf(1.0, 2.0, 3.0), res) + } + + @Test + fun testShape() { + val nd = Nd4j.rand(10, 2, 3, 6)!! + val struct = INDArrayIntStructure(nd) + assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList()) + } + + @Test + fun testEquals() { + val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct1 = INDArrayDoubleStructure(nd1) + val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct2 = INDArrayDoubleStructure(nd2) + assertEquals(struct1, struct2) + } + + @Test + fun testDimension() { + val nd = Nd4j.rand(8, 16, 3, 7, 1)!! + val struct = INDArrayIntStructure(nd) + assertEquals(5, struct.dimension) + } +} From b6bf741dbe72afdcbc311d6ac89ac8700ac81cb4 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 21:19:19 +0700 Subject: [PATCH 06/69] Replace lambdas with references --- .../main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt | 2 +- .../test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt index ef8c3ec2e..3ffcc110d 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt @@ -10,7 +10,7 @@ interface INDArrayStructureBase : NDStructure { get() = narrowToIntArray(ndArray.shape()) fun elementsIterator(): Iterator> - override fun elements(): Sequence> = Sequence { elementsIterator() } + override fun elements(): Sequence> = Sequence(::elementsIterator) } data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase { diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt index e851f2e80..235b65556 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -9,7 +9,7 @@ internal class INDArrayStructureTest { fun testElements() { val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val struct = INDArrayDoubleStructure(nd) - val res = struct.elements().map { it.second }.toList() + val res = struct.elements().map(Pair::second).toList() assertEquals(listOf(1.0, 2.0, 3.0), res) } From e466f4bdf2448ed028c8895e5d6bb7b147fa0777 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 21:21:16 +0700 Subject: [PATCH 07/69] Add test for get --- .../kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt index 235b65556..239289262 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -3,6 +3,7 @@ package scientifik.kmath.nd4j import org.nd4j.linalg.factory.Nd4j import kotlin.test.Test import kotlin.test.assertEquals +import scientifik.kmath.structures.get internal class INDArrayStructureTest { @Test @@ -35,4 +36,11 @@ internal class INDArrayStructureTest { val struct = INDArrayIntStructure(nd) assertEquals(5, struct.dimension) } + + @Test + fun testGet() { + val nd = Nd4j.rand(10, 2, 3, 6)!! + val struct = INDArrayIntStructure(nd) + assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0]) + } } From fefa0db86ed15a7583f442aa1d2e958dcdfee5c3 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 28 Jun 2020 17:29:00 +0700 Subject: [PATCH 08/69] Rename files --- .../{NDArrayIterators.kt => INDArrayIterators.kt} | 0 .../{ScalarsND4JStructure.kt => INDArrayStructures.kt} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/{NDArrayIterators.kt => INDArrayIterators.kt} (100%) rename kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/{ScalarsND4JStructure.kt => INDArrayStructures.kt} (100%) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/NDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt similarity index 100% rename from kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/NDArrayIterators.kt rename to kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt similarity index 100% rename from kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/ScalarsND4JStructure.kt rename to kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt From 5cc56b6ab01dff7c412e93a43feea400f70bbc96 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 28 Jun 2020 17:30:09 +0700 Subject: [PATCH 09/69] Remove Base suffix from class name --- .../kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index 3ffcc110d..b444def3d 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -3,7 +3,7 @@ package scientifik.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray import scientifik.kmath.structures.NDStructure -interface INDArrayStructureBase : NDStructure { +interface INDArrayStructure : NDStructure { val ndArray: INDArray override val shape: IntArray @@ -13,22 +13,22 @@ interface INDArrayStructureBase : NDStructure { override fun elements(): Sequence> = Sequence(::elementsIterator) } -data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase { +data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) override fun get(index: IntArray): Int = ndArray.getInt(*index) } -data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructureBase { +data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) } -data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructureBase { +data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayDoubleIterator(ndArray) override fun get(index: IntArray): Double = ndArray.getDouble(*index) } -data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructureBase { +data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) override fun get(index: IntArray): Float = ndArray.getFloat(*index) } From f49c3e4f4d2158123640c442dd7ebe6447ec2474 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 28 Jun 2020 17:33:09 +0700 Subject: [PATCH 10/69] Add final modifier --- .../src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt index 426b1ec2d..f6efdc0ba 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt @@ -6,7 +6,7 @@ import org.nd4j.linalg.api.shape.Shape internal sealed class INDArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { private var i: Int = 0 - override fun hasNext(): Boolean = i < iterateOver.length() + final override fun hasNext(): Boolean = i < iterateOver.length() abstract fun getSingle(indices: LongArray): T From b41a9588bc6c98842126e89650d6a942f4369bdf Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 28 Jun 2020 18:21:27 +0700 Subject: [PATCH 11/69] Rename file --- .../main/kotlin/scientifik.kmath.nd4j/{Arrays.kt => arrays.kt} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/{Arrays.kt => arrays.kt} (100%) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt similarity index 100% rename from kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/Arrays.kt rename to kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt From 05120929b0d18bb8a2bf7937a6f679e95ba5dc9b Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 28 Jun 2020 19:08:44 +0700 Subject: [PATCH 12/69] Encapsulate classOfT property of AsmBuilder --- .../kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index e550bc563..5531fd5dc 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -16,13 +16,13 @@ import kotlin.reflect.KClass * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * - * @param T the type of AsmExpression to unwrap. - * @param algebra the algebra the applied AsmExpressions use. - * @param className the unique class name of new loaded class. - * @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. + * @property T the type of AsmExpression to unwrap. + * @property algebra the algebra the applied AsmExpressions use. + * @property className the unique class name of new loaded class. + * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. */ internal class AsmBuilder internal constructor( - internal val classOfT: KClass<*>, + private val classOfT: KClass<*>, private val algebra: Algebra, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit From 3b18000f1edcd9d99e15037230885140b64b8001 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 00:14:01 +0700 Subject: [PATCH 13/69] Make several NDStructures mutable --- .../scientifik.kmath.nd4j/INDArrayStructures.kt | 10 +++++++--- .../scientifik/kmath/nd4j/INDArrayStructureTest.kt | 14 +++++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index b444def3d..f39d84716 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -1,6 +1,7 @@ package scientifik.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray +import scientifik.kmath.structures.MutableNDStructure import scientifik.kmath.structures.NDStructure interface INDArrayStructure : NDStructure { @@ -13,9 +14,10 @@ interface INDArrayStructure : NDStructure { override fun elements(): Sequence> = Sequence(::elementsIterator) } -data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure { +data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) override fun get(index: IntArray): Int = ndArray.getInt(*index) + override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } } data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { @@ -23,12 +25,14 @@ data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStruc override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) } -data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure { +data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayDoubleIterator(ndArray) override fun get(index: IntArray): Double = ndArray.getDouble(*index) + override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } } -data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure { +data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) override fun get(index: IntArray): Float = ndArray.getFloat(*index) + override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } } diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt index 239289262..77565856a 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -1,9 +1,9 @@ package scientifik.kmath.nd4j import org.nd4j.linalg.factory.Nd4j +import scientifik.kmath.structures.get import kotlin.test.Test import kotlin.test.assertEquals -import scientifik.kmath.structures.get internal class INDArrayStructureTest { @Test @@ -17,7 +17,7 @@ internal class INDArrayStructureTest { @Test fun testShape() { val nd = Nd4j.rand(10, 2, 3, 6)!! - val struct = INDArrayIntStructure(nd) + val struct = INDArrayLongStructure(nd) assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList()) } @@ -33,7 +33,7 @@ internal class INDArrayStructureTest { @Test fun testDimension() { val nd = Nd4j.rand(8, 16, 3, 7, 1)!! - val struct = INDArrayIntStructure(nd) + val struct = INDArrayFloatStructure(nd) assertEquals(5, struct.dimension) } @@ -43,4 +43,12 @@ internal class INDArrayStructureTest { val struct = INDArrayIntStructure(nd) assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0]) } + + @Test + fun testSet() { + val nd = Nd4j.rand(17, 12, 4, 8)!! + val struct = INDArrayIntStructure(nd) + struct[intArrayOf(1, 2, 3, 4)] = 777 + assertEquals(777, struct[1, 2, 3, 4]) + } } From eb9d40fd2aed7342a5584609fdfb83718093d400 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 00:29:43 +0700 Subject: [PATCH 14/69] Convert INDArray NDStructures implementations to inline classes, add tests to verify equals and hashCode --- .../scientifik.kmath.nd4j/INDArrayStructures.kt | 13 +++++++++---- .../kmath/nd4j/INDArrayStructureTest.kt | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index f39d84716..66aa00fac 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -14,25 +14,30 @@ interface INDArrayStructure : NDStructure { override fun elements(): Sequence> = Sequence(::elementsIterator) } -data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) override fun get(index: IntArray): Int = ndArray.getInt(*index) override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } + override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)" } -data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { +inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) + override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)" + } -data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayDoubleIterator(ndArray) override fun get(index: IntArray): Double = ndArray.getDouble(*index) override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } + override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)" } -data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) override fun get(index: IntArray): Float = ndArray.getFloat(*index) override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } + override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)" } diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt index 77565856a..ad1cbb585 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -4,6 +4,7 @@ import org.nd4j.linalg.factory.Nd4j import scientifik.kmath.structures.get import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertNotEquals internal class INDArrayStructureTest { @Test @@ -25,9 +26,25 @@ internal class INDArrayStructureTest { fun testEquals() { val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val struct1 = INDArrayDoubleStructure(nd1) + assertEquals(struct1, struct1) + assertNotEquals(struct1, null as INDArrayDoubleStructure?) val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val struct2 = INDArrayDoubleStructure(nd2) assertEquals(struct1, struct2) + assertEquals(struct2, struct1) + val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct3 = INDArrayDoubleStructure(nd3) + assertEquals(struct2, struct3) + assertEquals(struct1, struct3) + } + + @Test + fun testHashCode() { + val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct1 = INDArrayDoubleStructure(nd1) + val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct2 = INDArrayDoubleStructure(nd2) + assertEquals(struct1.hashCode(), struct2.hashCode()) } @Test From 783087982fbe93f60a2c20e7a3d50627b69b8b58 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 02:50:34 +0700 Subject: [PATCH 15/69] Rollback making Structures inline, implement Algebras for NDArrayStructure --- .../scientifik.kmath.nd4j/INDArrayAlgebra.kt | 89 +++++++++++++++++++ .../INDArrayIterators.kt | 7 +- .../INDArrayStructures.kt | 25 +++--- .../kmath/nd4j/INDArrayAlgebraTest.kt | 30 +++++++ .../kmath/nd4j/INDArrayStructureTest.kt | 14 +-- 5 files changed, 147 insertions(+), 18 deletions(-) create mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt create mode 100644 kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt new file mode 100644 index 000000000..760e3d03e --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt @@ -0,0 +1,89 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.factory.Nd4j +import scientifik.kmath.operations.* +import scientifik.kmath.structures.MutableNDStructure +import scientifik.kmath.structures.NDField +import scientifik.kmath.structures.NDRing + +interface INDArrayRing : + NDRing where F : Ring, N : INDArrayStructure, N : MutableNDStructure { + fun INDArray.wrap(): N + + override val zero: N + get() = Nd4j.zeros(*shape).wrap() + + override val one: N + get() = Nd4j.ones(*shape).wrap() + + override fun produce(initializer: F.(IntArray) -> T): N { + val struct = Nd4j.create(*shape).wrap() + struct.elements().map(Pair::first).forEach { struct[it] = elementContext.initializer(it) } + return struct + } + + override fun map(arg: N, transform: F.(T) -> T): N { + val new = Nd4j.create(*shape) + Nd4j.copy(arg.ndArray, new) + val newStruct = new.wrap() + newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } + return newStruct + } + + override fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N { + val new = Nd4j.create(*shape).wrap() + new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) } + return new + } + + override fun combine(a: N, b: N, transform: F.(T, T) -> T): N { + val new = Nd4j.create(*shape).wrap() + new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) } + return new + } + + override fun add(a: N, b: N): N = a.ndArray.addi(b.ndArray).wrap() + override fun N.minus(b: N): N = ndArray.subi(b.ndArray).wrap() + override fun N.unaryMinus(): N = ndArray.negi().wrap() + override fun multiply(a: N, b: N): N = a.ndArray.muli(b.ndArray).wrap() + override fun multiply(a: N, k: Number): N = a.ndArray.muli(k).wrap() + override fun N.div(k: Number): N = ndArray.divi(k).wrap() + override fun N.minus(b: Number): N = ndArray.subi(b).wrap() + override fun N.plus(b: Number): N = ndArray.addi(b).wrap() + override fun N.times(k: Number): N = ndArray.muli(k).wrap() +} + +interface INDArrayField : NDField, + INDArrayRing where F : Field, N : INDArrayStructure, N : MutableNDStructure { + override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap() +} + +class RealINDArrayField(override val shape: IntArray, override val elementContext: Field = RealField) : + INDArrayField, INDArrayRealStructure> { + override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure() + override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap() + override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap() + override fun INDArrayRealStructure.div(k: Number): INDArrayRealStructure = ndArray.divi(k).wrap() + override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap() + override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap() +} + +class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field = FloatField) : + INDArrayField, INDArrayFloatStructure> { + override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure() + override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap() + override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap() + override fun INDArrayFloatStructure.div(k: Number): INDArrayFloatStructure = ndArray.divi(k).wrap() + override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap() + override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap() +} + +class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring = IntRing) : + INDArrayRing, INDArrayIntStructure> { + override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure() + override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap() + override fun INDArrayIntStructure.div(k: Number): INDArrayIntStructure = ndArray.divi(k).wrap() + override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap() + override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap() +} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt index f6efdc0ba..115c78cb9 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt @@ -20,14 +20,19 @@ internal sealed class INDArrayIteratorBase(protected val iterateOver: INDArra } } -internal class INDArrayDoubleIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +internal class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices) } +internal fun INDArray.realIterator(): INDArrayRealIterator = INDArrayRealIterator(this) + internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) } +// TODO +//internal fun INDArray.longI + internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices)) } diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index 66aa00fac..351110485 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -14,30 +14,35 @@ interface INDArrayStructure : NDStructure { override fun elements(): Sequence> = Sequence(::elementsIterator) } -inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) override fun get(index: IntArray): Int = ndArray.getInt(*index) override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } - override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)" } -inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { +fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this) + +data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) - override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)" - } -inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { - override fun elementsIterator(): Iterator> = INDArrayDoubleIterator(ndArray) +fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this) + +data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure, + MutableNDStructure { + override fun elementsIterator(): Iterator> = INDArrayRealIterator(ndArray) override fun get(index: IntArray): Double = ndArray.getDouble(*index) override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } - override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)" } -inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this) + +data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, + MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) override fun get(index: IntArray): Float = ndArray.getFloat(*index) override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } - override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)" } + +fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this) diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt new file mode 100644 index 000000000..f971e7871 --- /dev/null +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt @@ -0,0 +1,30 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.factory.Nd4j +import scientifik.kmath.operations.invoke +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class INDArrayAlgebraTest { + @Test + fun testProduce() { + val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } + val expected = Nd4j.create(2, 2)!!.asRealStructure() + expected[intArrayOf(0, 0)] = 0.0 + expected[intArrayOf(0, 1)] = 1.0 + expected[intArrayOf(1, 0)] = 1.0 + expected[intArrayOf(1, 1)] = 2.0 + assertEquals(expected, res) + } + + @Test + fun testMap() { + val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } } + val expected = Nd4j.create(2, 2)!!.asIntStructure() + expected[intArrayOf(0, 0)] = 3 + expected[intArrayOf(0, 1)] = 3 + expected[intArrayOf(1, 0)] = 3 + expected[intArrayOf(1, 1)] = 3 + assertEquals(expected, res) + } +} diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt index ad1cbb585..dfede6d32 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -10,7 +10,7 @@ internal class INDArrayStructureTest { @Test fun testElements() { val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct = INDArrayDoubleStructure(nd) + val struct = INDArrayRealStructure(nd) val res = struct.elements().map(Pair::second).toList() assertEquals(listOf(1.0, 2.0, 3.0), res) } @@ -25,15 +25,15 @@ internal class INDArrayStructureTest { @Test fun testEquals() { val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct1 = INDArrayDoubleStructure(nd1) + val struct1 = INDArrayRealStructure(nd1) assertEquals(struct1, struct1) - assertNotEquals(struct1, null as INDArrayDoubleStructure?) + assertNotEquals(struct1, null as INDArrayRealStructure?) val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct2 = INDArrayDoubleStructure(nd2) + val struct2 = INDArrayRealStructure(nd2) assertEquals(struct1, struct2) assertEquals(struct2, struct1) val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct3 = INDArrayDoubleStructure(nd3) + val struct3 = INDArrayRealStructure(nd3) assertEquals(struct2, struct3) assertEquals(struct1, struct3) } @@ -41,9 +41,9 @@ internal class INDArrayStructureTest { @Test fun testHashCode() { val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct1 = INDArrayDoubleStructure(nd1) + val struct1 = INDArrayRealStructure(nd1) val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct2 = INDArrayDoubleStructure(nd2) + val struct2 = INDArrayRealStructure(nd2) assertEquals(struct1.hashCode(), struct2.hashCode()) } From d7949fdb01ac96a617bf7bf8e3c2579de481b946 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 03:39:37 +0700 Subject: [PATCH 16/69] Remove duplicated code --- .../src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt | 3 --- 1 file changed, 3 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt index 760e3d03e..a4ecd09e5 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt @@ -64,7 +64,6 @@ class RealINDArrayField(override val shape: IntArray, override val elementContex override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure() override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap() override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap() - override fun INDArrayRealStructure.div(k: Number): INDArrayRealStructure = ndArray.divi(k).wrap() override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap() override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap() } @@ -74,7 +73,6 @@ class FloatINDArrayField(override val shape: IntArray, override val elementConte override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure() override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap() override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap() - override fun INDArrayFloatStructure.div(k: Number): INDArrayFloatStructure = ndArray.divi(k).wrap() override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap() override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap() } @@ -83,7 +81,6 @@ class IntINDArrayRing(override val shape: IntArray, override val elementContext: INDArrayRing, INDArrayIntStructure> { override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure() override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap() - override fun INDArrayIntStructure.div(k: Number): INDArrayIntStructure = ndArray.divi(k).wrap() override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap() override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap() } From 8a8b314d0a60c027f6abcfcff8d0d924613cd0d1 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 03:48:11 +0700 Subject: [PATCH 17/69] Optimize reverse division for FP INDArrayAlgebra --- .../scientifik.kmath.nd4j/INDArrayAlgebra.kt | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt index a4ecd09e5..44d8f6611 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt @@ -7,8 +7,8 @@ import scientifik.kmath.structures.MutableNDStructure import scientifik.kmath.structures.NDField import scientifik.kmath.structures.NDRing -interface INDArrayRing : - NDRing where F : Ring, N : INDArrayStructure, N : MutableNDStructure { +interface INDArrayRing : + NDRing where R : Ring, N : INDArrayStructure, N : MutableNDStructure { fun INDArray.wrap(): N override val zero: N @@ -17,13 +17,13 @@ interface INDArrayRing : override val one: N get() = Nd4j.ones(*shape).wrap() - override fun produce(initializer: F.(IntArray) -> T): N { + override fun produce(initializer: R.(IntArray) -> T): N { val struct = Nd4j.create(*shape).wrap() struct.elements().map(Pair::first).forEach { struct[it] = elementContext.initializer(it) } return struct } - override fun map(arg: N, transform: F.(T) -> T): N { + override fun map(arg: N, transform: R.(T) -> T): N { val new = Nd4j.create(*shape) Nd4j.copy(arg.ndArray, new) val newStruct = new.wrap() @@ -31,13 +31,13 @@ interface INDArrayRing : return newStruct } - override fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N { + override fun mapIndexed(arg: N, transform: R.(index: IntArray, T) -> T): N { val new = Nd4j.create(*shape).wrap() new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) } return new } - override fun combine(a: N, b: N, transform: F.(T, T) -> T): N { + override fun combine(a: N, b: N, transform: R.(T, T) -> T): N { val new = Nd4j.create(*shape).wrap() new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) } return new @@ -66,6 +66,7 @@ class RealINDArrayField(override val shape: IntArray, override val elementContex override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap() override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap() override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap() + override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap() } class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field = FloatField) : @@ -75,6 +76,7 @@ class FloatINDArrayField(override val shape: IntArray, override val elementConte override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap() override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap() override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap() + override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap() } class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring = IntRing) : From 23b2ba9950cb4fab33f0b53565ae762cfa2f7a45 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 03:49:29 +0700 Subject: [PATCH 18/69] Optimize reverse division for FP INDArrayAlgebra --- .../src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt index 44d8f6611..f476af0d5 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt @@ -57,6 +57,7 @@ interface INDArrayRing : interface INDArrayField : NDField, INDArrayRing where F : Field, N : INDArrayStructure, N : MutableNDStructure { override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap() + override fun Number.div(b: N): N = b.ndArray.rdivi(this).wrap() } class RealINDArrayField(override val shape: IntArray, override val elementContext: Field = RealField) : From d87dd3e717cc18851b7922d1e31b7ac81698c082 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 21:31:08 +0700 Subject: [PATCH 19/69] Refactor array functions --- .../kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt | 11 +++++++---- .../scientifik.kmath.nd4j/INDArrayStructures.kt | 4 ++-- .../src/main/kotlin/scientifik.kmath.nd4j/arrays.kt | 4 ++-- .../scientifik/kmath/nd4j/INDArrayAlgebraTest.kt | 11 +++++++++++ 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt index 115c78cb9..bba5089a1 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt @@ -16,7 +16,7 @@ internal sealed class INDArrayIteratorBase(protected val iterateOver: INDArra else Shape.ind2sub(iterateOver, i++.toLong())!! - return narrowToIntArray(la) to getSingle(la) + return la.toIntArray() to getSingle(la) } } @@ -30,13 +30,16 @@ internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBas override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) } -// TODO -//internal fun INDArray.longI +internal fun INDArray.longIterator(): INDArrayLongIterator = INDArrayLongIterator(this) internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { - override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices)) + override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray()) } +internal fun INDArray.intIterator(): INDArrayIntIterator = INDArrayIntIterator(this) + internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices) } + +internal fun INDArray.floatIterator() = INDArrayFloatIterator(this) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index 351110485..ef7436285 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -8,7 +8,7 @@ interface INDArrayStructure : NDStructure { val ndArray: INDArray override val shape: IntArray - get() = narrowToIntArray(ndArray.shape()) + get() = ndArray.shape().toIntArray() fun elementsIterator(): Iterator> override fun elements(): Sequence> = Sequence(::elementsIterator) @@ -24,7 +24,7 @@ fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this) data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) - override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) + override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) } fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt index 3d5062a4f..269fc89c2 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt @@ -1,4 +1,4 @@ package scientifik.kmath.nd4j -internal fun widenToLongArray(ia: IntArray): LongArray = LongArray(ia.size) { ia[it].toLong() } -internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() } +internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() } +internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() } diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt index f971e7871..4aa40c233 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt @@ -27,4 +27,15 @@ internal class INDArrayAlgebraTest { expected[intArrayOf(1, 1)] = 3 assertEquals(expected, res) } + + @Test + fun testAdd() { + val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 } + val expected = Nd4j.create(2, 2)!!.asIntStructure() + expected[intArrayOf(0, 0)] = 26 + expected[intArrayOf(0, 1)] = 26 + expected[intArrayOf(1, 0)] = 26 + expected[intArrayOf(1, 1)] = 26 + assertEquals(expected, res) + } } From f54e5679cf261dceeff58531e929d3fd0bd04549 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 22:06:13 +0700 Subject: [PATCH 20/69] Add README.md for kmath-nd4j --- kmath-nd4j/README.md | 73 +++++++++++++++++++ .../INDArrayStructures.kt | 8 +- 2 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 kmath-nd4j/README.md diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md new file mode 100644 index 000000000..63687a880 --- /dev/null +++ b/kmath-nd4j/README.md @@ -0,0 +1,73 @@ +# ND4J NDStructure implementation (`kmath-nd4j`) + +This subproject implements the following features: + +- NDStructure wrapper for INDArray. +- Optimized NDRing implementation for INDArray storing Ints. +- Optimized NDField implementation for INDArray storing Floats and Doubles. + +> #### Artifact: +> This module is distributed in the artifact `scientifik:kmath-nd4j:0.1.4-dev-8`. +> +> **Gradle:** +> +> ```gradle +> repositories { +> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } +> maven { url 'https://dl.bintray.com/mipt-npm/dev' } +> } +> +> dependencies { +> implementation 'scientifik:kmath-nd4j:0.1.4-dev-8' +> } +> ``` +> **Gradle Kotlin DSL:** +> +> ```kotlin +> repositories { +> maven("https://dl.bintray.com/mipt-npm/scientifik") +> maven("https://dl.bintray.com/mipt-npm/dev") +> } +> +> dependencies { +> implementation("scientifik:kmath-nd4j:0.1.4-dev-8") +> } +> ``` +> + +## Examples + +NDStructure wrapper for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.structures.* + +val array = Nd4j.ones(2, 2)!!.asRealStructure() +println(array[0, 0]) // 1.0 +array[intArrayOf(0, 0)] = 24.0 +println(array[0, 0]) // 24.0 +``` + +Fast element-wise arithmetics for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.operations.* + +val field = RealINDArrayField(intArrayOf(2, 2)) +val array = Nd4j.rand(2, 2)!!.asRealStructure() + +val res = field { + (25.0 / array + 20) * 4 +} + +println(res.ndArray) +// [[ 250.6449, 428.5840], +// [ 269.7913, 202.2077]] +``` + + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index ef7436285..06b0354d8 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -15,7 +15,7 @@ interface INDArrayStructure : NDStructure { } data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { - override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) + override fun elementsIterator(): Iterator> = ndArray.intIterator() override fun get(index: IntArray): Int = ndArray.getInt(*index) override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } } @@ -23,7 +23,7 @@ data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStruct fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this) data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { - override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) + override fun elementsIterator(): Iterator> = ndArray.longIterator() override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) } @@ -31,7 +31,7 @@ fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(th data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { - override fun elementsIterator(): Iterator> = INDArrayRealIterator(ndArray) + override fun elementsIterator(): Iterator> = ndArray.realIterator() override fun get(index: IntArray): Double = ndArray.getDouble(*index) override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } } @@ -40,7 +40,7 @@ fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(th data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { - override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) + override fun elementsIterator(): Iterator> = ndArray.floatIterator() override fun get(index: IntArray): Float = ndArray.getFloat(*index) override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } } From bf071bcdc1f1040bf644fe9638827c3ffe36491e Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 22:30:08 +0700 Subject: [PATCH 21/69] Minor refactor --- kmath-nd4j/README.md | 2 +- .../scientifik.kmath.nd4j/INDArrayAlgebra.kt | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index 63687a880..ad76799c3 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -50,7 +50,7 @@ array[intArrayOf(0, 0)] = 24.0 println(array[0, 0]) // 24.0 ``` -Fast element-wise arithmetics for INDArray: +Fast element-wise and in-place arithmetics for INDArray: ```kotlin import org.nd4j.linalg.factory.* diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt index f476af0d5..14fe202c3 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt @@ -9,24 +9,22 @@ import scientifik.kmath.structures.NDRing interface INDArrayRing : NDRing where R : Ring, N : INDArrayStructure, N : MutableNDStructure { - fun INDArray.wrap(): N - override val zero: N get() = Nd4j.zeros(*shape).wrap() override val one: N get() = Nd4j.ones(*shape).wrap() + fun INDArray.wrap(): N + override fun produce(initializer: R.(IntArray) -> T): N { - val struct = Nd4j.create(*shape).wrap() + val struct = Nd4j.create(*shape)!!.wrap() struct.elements().map(Pair::first).forEach { struct[it] = elementContext.initializer(it) } return struct } override fun map(arg: N, transform: R.(T) -> T): N { - val new = Nd4j.create(*shape) - Nd4j.copy(arg.ndArray, new) - val newStruct = new.wrap() + val newStruct = arg.ndArray.dup().wrap() newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } return newStruct } @@ -52,6 +50,7 @@ interface INDArrayRing : override fun N.minus(b: Number): N = ndArray.subi(b).wrap() override fun N.plus(b: Number): N = ndArray.addi(b).wrap() override fun N.times(k: Number): N = ndArray.muli(k).wrap() + override fun Number.minus(b: N): N = b.ndArray.rsubi(this).wrap() } interface INDArrayField : NDField, @@ -61,13 +60,14 @@ interface INDArrayField : NDField, } class RealINDArrayField(override val shape: IntArray, override val elementContext: Field = RealField) : - INDArrayField, INDArrayRealStructure> { + INDArrayField, INDArrayRealStructure> { override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure() override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap() override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap() override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap() override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap() override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap() + override fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rsubi(this).wrap() } class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field = FloatField) : @@ -78,6 +78,7 @@ class FloatINDArrayField(override val shape: IntArray, override val elementConte override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap() override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap() override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap() + override fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rsubi(this).wrap() } class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring = IntRing) : @@ -86,4 +87,5 @@ class IntINDArrayRing(override val shape: IntArray, override val elementContext: override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap() override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap() override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap() + override fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure = arg.ndArray.rsubi(this).wrap() } From 7157878485236ee88688053b6313b710891c0c03 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 15 Aug 2020 18:35:16 +0700 Subject: [PATCH 22/69] Update changelog, document kmath-nd4j, refactor iterators, correct algebra mistakes, separate INDArrayStructureRing to Space, Ring and Algebra --- CHANGELOG.md | 1 + .../kmath/operations/AlgebraElements.kt | 6 +- .../kmath/structures/BufferedNDElement.kt | 6 +- .../scientifik/kmath/structures/NDAlgebra.kt | 170 ++++++++--- kmath-nd4j/README.md | 8 +- .../scientifik.kmath.nd4j/INDArrayAlgebra.kt | 286 ++++++++++++++---- .../INDArrayIterators.kt | 35 ++- .../INDArrayStructures.kt | 50 ++- 8 files changed, 448 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26f9e33ec..e9afb6c26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ - Blocking chains in `kmath-coroutines` - Full hyperbolic functions support and default implementations within `ExtendedField` - Norm support for `Complex` +- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`. ### Changed - BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt index 197897c14..e1d50c4f0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt @@ -74,9 +74,9 @@ interface SpaceElement, S : Space> : MathElement /** * The element of [Ring]. * - * @param T the type of space operation results. + * @param T the type of ring operation results. * @param I self type of the element. Needed for static type checking. - * @param R the type of space. + * @param R the type of ring. */ interface RingElement, R : Ring> : SpaceElement { /** @@ -91,7 +91,7 @@ interface RingElement, R : Ring> : SpaceElement>( override val context: BufferedNDField, override val buffer: Buffer ) : BufferedNDElement(), FieldElement, BufferedNDFieldElement, BufferedNDField> { - override fun unwrap(): NDBuffer = this override fun NDBuffer.wrap(): BufferedNDFieldElement { @@ -56,8 +55,9 @@ class BufferedNDFieldElement>( /** * Element by element application of any operation on elements to the whole array. Just like in numpy. */ -operator fun > Function1.invoke(ndElement: BufferedNDElement): MathElement> = - ndElement.context.run { map(ndElement) { invoke(it) }.toElement() } +operator fun > Function1.invoke( + ndElement: BufferedNDElement +): MathElement> = ndElement.context.run { map(ndElement) { invoke(it) }.toElement() } /* plus and minus */ diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt index f09db3c72..14c23a81a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt @@ -5,56 +5,78 @@ import scientifik.kmath.operations.Field import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space - /** - * An exception is thrown when the expected ans actual shape of NDArray differs + * An exception is thrown when the expected ans actual shape of NDArray differs. + * + * @property expected the expected shape. + * @property actual the actual shape. */ -class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : RuntimeException() - +class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : + RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") /** - * The base interface for all nd-algebra implementations - * @param T the type of nd-structure element - * @param C the type of the element context - * @param N the type of the structure + * The base interface for all ND-algebra implementations. + * + * @param T the type of ND-structure element. + * @param C the type of the element context. + * @param N the type of the structure. */ interface NDAlgebra> { + /** + * The shape of ND-structures this algebra operates on. + */ val shape: IntArray + + /** + * The algebra over elements of ND structure. + */ val elementContext: C /** - * Produce a new [N] structure using given initializer function + * Produces a new [N] structure using given initializer function. */ fun produce(initializer: C.(IntArray) -> T): N /** - * Map elements from one structure to another one + * Maps elements from one structure to another one by applying [transform] to them. */ fun map(arg: N, transform: C.(T) -> T): N /** - * Map indexed elements + * Maps elements from one structure to another one by applying [transform] to them alongside with their indices. */ fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N /** - * Combine two structures into one + * Combines two structures into one. */ fun combine(a: N, b: N, transform: C.(T, T) -> T): N /** - * Check if given elements are consistent with this context + * Checks if given element is consistent with this context. + * + * @param element the structure to check. + * @return the valid structure. */ - fun check(vararg elements: N) { - elements.forEach { - if (!shape.contentEquals(it.shape)) { - throw ShapeMismatchException(shape, it.shape) - } - } + fun check(element: N): N { + if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape) + return element } /** - * element-by-element invoke a function working on [T] on a [NDStructure] + * Checks if given elements are consistent with this context. + * + * @param elements the structures to check. + * @return the array of valid structures. + */ + fun check(vararg elements: N): Array = elements + .map(NDStructure::shape) + .singleOrNull { !shape.contentEquals(it) } + ?.let { throw ShapeMismatchException(shape, it) } + ?: elements + + /** + * Element-wise invocation of function working on [T] on a [NDStructure]. */ operator fun Function1.invoke(structure: N): N = map(structure) { value -> this@invoke(value) } @@ -62,43 +84,107 @@ interface NDAlgebra> { } /** - * An nd-space over element space + * Space of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param S the type of space of structure elements. */ interface NDSpace, N : NDStructure> : Space, NDAlgebra { /** - * Element-by-element addition + * Element-wise addition. + * + * @param a the addend. + * @param b the augend. + * @return the sum. */ override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) } /** - * Multiply all elements by constant + * Element-wise multiplication by scalar. + * + * @param a the multiplicand. + * @param k the multiplier. + * @return the product. */ override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) } - //TODO move to extensions after KEEP-176 + // TODO move to extensions after KEEP-176 + + /** + * Adds an ND structure to an element of it. + * + * @receiver the addend. + * @param arg the augend. + * @return the sum. + */ operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) } + /** + * Subtracts an element from ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) } + /** + * Adds an element to ND structure of it. + * + * @receiver the addend. + * @param arg the augend. + * @return the sum. + */ operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) } + + /** + * Subtracts an ND structure from an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) } companion object } /** - * An nd-ring over element ring + * Ring of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param R the type of ring of structure elements. */ interface NDRing, N : NDStructure> : Ring, NDSpace { - /** - * Element-by-element multiplication + * Element-wise multiplication. + * + * @param a the multiplicand. + * @param b the multiplier. + * @return the product. */ override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } //TODO move to extensions after KEEP-176 + + /** + * Multiplies an ND structure by an element of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) } + /** + * Multiplies an element by a ND structure of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) } companion object @@ -109,31 +195,47 @@ interface NDRing, N : NDStructure> : Ring, NDSpace * * @param T the type of the element contained in ND structure. * @param N the type of ND structure. - * @param F field of structure elements. + * @param F the type field of structure elements. */ interface NDField, N : NDStructure> : Field, NDRing { - /** - * Element-by-element division + * Element-wise division. + * + * @param a the dividend. + * @param b the divisor. + * @return the quotient. */ override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } //TODO move to extensions after KEEP-176 + /** + * Divides an ND structure by an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) } + /** + * Divides an element by an ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) } companion object { - - private val realNDFieldCache = HashMap() + private val realNDFieldCache: MutableMap = hashMapOf() /** - * Create a nd-field for [Double] values or pull it from cache if it was created previously + * Create a nd-field for [Double] values or pull it from cache if it was created previously. */ fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } /** - * Create a nd-field with boxing generic buffer + * Create a ND field with boxing generic buffer. */ fun > boxing( field: F, diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index ad76799c3..f8ca5eed2 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -3,8 +3,8 @@ This subproject implements the following features: - NDStructure wrapper for INDArray. -- Optimized NDRing implementation for INDArray storing Ints. -- Optimized NDField implementation for INDArray storing Floats and Doubles. +- Optimized NDRing implementations for INDArray storing Ints and Longs. +- Optimized NDField implementations for INDArray storing Floats and Doubles. > #### Artifact: > This module is distributed in the artifact `scientifik:kmath-nd4j:0.1.4-dev-8`. @@ -44,7 +44,7 @@ import org.nd4j.linalg.factory.* import scientifik.kmath.nd4j.* import scientifik.kmath.structures.* -val array = Nd4j.ones(2, 2)!!.asRealStructure() +val array = Nd4j.ones(2, 2).asRealStructure() println(array[0, 0]) // 1.0 array[intArrayOf(0, 0)] = 24.0 println(array[0, 0]) // 24.0 @@ -58,7 +58,7 @@ import scientifik.kmath.nd4j.* import scientifik.kmath.operations.* val field = RealINDArrayField(intArrayOf(2, 2)) -val array = Nd4j.rand(2, 2)!!.asRealStructure() +val array = Nd4j.rand(2, 2).asRealStructure() val res = field { (25.0 / array + 20) * 4 diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt index 14fe202c3..c24e2ece6 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt @@ -3,89 +3,271 @@ package scientifik.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.factory.Nd4j import scientifik.kmath.operations.* -import scientifik.kmath.structures.MutableNDStructure -import scientifik.kmath.structures.NDField -import scientifik.kmath.structures.NDRing - -interface INDArrayRing : - NDRing where R : Ring, N : INDArrayStructure, N : MutableNDStructure { - override val zero: N - get() = Nd4j.zeros(*shape).wrap() - - override val one: N - get() = Nd4j.ones(*shape).wrap() +import scientifik.kmath.structures.* +/** + * Represents [NDAlgebra] over [INDArrayAlgebra]. + * + * @param T the type of ND-structure element. + * @param C the type of the element context. + * @param N the type of the structure. + */ +interface INDArrayAlgebra : NDAlgebra where N : INDArrayStructure, N : MutableNDStructure { + /** + * Wraps [INDArray] to [N]. + */ fun INDArray.wrap(): N - override fun produce(initializer: R.(IntArray) -> T): N { + override fun produce(initializer: C.(IntArray) -> T): N { val struct = Nd4j.create(*shape)!!.wrap() - struct.elements().map(Pair::first).forEach { struct[it] = elementContext.initializer(it) } + struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } return struct } - override fun map(arg: N, transform: R.(T) -> T): N { + override fun map(arg: N, transform: C.(T) -> T): N { + check(arg) val newStruct = arg.ndArray.dup().wrap() newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } return newStruct } - override fun mapIndexed(arg: N, transform: R.(index: IntArray, T) -> T): N { + override fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N { + check(arg) val new = Nd4j.create(*shape).wrap() - new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) } + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) } return new } - override fun combine(a: N, b: N, transform: R.(T, T) -> T): N { + override fun combine(a: N, b: N, transform: C.(T, T) -> T): N { + check(a, b) val new = Nd4j.create(*shape).wrap() - new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) } + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } return new } - - override fun add(a: N, b: N): N = a.ndArray.addi(b.ndArray).wrap() - override fun N.minus(b: N): N = ndArray.subi(b.ndArray).wrap() - override fun N.unaryMinus(): N = ndArray.negi().wrap() - override fun multiply(a: N, b: N): N = a.ndArray.muli(b.ndArray).wrap() - override fun multiply(a: N, k: Number): N = a.ndArray.muli(k).wrap() - override fun N.div(k: Number): N = ndArray.divi(k).wrap() - override fun N.minus(b: Number): N = ndArray.subi(b).wrap() - override fun N.plus(b: Number): N = ndArray.addi(b).wrap() - override fun N.times(k: Number): N = ndArray.muli(k).wrap() - override fun Number.minus(b: N): N = b.ndArray.rsubi(this).wrap() } -interface INDArrayField : NDField, - INDArrayRing where F : Field, N : INDArrayStructure, N : MutableNDStructure { - override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap() - override fun Number.div(b: N): N = b.ndArray.rdivi(this).wrap() +/** + * Represents [NDSpace] over [INDArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param S the type of space of structure elements. + */ +interface INDArraySpace : NDSpace, INDArrayAlgebra + where S : Space, N : INDArrayStructure, N : MutableNDStructure { + + override val zero: N + get() = Nd4j.zeros(*shape).wrap() + + override fun add(a: N, b: N): N { + check(a, b) + return a.ndArray.add(b.ndArray).wrap() + } + + override operator fun N.minus(b: N): N { + check(this, b) + return ndArray.sub(b.ndArray).wrap() + } + + override operator fun N.unaryMinus(): N { + check(this) + return ndArray.neg().wrap() + } + + override fun multiply(a: N, k: Number): N { + check(a) + return a.ndArray.mul(k).wrap() + } + + override operator fun N.div(k: Number): N { + check(this) + return ndArray.div(k).wrap() + } + + override operator fun N.times(k: Number): N { + check(this) + return ndArray.mul(k).wrap() + } } +/** + * Represents [NDRing] over [INDArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param R the type of ring of structure elements. + */ +interface INDArrayRing : NDRing, INDArraySpace + where R : Ring, N : INDArrayStructure, N : MutableNDStructure { + + override val one: N + get() = Nd4j.ones(*shape).wrap() + + override fun multiply(a: N, b: N): N { + check(a, b) + return a.ndArray.mul(b.ndArray).wrap() + } + + override operator fun N.minus(b: Number): N { + check(this) + return ndArray.sub(b).wrap() + } + + override operator fun N.plus(b: Number): N { + check(this) + return ndArray.add(b).wrap() + } + + override operator fun Number.minus(b: N): N { + check(b) + return b.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDField] over [INDArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param F the type field of structure elements. + */ +interface INDArrayField : NDField, INDArrayRing + where F : Field, N : INDArrayStructure, N : MutableNDStructure { + override fun divide(a: N, b: N): N { + check(a, b) + return a.ndArray.div(b.ndArray).wrap() + } + + override operator fun Number.div(b: N): N { + check(b) + return b.ndArray.rdiv(this).wrap() + } +} + +/** + * Represents [NDField] over [INDArrayRealStructure]. + */ class RealINDArrayField(override val shape: IntArray, override val elementContext: Field = RealField) : INDArrayField, INDArrayRealStructure> { - override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure() - override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap() - override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap() - override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap() - override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap() - override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap() - override fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rsubi(this).wrap() + override fun INDArray.wrap(): INDArrayRealStructure = check(asRealStructure()) + override operator fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure { + check(this) + return ndArray.div(arg).wrap() + } + + override operator fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure { + check(this) + return ndArray.add(arg).wrap() + } + + override operator fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + override operator fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + override operator fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + override operator fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } } +/** + * Represents [NDField] over [INDArrayFloatStructure]. + */ class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field = FloatField) : INDArrayField, INDArrayFloatStructure> { - override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure() - override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap() - override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap() - override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap() - override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap() - override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap() - override fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rsubi(this).wrap() + override fun INDArray.wrap(): INDArrayFloatStructure = check(asFloatStructure()) + override operator fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure { + check(this) + return ndArray.div(arg).wrap() + } + + override operator fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure { + check(this) + return ndArray.add(arg).wrap() + } + + override operator fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + override operator fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + override operator fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + override operator fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } } +/** + * Represents [NDRing] over [INDArrayIntStructure]. + */ class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring = IntRing) : INDArrayRing, INDArrayIntStructure> { - override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure() - override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap() - override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap() - override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap() - override fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure = arg.ndArray.rsubi(this).wrap() + override fun INDArray.wrap(): INDArrayIntStructure = check(asIntStructure()) + override operator fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure { + check(this) + return ndArray.add(arg).wrap() + } + + override operator fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + override operator fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + override operator fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDRing] over [INDArrayLongStructure]. + */ +class LongINDArrayRing(override val shape: IntArray, override val elementContext: Ring = LongRing) : + INDArrayRing, INDArrayLongStructure> { + override fun INDArray.wrap(): INDArrayLongStructure = check(asLongStructure()) + override operator fun INDArrayLongStructure.plus(arg: Long): INDArrayLongStructure { + check(this) + return ndArray.add(arg).wrap() + } + + override operator fun INDArrayLongStructure.minus(arg: Long): INDArrayLongStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + override operator fun INDArrayLongStructure.times(arg: Long): INDArrayLongStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + override operator fun Long.minus(arg: INDArrayLongStructure): INDArrayLongStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } } diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt index bba5089a1..2759f9fdb 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt @@ -3,7 +3,24 @@ package scientifik.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.shape.Shape -internal sealed class INDArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { +private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Iterator { + private var i: Int = 0 + + override fun hasNext(): Boolean = i < iterateOver.length() + + override fun next(): IntArray { + val la = if (iterateOver.ordering() == 'c') + Shape.ind2subC(iterateOver, i++.toLong())!! + else + Shape.ind2sub(iterateOver, i++.toLong())!! + + return la.toIntArray() + } +} + +internal fun INDArray.indicesIterator(): Iterator = INDArrayIndicesIterator(this) + +private sealed class INDArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { private var i: Int = 0 final override fun hasNext(): Boolean = i < iterateOver.length() @@ -20,26 +37,26 @@ internal sealed class INDArrayIteratorBase(protected val iterateOver: INDArra } } -internal class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices) } -internal fun INDArray.realIterator(): INDArrayRealIterator = INDArrayRealIterator(this) +internal fun INDArray.realIterator(): Iterator> = INDArrayRealIterator(this) -internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) } -internal fun INDArray.longIterator(): INDArrayLongIterator = INDArrayLongIterator(this) +internal fun INDArray.longIterator(): Iterator> = INDArrayLongIterator(this) -internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray()) } -internal fun INDArray.intIterator(): INDArrayIntIterator = INDArrayIntIterator(this) +internal fun INDArray.intIterator(): Iterator> = INDArrayIntIterator(this) -internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices) } -internal fun INDArray.floatIterator() = INDArrayFloatIterator(this) +internal fun INDArray.floatIterator(): Iterator> = INDArrayFloatIterator(this) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index 06b0354d8..39cefee3d 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -4,45 +4,77 @@ import org.nd4j.linalg.api.ndarray.INDArray import scientifik.kmath.structures.MutableNDStructure import scientifik.kmath.structures.NDStructure -interface INDArrayStructure : NDStructure { - val ndArray: INDArray +/** + * Represents a [NDStructure] wrapping an [INDArray] object. + * + * @param T the type of items. + */ +sealed class INDArrayStructure : MutableNDStructure { + /** + * The wrapped [INDArray]. + */ + abstract val ndArray: INDArray override val shape: IntArray get() = ndArray.shape().toIntArray() - fun elementsIterator(): Iterator> + internal abstract fun elementsIterator(): Iterator> + internal fun indicesIterator(): Iterator = ndArray.indicesIterator() override fun elements(): Sequence> = Sequence(::elementsIterator) } -data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +/** + * Represents a [NDStructure] over [INDArray] elements of which are accessed as ints. + */ +data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.intIterator() override fun get(index: IntArray): Int = ndArray.getInt(*index) override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } } +/** + * Wraps this [INDArray] to [INDArrayIntStructure]. + */ fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this) -data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { +/** + * Represents a [NDStructure] over [INDArray] elements of which are accessed as longs. + */ +data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.longIterator() override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) + override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) } } +/** + * Wraps this [INDArray] to [INDArrayLongStructure]. + */ fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this) -data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure, - MutableNDStructure { +/** + * Represents a [NDStructure] over [INDArray] elements of which are accessed as reals. + */ +data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.realIterator() override fun get(index: IntArray): Double = ndArray.getDouble(*index) override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } } +/** + * Wraps this [INDArray] to [INDArrayRealStructure]. + */ fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this) -data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, - MutableNDStructure { +/** + * Represents a [NDStructure] over [INDArray] elements of which are accessed as floats. + */ +data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.floatIterator() override fun get(index: IntArray): Float = ndArray.getFloat(*index) override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } } +/** + * Wraps this [INDArray] to [INDArrayFloatStructure]. + */ fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this) From 2bc62356d6da378f2aae87cc830a119f13f03af5 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 27 Aug 2020 16:44:58 +0700 Subject: [PATCH 23/69] Fix compilation issues --- .../kotlin/scientifik/kmath/structures/BoxingNDField.kt | 3 ++- .../kotlin/scientifik/kmath/structures/BoxingNDRing.kt | 3 ++- .../kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt index 4cbb565c1..29e0ba276 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt @@ -14,8 +14,9 @@ class BoxingNDField>( fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) - override fun check(vararg elements: NDBuffer) { + override fun check(vararg elements: NDBuffer): Array> { if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") + return elements } override val zero: BufferedNDFieldElement by lazy { produce { zero } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt index f7be95736..8aabe169a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt @@ -14,8 +14,9 @@ class BoxingNDRing>( fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) - override fun check(vararg elements: NDBuffer) { + override fun check(vararg elements: NDBuffer): Array> { if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") + return elements } override val zero: BufferedNDRingElement by lazy { produce { zero } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt index 06922c56f..de8a150c6 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt @@ -5,8 +5,9 @@ import scientifik.kmath.operations.* interface BufferedNDAlgebra : NDAlgebra> { val strides: Strides - override fun check(vararg elements: NDBuffer) { + override fun check(vararg elements: NDBuffer): Array> { if (!elements.all { it.strides == this.strides }) error("Strides mismatch") + return elements } /** From d54e7c3e97e02090ef5abae2cfe1aca32c6d8e0a Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 15 Sep 2020 17:48:43 +0700 Subject: [PATCH 24/69] Update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 015f0deeb..bf0316c9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## [Unreleased] ### Added +- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`. ### Changed @@ -30,7 +31,6 @@ - Blocking chains in `kmath-coroutines` - Full hyperbolic functions support and default implementations within `ExtendedField` - Norm support for `Complex` -- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`. ### Changed - `readAsMemory` now has `throws IOException` in JVM signature. From 4e5c7ab366e41fe5ec8f392b4bf36d0ba9be6769 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 20 Sep 2020 16:45:26 +0700 Subject: [PATCH 25/69] Make one-liner not a one-liner --- kmath-nd4j/build.gradle.kts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts index 59354a8f9..110a2ac30 100644 --- a/kmath-nd4j/build.gradle.kts +++ b/kmath-nd4j/build.gradle.kts @@ -1,4 +1,6 @@ -plugins { id("scientifik.jvm") } +plugins { + id("scientifik.jvm") +} dependencies { api(project(":kmath-core")) From 2ee5d0f325fb7189462b4190497be9efefe420eb Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 21 Sep 2020 20:53:31 +0700 Subject: [PATCH 26/69] Change package name, simplify exposed API types, update build snippet, minor refactor --- examples/build.gradle.kts | 6 +- .../kmath/structures/BoxingNDField.kt | 5 +- .../kmath/structures/BufferedNDAlgebra.kt | 4 +- .../kscience/kmath/structures/NDAlgebra.kt | 2 +- kmath-nd4j/README.md | 7 + kmath-nd4j/build.gradle.kts | 2 +- .../kscience.kmath.nd4j/INDArrayAlgebra.kt | 284 ++++++++++++++++++ .../INDArrayIterators.kt | 4 +- .../kscience.kmath.nd4j/INDArrayStructures.kt | 68 +++++ .../arrays.kt | 2 +- .../scientifik.kmath.nd4j/INDArrayAlgebra.kt | 273 ----------------- .../INDArrayStructures.kt | 80 ----- .../kmath/nd4j/INDArrayAlgebraTest.kt | 11 +- .../kmath/nd4j/INDArrayStructureTest.kt | 41 +-- 14 files changed, 402 insertions(+), 387 deletions(-) create mode 100644 kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt rename kmath-nd4j/src/main/kotlin/{scientifik.kmath.nd4j => kscience.kmath.nd4j}/INDArrayIterators.kt (94%) create mode 100644 kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayStructures.kt rename kmath-nd4j/src/main/kotlin/{scientifik.kmath.nd4j => kscience.kmath.nd4j}/arrays.kt (85%) delete mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt delete mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt rename kmath-nd4j/src/test/kotlin/{scientifik => kscience}/kmath/nd4j/INDArrayAlgebraTest.kt (78%) rename kmath-nd4j/src/test/kotlin/{scientifik => kscience}/kmath/nd4j/INDArrayStructureTest.kt (55%) diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 3d193efce..f0161afbb 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -26,9 +26,13 @@ dependencies { implementation(project(":kmath-prob")) implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) + implementation(project(":kmath-nd4j")) implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") - "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath + implementation("org.slf4j:slf4j-simple:1.7.30") + implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") + "benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8") + "benchmarksImplementation"(sourceSets.main.get().output + sourceSets.main.get().runtimeClasspath) } // Configure benchmark diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt index ddec7bd25..dc65b12c4 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt @@ -15,8 +15,9 @@ public class BoxingNDField>( public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) - public override fun check(vararg elements: NDBuffer) { - check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + public override fun check(vararg elements: NDBuffer): Array> { + require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + return elements } public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt index 66b4f19e1..251b1bcb5 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt @@ -5,8 +5,10 @@ import kscience.kmath.operations.* public interface BufferedNDAlgebra : NDAlgebra> { public val strides: Strides - public override fun check(vararg elements: NDBuffer): Unit = + public override fun check(vararg elements: NDBuffer): Array> { require(elements.all { it.strides == strides }) { ("Strides mismatch") } + return elements + } /** * Convert any [NDStructure] to buffered structure using strides from this context. diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt index 35a65c487..4315f0423 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt @@ -11,7 +11,7 @@ import kscience.kmath.operations.Space * @property expected the expected shape. * @property actual the actual shape. */ -public class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : +public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") /** diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index f8ca5eed2..fac24504a 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -13,26 +13,33 @@ This subproject implements the following features: > > ```gradle > repositories { +> mavenCentral() > maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } > } > > dependencies { > implementation 'scientifik:kmath-nd4j:0.1.4-dev-8' +> implementation 'org.nd4j:nd4j-native-platform:1.0.0-beta7' > } > ``` > **Gradle Kotlin DSL:** > > ```kotlin > repositories { +> mavenCentral() > maven("https://dl.bintray.com/mipt-npm/scientifik") > maven("https://dl.bintray.com/mipt-npm/dev") > } > > dependencies { > implementation("scientifik:kmath-nd4j:0.1.4-dev-8") +> implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") > } > ``` +> +> This distribution also needs an implementation of ND4J API. The ND4J Native Platform is usually the fastest one, so +> it is included to the snippet. > ## Examples diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts index 110a2ac30..67569b870 100644 --- a/kmath-nd4j/build.gradle.kts +++ b/kmath-nd4j/build.gradle.kts @@ -1,5 +1,5 @@ plugins { - id("scientifik.jvm") + id("ru.mipt.npm.jvm") } dependencies { diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt new file mode 100644 index 000000000..728ce3773 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt @@ -0,0 +1,284 @@ +package kscience.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.factory.Nd4j +import kscience.kmath.operations.* +import kscience.kmath.structures.* + +/** + * Represents [NDAlgebra] over [INDArrayAlgebra]. + * + * @param T the type of ND-structure element. + * @param C the type of the element context. + */ +public interface INDArrayAlgebra : NDAlgebra> { + /** + * Wraps [INDArray] to [N]. + */ + public fun INDArray.wrap(): INDArrayStructure + + public override fun produce(initializer: C.(IntArray) -> T): INDArrayStructure { + val struct = Nd4j.create(*shape)!!.wrap() + struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } + return struct + } + + public override fun map(arg: INDArrayStructure, transform: C.(T) -> T): INDArrayStructure { + check(arg) + val newStruct = arg.ndArray.dup().wrap() + newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } + return newStruct + } + + public override fun mapIndexed( + arg: INDArrayStructure, + transform: C.(index: IntArray, T) -> T + ): INDArrayStructure { + check(arg) + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) } + return new + } + + public override fun combine( + a: INDArrayStructure, + b: INDArrayStructure, + transform: C.(T, T) -> T + ): INDArrayStructure { + check(a, b) + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } + return new + } +} + +/** + * Represents [NDSpace] over [INDArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param S the type of space of structure elements. + */ +public interface INDArraySpace : NDSpace>, INDArrayAlgebra where S : Space { + public override val zero: INDArrayStructure + get() = Nd4j.zeros(*shape).wrap() + + public override fun add(a: INDArrayStructure, b: INDArrayStructure): INDArrayStructure { + check(a, b) + return a.ndArray.add(b.ndArray).wrap() + } + + public override operator fun INDArrayStructure.minus(b: INDArrayStructure): INDArrayStructure { + check(this, b) + return ndArray.sub(b.ndArray).wrap() + } + + public override operator fun INDArrayStructure.unaryMinus(): INDArrayStructure { + check(this) + return ndArray.neg().wrap() + } + + public override fun multiply(a: INDArrayStructure, k: Number): INDArrayStructure { + check(a) + return a.ndArray.mul(k).wrap() + } + + public override operator fun INDArrayStructure.div(k: Number): INDArrayStructure { + check(this) + return ndArray.div(k).wrap() + } + + public override operator fun INDArrayStructure.times(k: Number): INDArrayStructure { + check(this) + return ndArray.mul(k).wrap() + } +} + +/** + * Represents [NDRing] over [INDArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param R the type of ring of structure elements. + */ +public interface INDArrayRing : NDRing>, INDArraySpace where R : Ring { + public override val one: INDArrayStructure + get() = Nd4j.ones(*shape).wrap() + + public override fun multiply(a: INDArrayStructure, b: INDArrayStructure): INDArrayStructure { + check(a, b) + return a.ndArray.mul(b.ndArray).wrap() + } + + public override operator fun INDArrayStructure.minus(b: Number): INDArrayStructure { + check(this) + return ndArray.sub(b).wrap() + } + + public override operator fun INDArrayStructure.plus(b: Number): INDArrayStructure { + check(this) + return ndArray.add(b).wrap() + } + + public override operator fun Number.minus(b: INDArrayStructure): INDArrayStructure { + check(b) + return b.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDField] over [INDArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param F the type field of structure elements. + */ +public interface INDArrayField : NDField>, INDArrayRing where F : Field { + public override fun divide(a: INDArrayStructure, b: INDArrayStructure): INDArrayStructure { + check(a, b) + return a.ndArray.div(b.ndArray).wrap() + } + + public override operator fun Number.div(b: INDArrayStructure): INDArrayStructure { + check(b) + return b.ndArray.rdiv(this).wrap() + } +} + +/** + * Represents [NDField] over [INDArrayRealStructure]. + */ +public class RealINDArrayField(public override val shape: IntArray) : INDArrayField { + public override val elementContext: RealField + get() = RealField + + public override fun INDArray.wrap(): INDArrayStructure = check(asRealStructure()) + + public override operator fun INDArrayStructure.div(arg: Double): INDArrayStructure { + check(this) + return ndArray.div(arg).wrap() + } + + public override operator fun INDArrayStructure.plus(arg: Double): INDArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun INDArrayStructure.minus(arg: Double): INDArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun INDArrayStructure.times(arg: Double): INDArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Double.div(arg: INDArrayStructure): INDArrayStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + public override operator fun Double.minus(arg: INDArrayStructure): INDArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDField] over [INDArrayStructure] of [Float]. + */ +public class FloatINDArrayField(public override val shape: IntArray) : INDArrayField { + public override val elementContext: FloatField + get() = FloatField + + public override fun INDArray.wrap(): INDArrayStructure = check(asFloatStructure()) + + public override operator fun INDArrayStructure.div(arg: Float): INDArrayStructure { + check(this) + return ndArray.div(arg).wrap() + } + + public override operator fun INDArrayStructure.plus(arg: Float): INDArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun INDArrayStructure.minus(arg: Float): INDArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun INDArrayStructure.times(arg: Float): INDArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Float.div(arg: INDArrayStructure): INDArrayStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + public override operator fun Float.minus(arg: INDArrayStructure): INDArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDRing] over [INDArrayIntStructure]. + */ +public class IntINDArrayRing(public override val shape: IntArray) : INDArrayRing { + public override val elementContext: IntRing + get() = IntRing + + public override fun INDArray.wrap(): INDArrayStructure = check(asIntStructure()) + + public override operator fun INDArrayStructure.plus(arg: Int): INDArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun INDArrayStructure.minus(arg: Int): INDArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun INDArrayStructure.times(arg: Int): INDArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Int.minus(arg: INDArrayStructure): INDArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDRing] over [INDArrayStructure] of [Long]. + */ +public class LongINDArrayRing(public override val shape: IntArray) : INDArrayRing { + public override val elementContext: LongRing + get() = LongRing + + public override fun INDArray.wrap(): INDArrayStructure = check(asLongStructure()) + + public override operator fun INDArrayStructure.plus(arg: Long): INDArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun INDArrayStructure.minus(arg: Long): INDArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun INDArrayStructure.times(arg: Long): INDArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Long.minus(arg: INDArrayStructure): INDArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayIterators.kt similarity index 94% rename from kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt rename to kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayIterators.kt index 2759f9fdb..9e7ef9e16 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayIterators.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.nd4j +package kscience.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.shape.Shape @@ -18,7 +18,7 @@ private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Itera } } -internal fun INDArray.indicesIterator(): Iterator = INDArrayIndicesIterator(this) +internal fun INDArray.indicesIterator(): Iterator = INDArrayIndicesIterator(this) private sealed class INDArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { private var i: Int = 0 diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayStructures.kt new file mode 100644 index 000000000..5d4e1a979 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayStructures.kt @@ -0,0 +1,68 @@ +package kscience.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import kscience.kmath.structures.MutableNDStructure +import kscience.kmath.structures.NDStructure + +/** + * Represents a [NDStructure] wrapping an [INDArray] object. + * + * @param T the type of items. + */ +public sealed class INDArrayStructure : MutableNDStructure { + /** + * The wrapped [INDArray]. + */ + public abstract val ndArray: INDArray + + public override val shape: IntArray + get() = ndArray.shape().toIntArray() + + internal abstract fun elementsIterator(): Iterator> + internal fun indicesIterator(): Iterator = ndArray.indicesIterator() + public override fun elements(): Sequence> = Sequence(::elementsIterator) +} + +private data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.intIterator() + override fun get(index: IntArray): Int = ndArray.getInt(*index) + override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } +} + +/** + * Wraps this [INDArray] to [INDArrayStructure]. + */ +public fun INDArray.asIntStructure(): INDArrayStructure = INDArrayIntStructure(this) + +private data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.longIterator() + override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) + override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) } +} + +/** + * Wraps this [INDArray] to [INDArrayStructure]. + */ +public fun INDArray.asLongStructure(): INDArrayStructure = INDArrayLongStructure(this) + +private data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.realIterator() + override fun get(index: IntArray): Double = ndArray.getDouble(*index) + override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } +} + +/** + * Wraps this [INDArray] to [INDArrayStructure]. + */ +public fun INDArray.asRealStructure(): INDArrayStructure = INDArrayRealStructure(this) + +private data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.floatIterator() + override fun get(index: IntArray): Float = ndArray.getFloat(*index) + override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } +} + +/** + * Wraps this [INDArray] to [INDArrayStructure]. + */ +public fun INDArray.asFloatStructure(): INDArrayStructure = INDArrayFloatStructure(this) diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt similarity index 85% rename from kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt rename to kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt index 269fc89c2..798f81c35 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/arrays.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.nd4j +package kscience.kmath.nd4j internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() } internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() } diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt deleted file mode 100644 index c24e2ece6..000000000 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt +++ /dev/null @@ -1,273 +0,0 @@ -package scientifik.kmath.nd4j - -import org.nd4j.linalg.api.ndarray.INDArray -import org.nd4j.linalg.factory.Nd4j -import scientifik.kmath.operations.* -import scientifik.kmath.structures.* - -/** - * Represents [NDAlgebra] over [INDArrayAlgebra]. - * - * @param T the type of ND-structure element. - * @param C the type of the element context. - * @param N the type of the structure. - */ -interface INDArrayAlgebra : NDAlgebra where N : INDArrayStructure, N : MutableNDStructure { - /** - * Wraps [INDArray] to [N]. - */ - fun INDArray.wrap(): N - - override fun produce(initializer: C.(IntArray) -> T): N { - val struct = Nd4j.create(*shape)!!.wrap() - struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } - return struct - } - - override fun map(arg: N, transform: C.(T) -> T): N { - check(arg) - val newStruct = arg.ndArray.dup().wrap() - newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } - return newStruct - } - - override fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N { - check(arg) - val new = Nd4j.create(*shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) } - return new - } - - override fun combine(a: N, b: N, transform: C.(T, T) -> T): N { - check(a, b) - val new = Nd4j.create(*shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } - return new - } -} - -/** - * Represents [NDSpace] over [INDArrayStructure]. - * - * @param T the type of the element contained in ND structure. - * @param N the type of ND structure. - * @param S the type of space of structure elements. - */ -interface INDArraySpace : NDSpace, INDArrayAlgebra - where S : Space, N : INDArrayStructure, N : MutableNDStructure { - - override val zero: N - get() = Nd4j.zeros(*shape).wrap() - - override fun add(a: N, b: N): N { - check(a, b) - return a.ndArray.add(b.ndArray).wrap() - } - - override operator fun N.minus(b: N): N { - check(this, b) - return ndArray.sub(b.ndArray).wrap() - } - - override operator fun N.unaryMinus(): N { - check(this) - return ndArray.neg().wrap() - } - - override fun multiply(a: N, k: Number): N { - check(a) - return a.ndArray.mul(k).wrap() - } - - override operator fun N.div(k: Number): N { - check(this) - return ndArray.div(k).wrap() - } - - override operator fun N.times(k: Number): N { - check(this) - return ndArray.mul(k).wrap() - } -} - -/** - * Represents [NDRing] over [INDArrayStructure]. - * - * @param T the type of the element contained in ND structure. - * @param N the type of ND structure. - * @param R the type of ring of structure elements. - */ -interface INDArrayRing : NDRing, INDArraySpace - where R : Ring, N : INDArrayStructure, N : MutableNDStructure { - - override val one: N - get() = Nd4j.ones(*shape).wrap() - - override fun multiply(a: N, b: N): N { - check(a, b) - return a.ndArray.mul(b.ndArray).wrap() - } - - override operator fun N.minus(b: Number): N { - check(this) - return ndArray.sub(b).wrap() - } - - override operator fun N.plus(b: Number): N { - check(this) - return ndArray.add(b).wrap() - } - - override operator fun Number.minus(b: N): N { - check(b) - return b.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDField] over [INDArrayStructure]. - * - * @param T the type of the element contained in ND structure. - * @param N the type of ND structure. - * @param F the type field of structure elements. - */ -interface INDArrayField : NDField, INDArrayRing - where F : Field, N : INDArrayStructure, N : MutableNDStructure { - override fun divide(a: N, b: N): N { - check(a, b) - return a.ndArray.div(b.ndArray).wrap() - } - - override operator fun Number.div(b: N): N { - check(b) - return b.ndArray.rdiv(this).wrap() - } -} - -/** - * Represents [NDField] over [INDArrayRealStructure]. - */ -class RealINDArrayField(override val shape: IntArray, override val elementContext: Field = RealField) : - INDArrayField, INDArrayRealStructure> { - override fun INDArray.wrap(): INDArrayRealStructure = check(asRealStructure()) - override operator fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure { - check(this) - return ndArray.div(arg).wrap() - } - - override operator fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure { - check(this) - return ndArray.add(arg).wrap() - } - - override operator fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - override operator fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - override operator fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure { - check(arg) - return arg.ndArray.rdiv(this).wrap() - } - - override operator fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDField] over [INDArrayFloatStructure]. - */ -class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field = FloatField) : - INDArrayField, INDArrayFloatStructure> { - override fun INDArray.wrap(): INDArrayFloatStructure = check(asFloatStructure()) - override operator fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure { - check(this) - return ndArray.div(arg).wrap() - } - - override operator fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure { - check(this) - return ndArray.add(arg).wrap() - } - - override operator fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - override operator fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - override operator fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure { - check(arg) - return arg.ndArray.rdiv(this).wrap() - } - - override operator fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDRing] over [INDArrayIntStructure]. - */ -class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring = IntRing) : - INDArrayRing, INDArrayIntStructure> { - override fun INDArray.wrap(): INDArrayIntStructure = check(asIntStructure()) - override operator fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure { - check(this) - return ndArray.add(arg).wrap() - } - - override operator fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - override operator fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - override operator fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDRing] over [INDArrayLongStructure]. - */ -class LongINDArrayRing(override val shape: IntArray, override val elementContext: Ring = LongRing) : - INDArrayRing, INDArrayLongStructure> { - override fun INDArray.wrap(): INDArrayLongStructure = check(asLongStructure()) - override operator fun INDArrayLongStructure.plus(arg: Long): INDArrayLongStructure { - check(this) - return ndArray.add(arg).wrap() - } - - override operator fun INDArrayLongStructure.minus(arg: Long): INDArrayLongStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - override operator fun INDArrayLongStructure.times(arg: Long): INDArrayLongStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - override operator fun Long.minus(arg: INDArrayLongStructure): INDArrayLongStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt deleted file mode 100644 index 39cefee3d..000000000 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ /dev/null @@ -1,80 +0,0 @@ -package scientifik.kmath.nd4j - -import org.nd4j.linalg.api.ndarray.INDArray -import scientifik.kmath.structures.MutableNDStructure -import scientifik.kmath.structures.NDStructure - -/** - * Represents a [NDStructure] wrapping an [INDArray] object. - * - * @param T the type of items. - */ -sealed class INDArrayStructure : MutableNDStructure { - /** - * The wrapped [INDArray]. - */ - abstract val ndArray: INDArray - - override val shape: IntArray - get() = ndArray.shape().toIntArray() - - internal abstract fun elementsIterator(): Iterator> - internal fun indicesIterator(): Iterator = ndArray.indicesIterator() - override fun elements(): Sequence> = Sequence(::elementsIterator) -} - -/** - * Represents a [NDStructure] over [INDArray] elements of which are accessed as ints. - */ -data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure() { - override fun elementsIterator(): Iterator> = ndArray.intIterator() - override fun get(index: IntArray): Int = ndArray.getInt(*index) - override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } -} - -/** - * Wraps this [INDArray] to [INDArrayIntStructure]. - */ -fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this) - -/** - * Represents a [NDStructure] over [INDArray] elements of which are accessed as longs. - */ -data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure() { - override fun elementsIterator(): Iterator> = ndArray.longIterator() - override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) - override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) } -} - -/** - * Wraps this [INDArray] to [INDArrayLongStructure]. - */ -fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this) - -/** - * Represents a [NDStructure] over [INDArray] elements of which are accessed as reals. - */ -data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure() { - override fun elementsIterator(): Iterator> = ndArray.realIterator() - override fun get(index: IntArray): Double = ndArray.getDouble(*index) - override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } -} - -/** - * Wraps this [INDArray] to [INDArrayRealStructure]. - */ -fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this) - -/** - * Represents a [NDStructure] over [INDArray] elements of which are accessed as floats. - */ -data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure() { - override fun elementsIterator(): Iterator> = ndArray.floatIterator() - override fun get(index: IntArray): Float = ndArray.getFloat(*index) - override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } -} - -/** - * Wraps this [INDArray] to [INDArrayFloatStructure]. - */ -fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this) diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayAlgebraTest.kt similarity index 78% rename from kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt rename to kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayAlgebraTest.kt index 4aa40c233..1a4f4c9f3 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt +++ b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayAlgebraTest.kt @@ -1,15 +1,16 @@ -package scientifik.kmath.nd4j +package kscience.kmath.nd4j import org.nd4j.linalg.factory.Nd4j -import scientifik.kmath.operations.invoke +import kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.fail internal class INDArrayAlgebraTest { @Test fun testProduce() { val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } - val expected = Nd4j.create(2, 2)!!.asRealStructure() + val expected = (Nd4j.create(2, 2) ?: fail()).asRealStructure() expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 1)] = 1.0 expected[intArrayOf(1, 0)] = 1.0 @@ -20,7 +21,7 @@ internal class INDArrayAlgebraTest { @Test fun testMap() { val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } } - val expected = Nd4j.create(2, 2)!!.asIntStructure() + val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 3 expected[intArrayOf(0, 1)] = 3 expected[intArrayOf(1, 0)] = 3 @@ -31,7 +32,7 @@ internal class INDArrayAlgebraTest { @Test fun testAdd() { val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 } - val expected = Nd4j.create(2, 2)!!.asIntStructure() + val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 26 expected[intArrayOf(0, 1)] = 26 expected[intArrayOf(1, 0)] = 26 diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayStructureTest.kt similarity index 55% rename from kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt rename to kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayStructureTest.kt index dfede6d32..63426d7f9 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayStructureTest.kt @@ -1,70 +1,71 @@ -package scientifik.kmath.nd4j +package kscience.kmath.nd4j +import kscience.kmath.structures.get import org.nd4j.linalg.factory.Nd4j -import scientifik.kmath.structures.get import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotEquals +import kotlin.test.fail internal class INDArrayStructureTest { @Test fun testElements() { val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct = INDArrayRealStructure(nd) + val struct = nd.asRealStructure() val res = struct.elements().map(Pair::second).toList() assertEquals(listOf(1.0, 2.0, 3.0), res) } @Test fun testShape() { - val nd = Nd4j.rand(10, 2, 3, 6)!! - val struct = INDArrayLongStructure(nd) + val nd = Nd4j.rand(10, 2, 3, 6) ?: fail() + val struct = nd.asRealStructure() assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList()) } @Test fun testEquals() { - val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct1 = INDArrayRealStructure(nd1) + val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail() + val struct1 = nd1.asRealStructure() assertEquals(struct1, struct1) - assertNotEquals(struct1, null as INDArrayRealStructure?) - val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct2 = INDArrayRealStructure(nd2) + assertNotEquals(struct1 as Any?, null) + val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail() + val struct2 = nd2.asRealStructure() assertEquals(struct1, struct2) assertEquals(struct2, struct1) - val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct3 = INDArrayRealStructure(nd3) + val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail() + val struct3 = nd3.asRealStructure() assertEquals(struct2, struct3) assertEquals(struct1, struct3) } @Test fun testHashCode() { - val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct1 = INDArrayRealStructure(nd1) - val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct2 = INDArrayRealStructure(nd2) + val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail() + val struct1 = nd1.asRealStructure() + val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail() + val struct2 = nd2.asRealStructure() assertEquals(struct1.hashCode(), struct2.hashCode()) } @Test fun testDimension() { val nd = Nd4j.rand(8, 16, 3, 7, 1)!! - val struct = INDArrayFloatStructure(nd) + val struct = nd.asFloatStructure() assertEquals(5, struct.dimension) } @Test fun testGet() { - val nd = Nd4j.rand(10, 2, 3, 6)!! - val struct = INDArrayIntStructure(nd) + val nd = Nd4j.rand(10, 2, 3, 6)?:fail() + val struct = nd.asIntStructure() assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0]) } @Test fun testSet() { val nd = Nd4j.rand(17, 12, 4, 8)!! - val struct = INDArrayIntStructure(nd) + val struct = nd.asLongStructure() struct[intArrayOf(1, 2, 3, 4)] = 777 assertEquals(777, struct[1, 2, 3, 4]) } From f0fbebd770f186a6a2031923cc17688abbb9e4a2 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 20:26:03 +0700 Subject: [PATCH 27/69] Add adapters of scalar functions to MST and vice versa --- build.gradle.kts | 1 + kmath-ast-kotlingrad/build.gradle.kts | 8 +++ .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 64 +++++++++++++++++++ kmath-prob/build.gradle.kts | 4 +- kmath-viktor/build.gradle.kts | 4 +- settings.gradle.kts | 6 +- 6 files changed, 82 insertions(+), 5 deletions(-) create mode 100644 kmath-ast-kotlingrad/build.gradle.kts create mode 100644 kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt diff --git a/build.gradle.kts b/build.gradle.kts index 05e2d5979..0da2212a7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -12,6 +12,7 @@ allprojects { maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://jitpack.io") } group = "kscience.kmath" diff --git a/kmath-ast-kotlingrad/build.gradle.kts b/kmath-ast-kotlingrad/build.gradle.kts new file mode 100644 index 000000000..0fe6e6b93 --- /dev/null +++ b/kmath-ast-kotlingrad/build.gradle.kts @@ -0,0 +1,8 @@ +plugins { + id("ru.mipt.npm.jvm") +} + +dependencies { + api("com.github.breandan:kotlingrad:0.3.2") + api(project(":kmath-ast")) +} diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt new file mode 100644 index 000000000..70c8c09f8 --- /dev/null +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -0,0 +1,64 @@ +package kscience.kmath.ast.kotlingrad + +import edu.umontreal.kotlingrad.experimental.* +import kscience.kmath.ast.MST +import kscience.kmath.ast.MstExtendedField +import kscience.kmath.operations.* + +/** + * Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then. + * + * @receiver a scalar function. + * @return the [MST]. + */ +public fun > SFun.mst(): MST = MstExtendedField { + when (this@mst) { + is SVar -> symbol(name) + is SConst -> number(doubleValue) + is Sum -> left.mst() + right.mst() + is Prod -> left.mst() * right.mst() + is Power -> power(left.mst(), (right() as SConst<*>).doubleValue) + is Negative -> -input.mst() + is Log -> ln(left.mst()) / ln(right.mst()) + is Sine -> sin(input.mst()) + is Cosine -> cos(input.mst()) + is Tangent -> tan(input.mst()) + is DProd -> this@mst().mst() + is SComposition -> this@mst().mst() + is VSumAll -> this@mst().mst() + is Derivative -> this@mst().mst() + } +} + +/** + * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. + * + * @receiver an [MST]. + * @return the scalar function. + */ +public fun > MST.sfun(proto: X): SFun { + return when (this) { + is MST.Numeric -> SConst(value) + is MST.Symbolic -> SVar(proto, value) + + is MST.Unary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> value.sfun(proto) + SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto)) + TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto)) + TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto)) + TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto)) + PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5)) + ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E()) + ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto)) + else -> error("Unary operation $operation not defined in $this") + } + + is MST.Binary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto)) + SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) + RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) + FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) + else -> error("Binary operation $operation not defined in $this") + } + } +} diff --git a/kmath-prob/build.gradle.kts b/kmath-prob/build.gradle.kts index 4c9663e5f..186aff944 100644 --- a/kmath-prob/build.gradle.kts +++ b/kmath-prob/build.gradle.kts @@ -1,4 +1,6 @@ -plugins { id("ru.mipt.npm.mpp") } +plugins { + id("ru.mipt.npm.mpp") +} kotlin.sourceSets { commonMain { diff --git a/kmath-viktor/build.gradle.kts b/kmath-viktor/build.gradle.kts index 6fe8ad878..3e5c5912c 100644 --- a/kmath-viktor/build.gradle.kts +++ b/kmath-viktor/build.gradle.kts @@ -1,4 +1,6 @@ -plugins { id("ru.mipt.npm.jvm") } +plugins { + id("ru.mipt.npm.jvm") +} description = "Binding for https://github.com/JetBrains-Research/viktor" diff --git a/settings.gradle.kts b/settings.gradle.kts index 7ece3f25c..5fd072e1a 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -7,7 +7,6 @@ pluginManagement { maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/kotlin/kotlinx") - maven("https://dl.bintray.com/kotlin/kotlin-dev/") } val toolsVersion = "0.6.1-dev-1.4.20-M1" @@ -25,11 +24,11 @@ pluginManagement { } rootProject.name = "kmath" + include( ":kmath-memory", ":kmath-core", ":kmath-functions", -// ":kmath-io", ":kmath-coroutines", ":kmath-histograms", ":kmath-commons", @@ -40,5 +39,6 @@ include( ":kmath-geometry", ":kmath-ast", ":examples", - ":kmath-ejml" + ":kmath-ejml", + ":kmath-ast-kotlingrad" ) From 84f7535fdd7d59103e2c846eb0413bb2bb482cb7 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 20:36:05 +0700 Subject: [PATCH 28/69] Add pow support --- .../main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index 70c8c09f8..741a2534c 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -58,6 +58,7 @@ public fun > MST.sfun(proto: X): SFun { SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) + PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value)) else -> error("Binary operation $operation not defined in $this") } } From 5de9d69237118889b74d4a5074b48db609db92ae Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 21:06:15 +0700 Subject: [PATCH 29/69] Add more fine-grained converters from MST to SVar and SConst --- .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 70 ++++++++++++------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index 741a2534c..3a1191e6a 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -8,8 +8,8 @@ import kscience.kmath.operations.* /** * Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then. * - * @receiver a scalar function. - * @return the [MST]. + * @receiver the scalar function. + * @return a node. */ public fun > SFun.mst(): MST = MstExtendedField { when (this@mst) { @@ -30,36 +30,52 @@ public fun > SFun.mst(): MST = MstExtendedField { } } +/** + * Maps [MST.Numeric] to [SConst] directly. + * + * @receiver the node. + * @return a new constant. + */ +public fun > MST.Numeric.sconst(): SConst = SConst(value) + +/** + * Maps [MST.Symbolic] to [SVar] directly. + * + * @receiver the node. + * @param proto the prototype instance. + * @return a new variable. + */ +public fun > MST.Symbolic.svar(proto: X): SVar = SVar(proto, value) + /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. * - * @receiver an [MST]. - * @return the scalar function. + * @receiver the node. + * @param proto the prototype instance. + * @return a scalar function. */ -public fun > MST.sfun(proto: X): SFun { - return when (this) { - is MST.Numeric -> SConst(value) - is MST.Symbolic -> SVar(proto, value) +public fun > MST.sfun(proto: X): SFun = when (this) { + is MST.Numeric -> sconst() + is MST.Symbolic -> svar(proto) - is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.sfun(proto) - SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto)) - TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto)) - TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto)) - TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto)) - PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5)) - ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E()) - ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto)) - else -> error("Unary operation $operation not defined in $this") - } + is MST.Unary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> value.sfun(proto) + SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto)) + TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto)) + TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto)) + TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto)) + PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5)) + ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E()) + ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto)) + else -> error("Unary operation $operation not defined in $this") + } - is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto)) - SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) - RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) - FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) - PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value)) - else -> error("Binary operation $operation not defined in $this") - } + is MST.Binary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto)) + SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) + RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) + FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) + PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value)) + else -> error("Binary operation $operation not defined in $this") } } From 31c71e0fad0fa8a0c8837cca44e10ea12a9fdf21 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 21:21:08 +0700 Subject: [PATCH 30/69] Add comments on mapping of MST-to-SFun converting --- .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index 3a1191e6a..cfd1f2702 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -3,10 +3,29 @@ package kscience.kmath.ast.kotlingrad import edu.umontreal.kotlingrad.experimental.* import kscience.kmath.ast.MST import kscience.kmath.ast.MstExtendedField +import kscience.kmath.ast.MstExtendedField.unaryMinus import kscience.kmath.operations.* /** * Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then. + * [Power] operation is limited to constant right-hand side arguments. + * + * Detailed mapping is: + * + * - [SVar] -> [MstExtendedField.symbol]; + * - [SConst] -> [MstExtendedField.number]; + * - [Sum] -> [MstExtendedField.add]; + * - [Prod] -> [MstExtendedField.multiply]; + * - [Power] -> [MstExtendedField.power] (limited); + * - [Negative] -> [MstExtendedField.unaryMinus]; + * - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right); + * - [Sine] -> [MstExtendedField.sin]; + * - [Cosine] -> [MstExtendedField.cos]; + * - [Tangent] -> [MstExtendedField.tan]; + * - [DProd] is vector operation, and it is requested to be evaluated; + * - [SComposition] is also requested to be evaluated eagerly; + * - [VSumAll] is requested to be evaluated; + * - [Derivative] is requested to be evaluated. * * @receiver the scalar function. * @return a node. @@ -50,6 +69,13 @@ public fun > MST.Symbolic.svar(proto: X): SVar = SVar(proto, valu /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. * + * Detailed mapping is: + * + * - [MST.Numeric] -> [SConst]; + * - [MST.Symbolic] -> [SVar]; + * - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log]; + * - [MST.Binary] -> [Sum], [Prod], [Power]. + * * @receiver the node. * @param proto the prototype instance. * @return a scalar function. From 57bdee49368208c67eea5259b39223f37fcae386 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 22:34:05 +0700 Subject: [PATCH 31/69] Add test, update MstAlgebra a bit to return concrete types --- .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 2 +- .../kmath/ast/kotlingrad/AdaptingTests.kt | 66 ++++++++++ .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 113 +++++++++--------- 3 files changed, 125 insertions(+), 56 deletions(-) create mode 100644 kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index cfd1f2702..16c96646a 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -36,7 +36,7 @@ public fun > SFun.mst(): MST = MstExtendedField { is SConst -> number(doubleValue) is Sum -> left.mst() + right.mst() is Prod -> left.mst() * right.mst() - is Power -> power(left.mst(), (right() as SConst<*>).doubleValue) + is Power -> power(left.mst(), (right as SConst<*>).doubleValue) is Negative -> -input.mst() is Log -> ln(left.mst()) / ln(right.mst()) is Sine -> sin(input.mst()) diff --git a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt new file mode 100644 index 000000000..94d25e411 --- /dev/null +++ b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt @@ -0,0 +1,66 @@ +package kscience.kmath.ast.kotlingrad + +import edu.umontreal.kotlingrad.experimental.* +import kscience.kmath.asm.compile +import kscience.kmath.ast.MstAlgebra +import kscience.kmath.ast.MstExpression +import kscience.kmath.ast.parseMath +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.fail + +internal class AdaptingTests { + private val proto: DReal = DoublePrecision.prototype + + @Test + fun symbol() { + val c1 = MstAlgebra.symbol("x") + assertTrue(c1.svar(proto).name == "x") + val c2 = "kitten".parseMath().sfun(proto) + if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() + } + + @Test + fun number() { + val c1 = MstAlgebra.number(12354324) + assertTrue(c1.sconst().doubleValue == 12354324.0) + val c2 = "0.234".parseMath().sfun(proto) + if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() + val c3 = "1e-3".parseMath().sfun(proto) + if (c3 is SConst) assertEquals(0.001, c3.value) else fail() + } + + @Test + fun simpleFunctionShape() { + val linear = "2*x+16".parseMath().sfun(proto) + if (linear !is Sum) fail() + if (linear.left !is Prod) fail() + if (linear.right !is SConst) fail() + } + + @Test + fun simpleFunctionDerivative() { + val x = MstAlgebra.symbol("x").svar(proto) + val quadratic = "x^2-4*x-44".parseMath().sfun(proto) + val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() + val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() + assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) + } + + @Test + fun moreComplexDerivative() { + val x = MstAlgebra.symbol("x").svar(proto) + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sfun(proto) + val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile() + + val expectedDerivative = MstExpression( + RealField, + "-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath() + ).compile() + + assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1)) + } +} diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 64a820b20..6ee6ab9af 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -6,14 +6,14 @@ import kscience.kmath.operations.* * [Algebra] over [MST] nodes. */ public object MstAlgebra : NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) + override fun number(value: Number): MST.Numeric = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) - override fun unaryOperation(operation: String, arg: MST): MST = + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MST.Unary(operation, arg) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MST.Binary(operation, left, right) } @@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra { * [Space] over [MST] nodes. */ public object MstSpace : Space, NumericAlgebra { - override val zero: MST = number(0.0) + override val zero: MST.Numeric by lazy { number(0.0) } - override fun number(value: Number): MST = MstAlgebra.number(value) - override fun symbol(value: String): MST = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) + override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) + override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg) } /** * [Ring] over [MST] nodes. */ public object MstRing : Ring, NumericAlgebra { - override val zero: MST + override val zero: MST.Numeric get() = MstSpace.zero - override val one: MST = number(1.0) - override fun number(value: Number): MST = MstSpace.number(value) - override fun symbol(value: String): MST = MstSpace.symbol(value) - override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) + override val one: MST.Numeric by lazy { number(1.0) } - override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) + override fun number(value: Number): MST.Numeric = MstSpace.number(value) + override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) + override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) + override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) + override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b) - override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstSpace.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg) } /** * [Field] over [MST] nodes. */ public object MstField : Field { - public override val zero: MST + public override val zero: MST.Numeric get() = MstRing.zero - public override val one: MST + public override val one: MST.Numeric get() = MstRing.one - public override fun symbol(value: String): MST = MstRing.symbol(value) - public override fun number(value: Number): MST = MstRing.number(value) - public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) + public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b) - public override fun binaryOperation(operation: String, left: MST, right: MST): MST = + public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstRing.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg) } /** * [ExtendedField] over [MST] nodes. */ public object MstExtendedField : ExtendedField { - override val zero: MST + override val zero: MST.Numeric get() = MstField.zero - override val one: MST + override val one: MST.Numeric get() = MstField.one - override fun symbol(value: String): MST = MstField.symbol(value) - override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) - override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) - override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) - override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) - override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) - override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) - override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) - override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) - override fun add(a: MST, b: MST): MST = MstField.add(a, b) - override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) - override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) - override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) - override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) - override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) - override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) + override fun number(value: Number): MST.Numeric = MstField.number(value) + override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) + override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) + override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) + override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) + override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) + override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) + override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) + override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) + override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) + override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) + override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) + override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun power(arg: MST, pow: Number): MST.Binary = + binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + + override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstField.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg) } From 54069fd37ebc77dc8d40a2d46c3bc290583f23cc Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 22:42:34 +0700 Subject: [PATCH 32/69] Add example of new AST API --- examples/build.gradle.kts | 3 +- .../ast/ExpressionsInterpretersBenchmark.kt | 148 ++++++++++-------- .../kscience/kmath/ast/KotlingradSupport.kt | 22 +++ 3 files changed, 103 insertions(+), 70 deletions(-) create mode 100644 examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 900da966b..968b372c3 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -19,7 +19,8 @@ repositories { sourceSets.register("benchmarks") dependencies { -// implementation(project(":kmath-ast")) + implementation(project(":kmath-ast")) + implementation(project(":kmath-ast-kotlingrad")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index f0a32e5bd..b25a61e96 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -1,70 +1,80 @@ package kscience.kmath.ast -// -//import kscience.kmath.asm.compile -//import kscience.kmath.expressions.Expression -//import kscience.kmath.expressions.expressionInField -//import kscience.kmath.expressions.invoke -//import kscience.kmath.operations.Field -//import kscience.kmath.operations.RealField -//import kotlin.random.Random -//import kotlin.system.measureTimeMillis -// -//class ExpressionsInterpretersBenchmark { -// private val algebra: Field = RealField -// fun functionalExpression() { -// val expr = algebra.expressionInField { -// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) -// } -// -// invokeAndSum(expr) -// } -// -// fun mstExpression() { -// val expr = algebra.mstInField { -// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) -// } -// -// invokeAndSum(expr) -// } -// -// fun asmExpression() { -// val expr = algebra.mstInField { -// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) -// }.compile() -// -// invokeAndSum(expr) -// } -// -// private fun invokeAndSum(expr: Expression) { -// val random = Random(0) -// var sum = 0.0 -// -// repeat(1000000) { -// sum += expr("x" to random.nextDouble()) -// } -// -// println(sum) -// } -//} -// -//fun main() { -// val benchmark = ExpressionsInterpretersBenchmark() -// -// val fe = measureTimeMillis { -// benchmark.functionalExpression() -// } -// -// println("fe=$fe") -// -// val mst = measureTimeMillis { -// benchmark.mstExpression() -// } -// -// println("mst=$mst") -// -// val asm = measureTimeMillis { -// benchmark.asmExpression() -// } -// -// println("asm=$asm") -//} + +import kscience.kmath.asm.compile +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.expressionInField +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.Field +import kscience.kmath.operations.RealField +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +internal class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + fun functionalExpression() { + val expr = algebra.expressionInField { + variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) + } + + invokeAndSum(expr) + } + + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + } + + invokeAndSum(expr) + } + + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + }.compile() + + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + val random = Random(0) + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} + +/** + * This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and + * core FunctionalExpressions API. + * + * The expected rating is: + * + * 1. ASM. + * 2. MST. + * 3. FE. + */ +fun main() { + val benchmark = ExpressionsInterpretersBenchmark() + + val fe = measureTimeMillis { + benchmark.functionalExpression() + } + + println("fe=$fe") + + val mst = measureTimeMillis { + benchmark.mstExpression() + } + + println("mst=$mst") + + val asm = measureTimeMillis { + benchmark.asmExpression() + } + + println("asm=$asm") +} diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt new file mode 100644 index 000000000..e63b0c9c0 --- /dev/null +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -0,0 +1,22 @@ +package kscience.kmath.ast + +import edu.umontreal.kotlingrad.experimental.DoublePrecision +import kscience.kmath.asm.compile +import kscience.kmath.ast.kotlingrad.mst +import kscience.kmath.ast.kotlingrad.sfun +import kscience.kmath.ast.kotlingrad.svar +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField + +/** + * In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with + * valid derivative. + */ +fun main() { + val proto = DoublePrecision.prototype + val x by MstAlgebra.symbol("x").svar(proto) + val quadratic = "x^2-4*x-44".parseMath().sfun(proto) + val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() + val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() + assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) +} From 4bf430b2c07a351d116e383a0d7806802af026c2 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 23:17:54 +0700 Subject: [PATCH 33/69] Rename converter functions, add symbol delegate provider for MstAlgebra --- .../kscience/kmath/ast/KotlingradSupport.kt | 8 ++--- .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 36 +++++++++---------- .../kmath/ast/kotlingrad/AdaptingTests.kt | 20 +++++------ .../kotlin/kscience/kmath/ast/extensions.kt | 22 ++++++++++++ 4 files changed, 54 insertions(+), 32 deletions(-) create mode 100644 kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index e63b0c9c0..c2e8456da 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -3,8 +3,8 @@ package kscience.kmath.ast import edu.umontreal.kotlingrad.experimental.DoublePrecision import kscience.kmath.asm.compile import kscience.kmath.ast.kotlingrad.mst -import kscience.kmath.ast.kotlingrad.sfun -import kscience.kmath.ast.kotlingrad.svar +import kscience.kmath.ast.kotlingrad.sFun +import kscience.kmath.ast.kotlingrad.sVar import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField @@ -14,8 +14,8 @@ import kscience.kmath.operations.RealField */ fun main() { val proto = DoublePrecision.prototype - val x by MstAlgebra.symbol("x").svar(proto) - val quadratic = "x^2-4*x-44".parseMath().sfun(proto) + val x by MstAlgebra.symbol("x").sVar(proto) + val quadratic = "x^2-4*x-44".parseMath().sFun(proto) val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index 16c96646a..3e7cd0439 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -55,7 +55,7 @@ public fun > SFun.mst(): MST = MstExtendedField { * @receiver the node. * @return a new constant. */ -public fun > MST.Numeric.sconst(): SConst = SConst(value) +public fun > MST.Numeric.sConst(): SConst = SConst(value) /** * Maps [MST.Symbolic] to [SVar] directly. @@ -64,7 +64,7 @@ public fun > MST.Numeric.sconst(): SConst = SConst(value) * @param proto the prototype instance. * @return a new variable. */ -public fun > MST.Symbolic.svar(proto: X): SVar = SVar(proto, value) +public fun > MST.Symbolic.sVar(proto: X): SVar = SVar(proto, value) /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. @@ -80,28 +80,28 @@ public fun > MST.Symbolic.svar(proto: X): SVar = SVar(proto, valu * @param proto the prototype instance. * @return a scalar function. */ -public fun > MST.sfun(proto: X): SFun = when (this) { - is MST.Numeric -> sconst() - is MST.Symbolic -> svar(proto) +public fun > MST.sFun(proto: X): SFun = when (this) { + is MST.Numeric -> sConst() + is MST.Symbolic -> sVar(proto) is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.sfun(proto) - SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto)) - TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto)) - TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto)) - TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto)) - PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5)) - ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E()) - ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto)) + SpaceOperations.PLUS_OPERATION -> value.sFun(proto) + SpaceOperations.MINUS_OPERATION -> Negative(value.sFun(proto)) + TrigonometricOperations.SIN_OPERATION -> Sine(value.sFun(proto)) + TrigonometricOperations.COS_OPERATION -> Cosine(value.sFun(proto)) + TrigonometricOperations.TAN_OPERATION -> Tangent(value.sFun(proto)) + PowerOperations.SQRT_OPERATION -> Power(value.sFun(proto), SConst(0.5)) + ExponentialOperations.EXP_OPERATION -> Power(value.sFun(proto), E()) + ExponentialOperations.LN_OPERATION -> Log(value.sFun(proto)) else -> error("Unary operation $operation not defined in $this") } is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto)) - SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) - RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) - FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) - PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value)) + SpaceOperations.PLUS_OPERATION -> Sum(left.sFun(proto), right.sFun(proto)) + SpaceOperations.MINUS_OPERATION -> Sum(left.sFun(proto), Negative(right.sFun(proto))) + RingOperations.TIMES_OPERATION -> Prod(left.sFun(proto), right.sFun(proto)) + FieldOperations.DIV_OPERATION -> Prod(left.sFun(proto), Power(right.sFun(proto), Negative(One()))) + PowerOperations.POW_OPERATION -> Power(left.sFun(proto), SConst((right as MST.Numeric).value)) else -> error("Binary operation $operation not defined in $this") } } diff --git a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt index 94d25e411..c3c4602ad 100644 --- a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt +++ b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt @@ -18,24 +18,24 @@ internal class AdaptingTests { @Test fun symbol() { val c1 = MstAlgebra.symbol("x") - assertTrue(c1.svar(proto).name == "x") - val c2 = "kitten".parseMath().sfun(proto) + assertTrue(c1.sVar(proto).name == "x") + val c2 = "kitten".parseMath().sFun(proto) if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() } @Test fun number() { val c1 = MstAlgebra.number(12354324) - assertTrue(c1.sconst().doubleValue == 12354324.0) - val c2 = "0.234".parseMath().sfun(proto) + assertTrue(c1.sConst().doubleValue == 12354324.0) + val c2 = "0.234".parseMath().sFun(proto) if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() - val c3 = "1e-3".parseMath().sfun(proto) + val c3 = "1e-3".parseMath().sFun(proto) if (c3 is SConst) assertEquals(0.001, c3.value) else fail() } @Test fun simpleFunctionShape() { - val linear = "2*x+16".parseMath().sfun(proto) + val linear = "2*x+16".parseMath().sFun(proto) if (linear !is Sum) fail() if (linear.left !is Prod) fail() if (linear.right !is SConst) fail() @@ -43,8 +43,8 @@ internal class AdaptingTests { @Test fun simpleFunctionDerivative() { - val x = MstAlgebra.symbol("x").svar(proto) - val quadratic = "x^2-4*x-44".parseMath().sfun(proto) + val x = MstAlgebra.symbol("x").sVar(proto) + val quadratic = "x^2-4*x-44".parseMath().sFun(proto) val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) @@ -52,8 +52,8 @@ internal class AdaptingTests { @Test fun moreComplexDerivative() { - val x = MstAlgebra.symbol("x").svar(proto) - val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sfun(proto) + val x = MstAlgebra.symbol("x").sVar(proto) + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sFun(proto) val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile() val expectedDerivative = MstExpression( diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt new file mode 100644 index 000000000..cba4bbb13 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt @@ -0,0 +1,22 @@ +package kscience.kmath.ast + +import kscience.kmath.operations.Algebra +import kotlin.properties.ReadOnlyProperty +import kotlin.reflect.KProperty + +/** + * Stores `provideDelegate` method returning property of [MST.Symbolic]. + */ +public object MstSymbolDelegateProvider { + /** + * Returns [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name of the property. + */ + public operator fun provideDelegate(thisRef: Any?, prop: KProperty<*>): ReadOnlyProperty = + ReadOnlyProperty { _, property -> MST.Symbolic(property.name) } +} + +/** + * Returns [MstSymbolDelegateProvider]. + */ +public val Algebra.symbol: MstSymbolDelegateProvider + get() = MstSymbolDelegateProvider From 06c3ce5aaf8020b95a18eaae95381fe80186bee7 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 23:42:13 +0700 Subject: [PATCH 34/69] Simplify extensions.kt --- .../kotlin/kscience/kmath/ast/extensions.kt | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt index cba4bbb13..b790a3a88 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt @@ -1,22 +1,12 @@ package kscience.kmath.ast import kscience.kmath.operations.Algebra +import kotlin.properties.PropertyDelegateProvider import kotlin.properties.ReadOnlyProperty -import kotlin.reflect.KProperty /** - * Stores `provideDelegate` method returning property of [MST.Symbolic]. + * Returns [PropertyDelegateProvider] providing [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name + * of the property. */ -public object MstSymbolDelegateProvider { - /** - * Returns [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name of the property. - */ - public operator fun provideDelegate(thisRef: Any?, prop: KProperty<*>): ReadOnlyProperty = - ReadOnlyProperty { _, property -> MST.Symbolic(property.name) } -} - -/** - * Returns [MstSymbolDelegateProvider]. - */ -public val Algebra.symbol: MstSymbolDelegateProvider - get() = MstSymbolDelegateProvider +public val Algebra.symbol: PropertyDelegateProvider, ReadOnlyProperty, MST.Symbolic>> + get() = PropertyDelegateProvider { _, _ -> ReadOnlyProperty { _, p -> MST.Symbolic(p.name) } } From 381137724dc6e6fbb3795b80a988dd64d2578d11 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 13 Oct 2020 19:47:07 +0700 Subject: [PATCH 35/69] Rename KG module --- examples/build.gradle.kts | 2 +- .../kscience/kmath/ast/KotlingradSupport.kt | 12 ++-- .../build.gradle.kts | 0 .../kmath}/kotlingrad/ScalarsAdapters.kt | 66 +++++++++---------- .../kmath}/kotlingrad/AdaptingTests.kt | 26 ++++---- settings.gradle.kts | 2 +- 6 files changed, 54 insertions(+), 54 deletions(-) rename {kmath-ast-kotlingrad => kmath-kotlingrad}/build.gradle.kts (100%) rename {kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast => kmath-kotlingrad/src/main/kotlin/kscience/kmath}/kotlingrad/ScalarsAdapters.kt (51%) rename {kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast => kmath-kotlingrad/src/test/kotlin/kscience/kmath}/kotlingrad/AdaptingTests.kt (74%) diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 968b372c3..46b677304 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -20,7 +20,7 @@ sourceSets.register("benchmarks") dependencies { implementation(project(":kmath-ast")) - implementation(project(":kmath-ast-kotlingrad")) + implementation(project(":kmath-kotlingrad")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index c2e8456da..366a2b4fd 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -2,9 +2,9 @@ package kscience.kmath.ast import edu.umontreal.kotlingrad.experimental.DoublePrecision import kscience.kmath.asm.compile -import kscience.kmath.ast.kotlingrad.mst -import kscience.kmath.ast.kotlingrad.sFun -import kscience.kmath.ast.kotlingrad.sVar +import kscience.kmath.kotlingrad.toMst +import kscience.kmath.kotlingrad.tSFun +import kscience.kmath.kotlingrad.toSVar import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField @@ -14,9 +14,9 @@ import kscience.kmath.operations.RealField */ fun main() { val proto = DoublePrecision.prototype - val x by MstAlgebra.symbol("x").sVar(proto) - val quadratic = "x^2-4*x-44".parseMath().sFun(proto) - val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() + val x by MstAlgebra.symbol("x").toSVar(proto) + val quadratic = "x^2-4*x-44".parseMath().tSFun(proto) + val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) } diff --git a/kmath-ast-kotlingrad/build.gradle.kts b/kmath-kotlingrad/build.gradle.kts similarity index 100% rename from kmath-ast-kotlingrad/build.gradle.kts rename to kmath-kotlingrad/build.gradle.kts diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt similarity index 51% rename from kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt rename to kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt index 3e7cd0439..f6f0c7a76 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -1,4 +1,4 @@ -package kscience.kmath.ast.kotlingrad +package kscience.kmath.kotlingrad import edu.umontreal.kotlingrad.experimental.* import kscience.kmath.ast.MST @@ -30,22 +30,22 @@ import kscience.kmath.operations.* * @receiver the scalar function. * @return a node. */ -public fun > SFun.mst(): MST = MstExtendedField { - when (this@mst) { +public fun > SFun.toMst(): MST = MstExtendedField { + when (this@toMst) { is SVar -> symbol(name) is SConst -> number(doubleValue) - is Sum -> left.mst() + right.mst() - is Prod -> left.mst() * right.mst() - is Power -> power(left.mst(), (right as SConst<*>).doubleValue) - is Negative -> -input.mst() - is Log -> ln(left.mst()) / ln(right.mst()) - is Sine -> sin(input.mst()) - is Cosine -> cos(input.mst()) - is Tangent -> tan(input.mst()) - is DProd -> this@mst().mst() - is SComposition -> this@mst().mst() - is VSumAll -> this@mst().mst() - is Derivative -> this@mst().mst() + is Sum -> left.toMst() + right.toMst() + is Prod -> left.toMst() * right.toMst() + is Power -> power(left.toMst(), (right as SConst<*>).doubleValue) + is Negative -> -input.toMst() + is Log -> ln(left.toMst()) / ln(right.toMst()) + is Sine -> sin(input.toMst()) + is Cosine -> cos(input.toMst()) + is Tangent -> tan(input.toMst()) + is DProd -> this@toMst().toMst() + is SComposition -> this@toMst().toMst() + is VSumAll -> this@toMst().toMst() + is Derivative -> this@toMst().toMst() } } @@ -55,7 +55,7 @@ public fun > SFun.mst(): MST = MstExtendedField { * @receiver the node. * @return a new constant. */ -public fun > MST.Numeric.sConst(): SConst = SConst(value) +public fun > MST.Numeric.toSConst(): SConst = SConst(value) /** * Maps [MST.Symbolic] to [SVar] directly. @@ -64,7 +64,7 @@ public fun > MST.Numeric.sConst(): SConst = SConst(value) * @param proto the prototype instance. * @return a new variable. */ -public fun > MST.Symbolic.sVar(proto: X): SVar = SVar(proto, value) +public fun > MST.Symbolic.toSVar(proto: X): SVar = SVar(proto, value) /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. @@ -80,28 +80,28 @@ public fun > MST.Symbolic.sVar(proto: X): SVar = SVar(proto, valu * @param proto the prototype instance. * @return a scalar function. */ -public fun > MST.sFun(proto: X): SFun = when (this) { - is MST.Numeric -> sConst() - is MST.Symbolic -> sVar(proto) +public fun > MST.tSFun(proto: X): SFun = when (this) { + is MST.Numeric -> toSConst() + is MST.Symbolic -> toSVar(proto) is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.sFun(proto) - SpaceOperations.MINUS_OPERATION -> Negative(value.sFun(proto)) - TrigonometricOperations.SIN_OPERATION -> Sine(value.sFun(proto)) - TrigonometricOperations.COS_OPERATION -> Cosine(value.sFun(proto)) - TrigonometricOperations.TAN_OPERATION -> Tangent(value.sFun(proto)) - PowerOperations.SQRT_OPERATION -> Power(value.sFun(proto), SConst(0.5)) - ExponentialOperations.EXP_OPERATION -> Power(value.sFun(proto), E()) - ExponentialOperations.LN_OPERATION -> Log(value.sFun(proto)) + SpaceOperations.PLUS_OPERATION -> value.tSFun(proto) + SpaceOperations.MINUS_OPERATION -> Negative(value.tSFun(proto)) + TrigonometricOperations.SIN_OPERATION -> Sine(value.tSFun(proto)) + TrigonometricOperations.COS_OPERATION -> Cosine(value.tSFun(proto)) + TrigonometricOperations.TAN_OPERATION -> Tangent(value.tSFun(proto)) + PowerOperations.SQRT_OPERATION -> Power(value.tSFun(proto), SConst(0.5)) + ExponentialOperations.EXP_OPERATION -> Power(value.tSFun(proto), E()) + ExponentialOperations.LN_OPERATION -> Log(value.tSFun(proto)) else -> error("Unary operation $operation not defined in $this") } is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> Sum(left.sFun(proto), right.sFun(proto)) - SpaceOperations.MINUS_OPERATION -> Sum(left.sFun(proto), Negative(right.sFun(proto))) - RingOperations.TIMES_OPERATION -> Prod(left.sFun(proto), right.sFun(proto)) - FieldOperations.DIV_OPERATION -> Prod(left.sFun(proto), Power(right.sFun(proto), Negative(One()))) - PowerOperations.POW_OPERATION -> Power(left.sFun(proto), SConst((right as MST.Numeric).value)) + SpaceOperations.PLUS_OPERATION -> Sum(left.tSFun(proto), right.tSFun(proto)) + SpaceOperations.MINUS_OPERATION -> Sum(left.tSFun(proto), Negative(right.tSFun(proto))) + RingOperations.TIMES_OPERATION -> Prod(left.tSFun(proto), right.tSFun(proto)) + FieldOperations.DIV_OPERATION -> Prod(left.tSFun(proto), Power(right.tSFun(proto), Negative(One()))) + PowerOperations.POW_OPERATION -> Power(left.tSFun(proto), SConst((right as MST.Numeric).value)) else -> error("Binary operation $operation not defined in $this") } } diff --git a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt similarity index 74% rename from kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt rename to kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt index c3c4602ad..25bdbf4be 100644 --- a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt +++ b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt @@ -1,4 +1,4 @@ -package kscience.kmath.ast.kotlingrad +package kscience.kmath.kotlingrad import edu.umontreal.kotlingrad.experimental.* import kscience.kmath.asm.compile @@ -18,24 +18,24 @@ internal class AdaptingTests { @Test fun symbol() { val c1 = MstAlgebra.symbol("x") - assertTrue(c1.sVar(proto).name == "x") - val c2 = "kitten".parseMath().sFun(proto) + assertTrue(c1.toSVar(proto).name == "x") + val c2 = "kitten".parseMath().tSFun(proto) if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() } @Test fun number() { val c1 = MstAlgebra.number(12354324) - assertTrue(c1.sConst().doubleValue == 12354324.0) - val c2 = "0.234".parseMath().sFun(proto) + assertTrue(c1.toSConst().doubleValue == 12354324.0) + val c2 = "0.234".parseMath().tSFun(proto) if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() - val c3 = "1e-3".parseMath().sFun(proto) + val c3 = "1e-3".parseMath().tSFun(proto) if (c3 is SConst) assertEquals(0.001, c3.value) else fail() } @Test fun simpleFunctionShape() { - val linear = "2*x+16".parseMath().sFun(proto) + val linear = "2*x+16".parseMath().tSFun(proto) if (linear !is Sum) fail() if (linear.left !is Prod) fail() if (linear.right !is SConst) fail() @@ -43,18 +43,18 @@ internal class AdaptingTests { @Test fun simpleFunctionDerivative() { - val x = MstAlgebra.symbol("x").sVar(proto) - val quadratic = "x^2-4*x-44".parseMath().sFun(proto) - val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() + val x = MstAlgebra.symbol("x").toSVar(proto) + val quadratic = "x^2-4*x-44".parseMath().tSFun(proto) + val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) } @Test fun moreComplexDerivative() { - val x = MstAlgebra.symbol("x").sVar(proto) - val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sFun(proto) - val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile() + val x = MstAlgebra.symbol("x").toSVar(proto) + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().tSFun(proto) + val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile() val expectedDerivative = MstExpression( RealField, diff --git a/settings.gradle.kts b/settings.gradle.kts index 5fd072e1a..9343db854 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -40,5 +40,5 @@ include( ":kmath-ast", ":examples", ":kmath-ejml", - ":kmath-ast-kotlingrad" + ":kmath-kotlingrad" ) From 2723c376d9b123cb0d311411d6757d6fb8f379b7 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 13 Oct 2020 22:09:39 +0700 Subject: [PATCH 36/69] Use KG DSL instead of raw scalar construction --- .../kmath/kotlingrad/ScalarsAdapters.kt | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt index f6f0c7a76..e4777282f 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -16,7 +16,7 @@ import kscience.kmath.operations.* * - [SConst] -> [MstExtendedField.number]; * - [Sum] -> [MstExtendedField.add]; * - [Prod] -> [MstExtendedField.multiply]; - * - [Power] -> [MstExtendedField.power] (limited); + * - [Power] -> [MstExtendedField.power] (limited to constant exponents only); * - [Negative] -> [MstExtendedField.unaryMinus]; * - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right); * - [Sine] -> [MstExtendedField.sin]; @@ -36,7 +36,7 @@ public fun > SFun.toMst(): MST = MstExtendedField { is SConst -> number(doubleValue) is Sum -> left.toMst() + right.toMst() is Prod -> left.toMst() * right.toMst() - is Power -> power(left.toMst(), (right as SConst<*>).doubleValue) + is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue) is Negative -> -input.toMst() is Log -> ln(left.toMst()) / ln(right.toMst()) is Sine -> sin(input.toMst()) @@ -86,22 +86,22 @@ public fun > MST.tSFun(proto: X): SFun = when (this) { is MST.Unary -> when (operation) { SpaceOperations.PLUS_OPERATION -> value.tSFun(proto) - SpaceOperations.MINUS_OPERATION -> Negative(value.tSFun(proto)) - TrigonometricOperations.SIN_OPERATION -> Sine(value.tSFun(proto)) - TrigonometricOperations.COS_OPERATION -> Cosine(value.tSFun(proto)) - TrigonometricOperations.TAN_OPERATION -> Tangent(value.tSFun(proto)) - PowerOperations.SQRT_OPERATION -> Power(value.tSFun(proto), SConst(0.5)) - ExponentialOperations.EXP_OPERATION -> Power(value.tSFun(proto), E()) - ExponentialOperations.LN_OPERATION -> Log(value.tSFun(proto)) + SpaceOperations.MINUS_OPERATION -> -value.tSFun(proto) + TrigonometricOperations.SIN_OPERATION -> sin(value.tSFun(proto)) + TrigonometricOperations.COS_OPERATION -> cos(value.tSFun(proto)) + TrigonometricOperations.TAN_OPERATION -> tan(value.tSFun(proto)) + PowerOperations.SQRT_OPERATION -> value.tSFun(proto) pow SConst(0.5) + ExponentialOperations.EXP_OPERATION -> E() pow value.tSFun(proto) + ExponentialOperations.LN_OPERATION -> value.tSFun(proto).ln() else -> error("Unary operation $operation not defined in $this") } is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> Sum(left.tSFun(proto), right.tSFun(proto)) - SpaceOperations.MINUS_OPERATION -> Sum(left.tSFun(proto), Negative(right.tSFun(proto))) - RingOperations.TIMES_OPERATION -> Prod(left.tSFun(proto), right.tSFun(proto)) - FieldOperations.DIV_OPERATION -> Prod(left.tSFun(proto), Power(right.tSFun(proto), Negative(One()))) - PowerOperations.POW_OPERATION -> Power(left.tSFun(proto), SConst((right as MST.Numeric).value)) + SpaceOperations.PLUS_OPERATION -> left.tSFun(proto) + right.tSFun(proto) + SpaceOperations.MINUS_OPERATION -> left.tSFun(proto) - right.tSFun(proto) + RingOperations.TIMES_OPERATION -> left.tSFun(proto) * right.tSFun(proto) + FieldOperations.DIV_OPERATION -> left.tSFun(proto) / right.tSFun(proto) + PowerOperations.POW_OPERATION -> left.tSFun(proto) pow (right as MST.Numeric).toSConst() else -> error("Binary operation $operation not defined in $this") } } From ea0ecc0fba58b31cb9e24020db20332cfb9e73bd Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 13 Oct 2020 22:18:44 +0700 Subject: [PATCH 37/69] Use postfix op. form --- .../main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt index e4777282f..99ab5e635 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -90,7 +90,7 @@ public fun > MST.tSFun(proto: X): SFun = when (this) { TrigonometricOperations.SIN_OPERATION -> sin(value.tSFun(proto)) TrigonometricOperations.COS_OPERATION -> cos(value.tSFun(proto)) TrigonometricOperations.TAN_OPERATION -> tan(value.tSFun(proto)) - PowerOperations.SQRT_OPERATION -> value.tSFun(proto) pow SConst(0.5) + PowerOperations.SQRT_OPERATION -> value.tSFun(proto).sqrt() ExponentialOperations.EXP_OPERATION -> E() pow value.tSFun(proto) ExponentialOperations.LN_OPERATION -> value.tSFun(proto).ln() else -> error("Unary operation $operation not defined in $this") From e44423192d8b3923893752d4c32056a384fadcb9 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 13 Oct 2020 20:34:17 +0300 Subject: [PATCH 38/69] Tools version update --- build.gradle.kts | 2 +- .../main/kotlin/kscience/kmath/operations/ComplexDemo.kt | 4 ++-- .../kscience/kmath/commons/expressions/DiffExpression.kt | 7 ++++--- .../kscience/kmath/commons/expressions/AutoDiffTest.kt | 2 +- .../commonMain/kotlin/kscience/kmath/operations/Complex.kt | 1 + settings.gradle.kts | 2 +- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 05e2d5979..239ea1296 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -24,4 +24,4 @@ subprojects { readme { readmeTemplate = file("docs/templates/README-TEMPLATE.md") -} +} \ No newline at end of file diff --git a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt index 34b3c9981..e84fd8df3 100644 --- a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt +++ b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt @@ -6,8 +6,8 @@ import kscience.kmath.structures.complex fun main() { // 2d element - val element = NDElement.complex(2, 2) { index: IntArray -> - Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) + val element = NDElement.complex(2, 2) { (i,j) -> + Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble()) } println(element) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt index c39f0d04c..1eca1a773 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt @@ -16,7 +16,7 @@ import kotlin.properties.ReadOnlyProperty */ public class DerivativeStructureField( public val order: Int, - public val parameters: Map + public val parameters: Map, ) : ExtendedField { public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) } public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) } @@ -85,8 +85,9 @@ public class DerivativeStructureField( /** * A constructs that creates a derivative structure with required order on-demand */ -public class DiffExpression(public val function: DerivativeStructureField.() -> DerivativeStructure) : - Expression { +public class DiffExpression( + public val function: DerivativeStructureField.() -> DerivativeStructure, +) : Expression { public override operator fun invoke(arguments: Map): Double = DerivativeStructureField( 0, arguments diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt index f905e6818..197faaf49 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt @@ -18,7 +18,7 @@ internal inline fun diff( internal class AutoDiffTest { @Test fun derivativeStructureFieldTest() { - val res = diff(3, "x" to 1.0, "y" to 1.0) { + val res: Double = diff(3, "x" to 1.0, "y" to 1.0) { val x by variable val y = variable("y") val z = x * (-sin(x * y) + y) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt index 37055a5c8..703931c7c 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt @@ -195,6 +195,7 @@ public data class Complex(val re: Double, val im: Double) : FieldElement Date: Mon, 19 Oct 2020 22:51:33 +0300 Subject: [PATCH 39/69] New Expression API --- README.md | 14 +- build.gradle.kts | 4 + docs/templates/README-TEMPLATE.md | 2 +- .../kscience/kmath/ast/MstExpression.kt | 10 +- .../kmath/asm/internal/mapIntrinsics.kt | 4 +- .../kscience/kmath/asm/TestAsmAlgebras.kt | 2 +- ...on.kt => DerivativeStructureExpression.kt} | 93 ++--- ...t => DerivativeStructureExpressionTest.kt} | 28 +- kmath-core/README.md | 9 +- kmath-core/build.gradle.kts | 2 +- .../kscience/kmath/expressions/Expression.kt | 81 ++++- .../FunctionalExpressionAlgebra.kt | 62 +--- .../kmath/expressions/SimpleAutoDiff.kt | 329 ++++++++++++++++++ .../kotlin/kscience/kmath/misc/AutoDiff.kt | 266 -------------- .../kmath/expressions/ExpressionFieldTest.kt | 22 +- .../kmath/expressions/SimpleAutoDiffTest.kt | 277 +++++++++++++++ .../kscience/kmath/misc/AutoDiffTest.kt | 261 -------------- 17 files changed, 794 insertions(+), 672 deletions(-) rename kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/{DiffExpression.kt => DerivativeStructureExpression.kt} (50%) rename kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/{AutoDiffTest.kt => DerivativeStructureExpressionTest.kt} (51%) create mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt delete mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt create mode 100644 kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt delete mode 100644 kmath-core/src/commonTest/kotlin/kscience/kmath/misc/AutoDiffTest.kt diff --git a/README.md b/README.md index 708bd8eb1..cbdf98afb 100644 --- a/README.md +++ b/README.md @@ -53,9 +53,7 @@ can be used for a wide variety of purposes from high performance calculations to * **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free to submit a feature request if you want something to be done first. - -* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures. - + ## Planned features * **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks. @@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to > **Maturity**: EXPERIMENTAL
+* ### [kmath-ejml](kmath-ejml) +> +> +> **Maturity**: EXPERIMENTAL +
+ * ### [kmath-for-real](kmath-for-real) > > @@ -178,8 +182,8 @@ repositories{ } dependencies{ - api("kscience.kmath:kmath-core:0.2.0-dev-1") - //api("kscience.kmath:kmath-core-jvm:0.2.0-dev-1") for jvm-specific version + api("kscience.kmath:kmath-core:0.2.0-dev-2") + //api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version } ``` diff --git a/build.gradle.kts b/build.gradle.kts index 239ea1296..74b76d731 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -24,4 +24,8 @@ subprojects { readme { readmeTemplate = file("docs/templates/README-TEMPLATE.md") +} + +apiValidation{ + validationDisabled = true } \ No newline at end of file diff --git a/docs/templates/README-TEMPLATE.md b/docs/templates/README-TEMPLATE.md index f451adb24..5117e0694 100644 --- a/docs/templates/README-TEMPLATE.md +++ b/docs/templates/README-TEMPLATE.md @@ -107,4 +107,4 @@ with the same artifact names. ## Contributing -The project requires a lot of additional work. Please feel free to contribute in any way and propose new features. +The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero). \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt index 483bc530c..5ca75e993 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -14,8 +14,8 @@ import kotlin.contracts.contract * @author Alexander Nozik */ public class MstExpression(public val algebra: Algebra, public val mst: MST) : Expression { - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { - override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun binaryOperation(operation: String, left: T, right: T): T = @@ -27,7 +27,7 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS error("Numeric nodes are not supported by $this") } - override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) + override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) } /** @@ -37,7 +37,7 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS */ public inline fun , E : Algebra> A.mst( mstAlgebra: E, - block: E.() -> MST + block: E.() -> MST, ): MstExpression = MstExpression(this, mstAlgebra.block()) /** @@ -116,7 +116,7 @@ public inline fun > FunctionalExpressionField> FunctionalExpressionExtendedField.mstInExtendedField( - block: MstExtendedField.() -> MST + block: MstExtendedField.() -> MST, ): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return algebra.mstInExtendedField(block) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt index 708b3c2b4..09e9a71b0 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt @@ -2,6 +2,8 @@ package kscience.kmath.asm.internal +import kscience.kmath.expressions.StringSymbol + /** * Gets value with given [key] or throws [IllegalStateException] whenever it is not present. * @@ -9,4 +11,4 @@ package kscience.kmath.asm.internal */ @JvmOverloads internal fun Map.getOrFail(key: K, default: V? = null): V = - this[key] ?: default ?: error("Parameter not found: $key") + this[StringSymbol(key.toString())] ?: default ?: error("Parameter not found: $key") diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt index 0cf1307d1..5eebfe43d 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt @@ -1,6 +1,5 @@ package kscience.kmath.asm -import kscience.kmath.asm.compile import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInRing import kscience.kmath.ast.mstInSpace @@ -11,6 +10,7 @@ import kotlin.test.Test import kotlin.test.assertEquals internal class TestAsmAlgebras { + @Test fun space() { val res1 = ByteRing.mstInSpace { diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt similarity index 50% rename from kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt rename to kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 1eca1a773..9a27e40cd 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -1,48 +1,57 @@ package kscience.kmath.commons.expressions +import kscience.kmath.expressions.DifferentiableExpression import kscience.kmath.expressions.Expression import kscience.kmath.expressions.ExpressionAlgebra +import kscience.kmath.expressions.Symbol import kscience.kmath.operations.ExtendedField -import kscience.kmath.operations.Field -import kscience.kmath.operations.invoke import org.apache.commons.math3.analysis.differentiation.DerivativeStructure -import kotlin.properties.ReadOnlyProperty /** * A field over commons-math [DerivativeStructure]. * * @property order The derivation order. - * @property parameters The map of free parameters. + * @property bindings The map of bindings values. All bindings are considered free parameters */ public class DerivativeStructureField( public val order: Int, - public val parameters: Map, -) : ExtendedField { - public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) } - public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) } + private val bindings: Map +) : ExtendedField, ExpressionAlgebra { + public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) } + public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) } - private val variables: Map = parameters.mapValues { (key, value) -> - DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) + /** + * A class that implements both [DerivativeStructure] and a [Symbol] + */ + public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : + DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { + override val identity: Any = symbol.identity } - public val variable: ReadOnlyProperty = ReadOnlyProperty { _, property -> - variables[property.name] ?: error("A variable with name ${property.name} does not exist") + /** + * Identity-based symbol bindings map + */ + private val variables: Map = bindings.entries.associate { (key, value) -> + key.identity to DerivativeStructureSymbol(key, value) } - public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure = - variables[name] ?: default ?: error("A variable with name $name does not exist") + override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value) - public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble()) + public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] - public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { - return deriv(mapOf(parName to order)) + public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) + + public fun Number.const(): DerivativeStructure = const(toDouble()) + + public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double { + return derivative(mapOf(parameter to order)) } - public fun DerivativeStructure.deriv(orders: Map): Double { - return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray()) + public fun DerivativeStructure.derivative(orders: Map): Double { + return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray()) } - public fun DerivativeStructure.deriv(vararg orders: Pair): Double = deriv(mapOf(*orders)) + public fun DerivativeStructure.derivative(vararg orders: Pair): Double = derivative(mapOf(*orders)) public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) { @@ -85,48 +94,16 @@ public class DerivativeStructureField( /** * A constructs that creates a derivative structure with required order on-demand */ -public class DiffExpression( +public class DerivativeStructureExpression( public val function: DerivativeStructureField.() -> DerivativeStructure, -) : Expression { - public override operator fun invoke(arguments: Map): Double = DerivativeStructureField( - 0, - arguments - ).function().value +) : DifferentiableExpression { + public override operator fun invoke(arguments: Map): Double = + DerivativeStructureField(0, arguments).function().value /** * Get the derivative expression with given orders - * TODO make result [DiffExpression] */ - public fun derivative(orders: Map): Expression = Expression { arguments -> - (DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) } + public override fun derivative(orders: Map): Expression = Expression { arguments -> + with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) } } - - //TODO add gradient and maybe other vector operators -} - -public fun DiffExpression.derivative(vararg orders: Pair): Expression = derivative(mapOf(*orders)) -public fun DiffExpression.derivative(name: String): Expression = derivative(name to 1) - -/** - * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) - */ -public object DiffExpressionAlgebra : ExpressionAlgebra, Field { - public override val zero: DiffExpression = DiffExpression { 0.0.const() } - public override val one: DiffExpression = DiffExpression { 1.0.const() } - - public override fun variable(name: String, default: Double?): DiffExpression = - DiffExpression { variable(name, default?.const()) } - - public override fun const(value: Double): DiffExpression = DiffExpression { value.const() } - - public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression = - DiffExpression { a.function(this) + b.function(this) } - - public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k } - - public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression = - DiffExpression { a.function(this) * b.function(this) } - - public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression = - DiffExpression { a.function(this) / b.function(this) } } diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt similarity index 51% rename from kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt rename to kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt index 197faaf49..8886e123f 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt @@ -1,6 +1,6 @@ package kscience.kmath.commons.expressions -import kscience.kmath.expressions.invoke +import kscience.kmath.expressions.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.test.Test @@ -8,33 +8,37 @@ import kotlin.test.assertEquals internal inline fun diff( order: Int, - vararg parameters: Pair, - block: DerivativeStructureField.() -> R + vararg parameters: Pair, + block: DerivativeStructureField.() -> R, ): R { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return DerivativeStructureField(order, mapOf(*parameters)).run(block) } internal class AutoDiffTest { + private val x by symbol + private val y by symbol + @Test fun derivativeStructureFieldTest() { - val res: Double = diff(3, "x" to 1.0, "y" to 1.0) { - val x by variable - val y = variable("y") + val res: Double = diff(3, x to 1.0, y to 1.0) { + val x = bind(x)//by binding() + val y = symbol("y") val z = x * (-sin(x * y) + y) - z.deriv("x") + z.derivative(x) } + println(res) } @Test fun autoDifTest() { - val f = DiffExpression { - val x by variable - val y by variable + val f = DerivativeStructureExpression { + val x by binding() + val y by binding() x.pow(2) + 2 * x * y + y.pow(2) + 1 } - assertEquals(10.0, f("x" to 1.0, "y" to 2.0)) - assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0)) + assertEquals(10.0, f(x to 1.0, y to 2.0)) + assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0)) } } diff --git a/kmath-core/README.md b/kmath-core/README.md index 2cf7ed5dc..6935c0d3c 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -12,7 +12,7 @@ The core features of KMath: > #### Artifact: > -> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`. +> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`. > > Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion) > @@ -22,25 +22,28 @@ The core features of KMath: > > ```gradle > repositories { +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } > maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/hotkeytlt/maven' } + > } > > dependencies { -> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1' +> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2' > } > ``` > **Gradle Kotlin DSL:** > > ```kotlin > repositories { +> maven("https://dl.bintray.com/kotlin/kotlin-eap") > maven("https://dl.bintray.com/mipt-npm/kscience") > maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/hotkeytlt/maven") > } > > dependencies { -> implementation("kscience.kmath:kmath-core:0.2.0-dev-1") +> implementation("kscience.kmath:kmath-core:0.2.0-dev-2") > } > ``` diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index b56151abe..bd254c39d 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -41,6 +41,6 @@ readme { feature( id = "autodif", description = "Automatic differentiation", - ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt" + ref = "src/commonMain/kotlin/kscience/kmath/misc/SimpleAutoDiff.kt" ) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index 5ade9e3ca..d64eb5a55 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -1,6 +1,26 @@ package kscience.kmath.expressions import kscience.kmath.operations.Algebra +import kotlin.jvm.JvmName +import kotlin.properties.ReadOnlyProperty + +/** + * A marker interface for a symbol. A symbol mus have an identity + */ +public interface Symbol { + /** + * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol. + * By default uses object identity + */ + public val identity: Any get() = this +} + +/** + * A [Symbol] with a [String] identity + */ +public inline class StringSymbol(override val identity: String) : Symbol { + override fun toString(): String = identity +} /** * An elementary function that could be invoked on a map of arguments @@ -12,30 +32,81 @@ public fun interface Expression { * @param arguments the map of arguments. * @return the value. */ - public operator fun invoke(arguments: Map): T + public operator fun invoke(arguments: Map): T public companion object } +/** + * Invlode an expression without parameters + */ +public operator fun Expression.invoke(): T = invoke(emptyMap()) +//This method exists to avoid resolution ambiguity of vararg methods + /** * Calls this expression from arguments. * * @param pairs the pair of arguments' names to values. * @return the value. */ -public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) +@JvmName("callBySymbol") +public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) + +@JvmName("callByString") +public operator fun Expression.invoke(vararg pairs: Pair): T = + invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) + +/** + * And object that could be differentiated + */ +public interface Differentiable { + public fun derivative(orders: Map): T +} + +public interface DifferentiableExpression : Differentiable>, Expression + +public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = + derivative(mapOf(*orders)) + +public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1) + +public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) /** * A context for expression construction + * + * @param T type of the constants for the expression + * @param E type of the actual expression state */ -public interface ExpressionAlgebra : Algebra { +public interface ExpressionAlgebra : Algebra { + /** - * Introduce a variable into expression context + * Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context. */ - public fun variable(name: String, default: T? = null): E + public fun bindOrNull(symbol: Symbol): E? + + /** + * Bind a string to a context using [StringSymbol] + */ + override fun symbol(value: String): E = bind(StringSymbol(value)) /** * A constant expression which does not depend on arguments */ public fun const(value: T): E } + +/** + * Bind a given [Symbol] to this context variable and produce context-specific object. + */ +public fun ExpressionAlgebra.bind(symbol: Symbol): E = + bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this") + +public val symbol: ReadOnlyProperty = ReadOnlyProperty { _, property -> + StringSymbol(property.name) +} + +public fun ExpressionAlgebra.binding(): ReadOnlyProperty = + ReadOnlyProperty { _, property -> + bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist") + } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 5b050dd36..9fd15238a 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -2,39 +2,6 @@ package kscience.kmath.expressions import kscience.kmath.operations.* -internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : - Expression { - override operator fun invoke(arguments: Map): T = - context.unaryOperation(name, expr.invoke(arguments)) -} - -internal class FunctionalBinaryOperation( - val context: Algebra, - val name: String, - val first: Expression, - val second: Expression -) : Expression { - override operator fun invoke(arguments: Map): T = - context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) -} - -internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { - override operator fun invoke(arguments: Map): T = - arguments[name] ?: default ?: error("Parameter not found: $name") -} - -internal class FunctionalConstantExpression(val value: T) : Expression { - override operator fun invoke(arguments: Map): T = value -} - -internal class FunctionalConstProductExpression( - val context: Space, - private val expr: Expression, - val const: Number -) : Expression { - override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) -} - /** * A context class for [Expression] construction. * @@ -45,24 +12,32 @@ public abstract class FunctionalExpressionAlgebra>(public val /** * Builds an Expression of constant expression which does not depend on arguments. */ - public override fun const(value: T): Expression = FunctionalConstantExpression(value) + public override fun const(value: T): Expression = Expression { value } /** * Builds an Expression to access a variable. */ - public override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) + public override fun bindOrNull(symbol: Symbol): Expression? = Expression { arguments -> + arguments[symbol] ?: error("Argument not found: $symbol") + } /** * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. */ - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - FunctionalBinaryOperation(algebra, operation, left, right) + public override fun binaryOperation( + operation: String, + left: Expression, + right: Expression, + ): Expression = Expression { arguments -> + algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments)) + } /** * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. */ - public override fun unaryOperation(operation: String, arg: Expression): Expression = - FunctionalUnaryOperation(algebra, operation, arg) + public override fun unaryOperation(operation: String, arg: Expression): Expression = Expression { arguments -> + algebra.unaryOperation(operation, arg.invoke(arguments)) + } } /** @@ -81,8 +56,9 @@ public open class FunctionalExpressionSpace>(algebra: A) : /** * Builds an Expression of multiplication of expression by number. */ - public override fun multiply(a: Expression, k: Number): Expression = - FunctionalConstProductExpression(algebra, a, k) + public override fun multiply(a: Expression, k: Number): Expression = Expression { arguments -> + algebra.multiply(a.invoke(arguments), k) + } public operator fun Expression.plus(arg: T): Expression = this + const(arg) public operator fun Expression.minus(arg: T): Expression = this - const(arg) @@ -118,8 +94,8 @@ public open class FunctionalExpressionRing(algebra: A) : FunctionalExpress } public open class FunctionalExpressionField(algebra: A) : - FunctionalExpressionRing(algebra), - Field> where A : Field, A : NumericAlgebra { + FunctionalExpressionRing(algebra), Field> + where A : Field, A : NumericAlgebra { /** * Builds an Expression of division an expression by another one. */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt new file mode 100644 index 000000000..5e8fe3e99 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -0,0 +1,329 @@ +package kscience.kmath.expressions + +import kscience.kmath.linear.Point +import kscience.kmath.operations.* +import kscience.kmath.structures.asBuffer +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/* + * Implementation of backward-mode automatic differentiation. + * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d + */ + + +/** + * A [Symbol] with bound value + */ +public interface BoundSymbol : Symbol { + public val value: T +} + +/** + * Bind a [Symbol] to a [value] and produce [BoundSymbol] + */ +public fun Symbol.bind(value: T): BoundSymbol = object : BoundSymbol { + override val identity = this@bind.identity + override val value: T = value +} + +/** + * Represents result of [withAutoDiff] call. + * + * @param T the non-nullable type of value. + * @param value the value of result. + * @property withAutoDiff The mapping of differentiated variables to their derivatives. + * @property context The field over [T]. + */ +public class DerivationResult( + override val value: T, + private val derivativeValues: Map, + public val context: Field, +) : BoundSymbol { + /** + * Returns derivative of [variable] or returns [Ring.zero] in [context]. + */ + public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero + + /** + * Computes the divergence. + */ + public fun div(): T = context { sum(derivativeValues.values) } +} + +/** + * Computes the gradient for variables in given order. + */ +public fun DerivationResult.grad(vararg variables: Symbol): Point { + check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } + return variables.map(::derivative).asBuffer() +} + +/** + * Runs differentiation and establishes [AutoDiffField] context inside the block of code. + * + * The partial derivatives are placed in argument `d` variable + * + * Example: + * ``` + * val x by symbol // define variable(s) and their values + * val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context + * assertEquals(17.0, y.x) // the value of result (y) + * assertEquals(9.0, x.d) // dy/dx + * ``` + * + * @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to. + * @return the result of differentiation. + */ +public fun > F.withAutoDiff( + bindings: Collection>, + body: AutoDiffField.() -> BoundSymbol, +): DerivationResult { + contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } + + return AutoDiffContext(this, bindings).derivate(body) +} + +public fun > F.withAutoDiff( + vararg bindings: Pair, + body: AutoDiffField.() -> BoundSymbol, +): DerivationResult = withAutoDiff(bindings.map { it.first.bind(it.second) }, body) + +/** + * Represents field in context of which functions can be derived. + */ +public abstract class AutoDiffField> + : Field>, ExpressionAlgebra> { + + public abstract val context: F + + /** + * A variable accessing inner state of derivatives. + * Use this value in inner builders to avoid creating additional derivative bindings. + */ + public abstract var BoundSymbol.d: T + + /** + * Performs update of derivative after the rest of the formula in the back-pass. + * + * For example, implementation of `sin` function is: + * + * ``` + * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result + * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function + * } + * ``` + */ + public abstract fun derive(value: R, block: F.(R) -> Unit): R + + public inline fun const(block: F.() -> T): BoundSymbol = const(context.block()) + + // Overloads for Double constants + + override operator fun Number.plus(b: BoundSymbol): BoundSymbol = + derive(const { this@plus.toDouble() * one + b.value }) { z -> + b.d += z.d + } + + override operator fun BoundSymbol.plus(b: Number): BoundSymbol = b.plus(this) + + override operator fun Number.minus(b: BoundSymbol): BoundSymbol = + derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } + + override operator fun BoundSymbol.minus(b: Number): BoundSymbol = + derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } +} + +/** + * Automatic Differentiation context class. + */ +private class AutoDiffContext>( + override val context: F, + bindings: Collection>, +) : AutoDiffField() { + // this stack contains pairs of blocks and values to apply them to + private var stack: Array = arrayOfNulls(8) + private var sp: Int = 0 + private val derivatives: MutableMap = hashMapOf() + override val zero: BoundSymbol get() = const(context.zero) + override val one: BoundSymbol get() = const(context.one) + + /** + * Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result + * with respect to this variable. + * + * @param T the non-nullable type of value. + * @property value The value of this variable. + */ + private class AutoDiffVariableWithDeriv(override val value: T, var d: T) : BoundSymbol + + private val bindings: Map> = bindings.associateBy { it.identity } + + override fun bindOrNull(symbol: Symbol): BoundSymbol? = bindings[symbol.identity] + + override fun const(value: T): BoundSymbol = AutoDiffVariableWithDeriv(value, context.zero) + + override var BoundSymbol.d: T + get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero + set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value + + @Suppress("UNCHECKED_CAST") + override fun derive(value: R, block: F.(R) -> Unit): R { + // save block to stack for backward pass + if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) + stack[sp++] = block + stack[sp++] = value + return value + } + + @Suppress("UNCHECKED_CAST") + fun runBackwardPass() { + while (sp > 0) { + val value = stack[--sp] + val block = stack[--sp] as F.(Any?) -> Unit + context.block(value) + } + } + + // Basic math (+, -, *, /) + + override fun add(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + derive(const { a.value + b.value }) { z -> + a.d += z.d + b.d += z.d + } + + override fun multiply(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + derive(const { a.value * b.value }) { z -> + a.d += z.d * b.value + b.d += z.d * a.value + } + + override fun divide(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + derive(const { a.value / b.value }) { z -> + a.d += z.d / b.value + b.d -= z.d * a.value / (b.value * b.value) + } + + override fun multiply(a: BoundSymbol, k: Number): BoundSymbol = + derive(const { k.toDouble() * a.value }) { z -> + a.d += z.d * k.toDouble() + } + + inline fun derivate(function: AutoDiffField.() -> BoundSymbol): DerivationResult { + val result = function() + result.d = context.one // computing derivative w.r.t result + runBackwardPass() + return DerivationResult(result.value, derivatives, context) + } +} + +/** + * A constructs that creates a derivative structure with required order on-demand + */ +public class SimpleAutoDiffExpression>( + public val field: F, + public val function: AutoDiffField.() -> BoundSymbol, +) : DifferentiableExpression { + public override operator fun invoke(arguments: Map): T { + val bindings = arguments.entries.map { it.key.bind(it.value) } + return AutoDiffContext(field, bindings).function().value + } + + /** + * Get the derivative expression with given orders + */ + public override fun derivative(orders: Map): Expression { + val dSymbol = orders.entries.singleOrNull { it.value == 1 } + ?: error("SimpleAutoDiff supports only first order derivatives") + return Expression { arguments -> + val bindings = arguments.entries.map { it.key.bind(it.value) } + val derivationResult = AutoDiffContext(field, bindings).derivate(function) + derivationResult.derivative(dSymbol.key) + } + } +} + + +// Extensions for differentiation of various basic mathematical functions + +// x ^ 2 +public fun > AutoDiffField.sqr(x: BoundSymbol): BoundSymbol = + derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } + +// x ^ 1/2 +public fun > AutoDiffField.sqrt(x: BoundSymbol): BoundSymbol = + derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } + +// x ^ y (const) +public fun > AutoDiffField.pow( + x: BoundSymbol, + y: Double, +): BoundSymbol = + derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } + +public fun > AutoDiffField.pow( + x: BoundSymbol, + y: Int, +): BoundSymbol = + pow(x, y.toDouble()) + +// exp(x) +public fun > AutoDiffField.exp(x: BoundSymbol): BoundSymbol = + derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } + +// ln(x) +public fun > AutoDiffField.ln(x: BoundSymbol): BoundSymbol = + derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } + +// x ^ y (any) +public fun > AutoDiffField.pow( + x: BoundSymbol, + y: BoundSymbol, +): BoundSymbol = + exp(y * ln(x)) + +// sin(x) +public fun > AutoDiffField.sin(x: BoundSymbol): BoundSymbol = + derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } + +// cos(x) +public fun > AutoDiffField.cos(x: BoundSymbol): BoundSymbol = + derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } + +public fun > AutoDiffField.tan(x: BoundSymbol): BoundSymbol = + derive(const { tan(x.value) }) { z -> + val c = cos(x.value) + x.d += z.d / (c * c) + } + +public fun > AutoDiffField.asin(x: BoundSymbol): BoundSymbol = + derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } + +public fun > AutoDiffField.acos(x: BoundSymbol): BoundSymbol = + derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } + +public fun > AutoDiffField.atan(x: BoundSymbol): BoundSymbol = + derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } + +public fun > AutoDiffField.sinh(x: BoundSymbol): BoundSymbol = + derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } + +public fun > AutoDiffField.cosh(x: BoundSymbol): BoundSymbol = + derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } + +public fun > AutoDiffField.tanh(x: BoundSymbol): BoundSymbol = + derive(const { tan(x.value) }) { z -> + val c = cosh(x.value) + x.d += z.d / (c * c) + } + +public fun > AutoDiffField.asinh(x: BoundSymbol): BoundSymbol = + derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } + +public fun > AutoDiffField.acosh(x: BoundSymbol): BoundSymbol = + derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } + +public fun > AutoDiffField.atanh(x: BoundSymbol): BoundSymbol = + derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } + diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt deleted file mode 100644 index bfcd5959f..000000000 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt +++ /dev/null @@ -1,266 +0,0 @@ -package kscience.kmath.misc - -import kscience.kmath.linear.Point -import kscience.kmath.operations.* -import kscience.kmath.structures.asBuffer -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract - -/* - * Implementation of backward-mode automatic differentiation. - * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d - */ - -/** - * Differentiable variable with value and derivative of differentiation ([deriv]) result - * with respect to this variable. - * - * @param T the non-nullable type of value. - * @property value The value of this variable. - */ -public open class Variable(public val value: T) - -/** - * Represents result of [deriv] call. - * - * @param T the non-nullable type of value. - * @param value the value of result. - * @property deriv The mapping of differentiated variables to their derivatives. - * @property context The field over [T]. - */ -public class DerivationResult( - value: T, - public val deriv: Map, T>, - public val context: Field -) : Variable(value) { - /** - * Returns derivative of [variable] or returns [Ring.zero] in [context]. - */ - public fun deriv(variable: Variable): T = deriv[variable] ?: context.zero - - /** - * Computes the divergence. - */ - public fun div(): T = context { sum(deriv.values) } - - /** - * Computes the gradient for variables in given order. - */ - public fun grad(vararg variables: Variable): Point { - check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } - return variables.map(::deriv).asBuffer() - } -} - -/** - * Runs differentiation and establishes [AutoDiffField] context inside the block of code. - * - * The partial derivatives are placed in argument `d` variable - * - * Example: - * ``` - * val x = Variable(2) // define variable(s) and their values - * val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context - * assertEquals(17.0, y.x) // the value of result (y) - * assertEquals(9.0, x.d) // dy/dx - * ``` - * - * @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to. - * @return the result of differentiation. - */ -public inline fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult { - contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } - - return (AutoDiffContext(this)) { - val result = body() - result.d = context.one // computing derivative w.r.t result - runBackwardPass() - DerivationResult(result.value, derivatives, this@deriv) - } -} - -/** - * Represents field in context of which functions can be derived. - */ -public abstract class AutoDiffField> : Field> { - public abstract val context: F - - /** - * A variable accessing inner state of derivatives. - * Use this value in inner builders to avoid creating additional derivative bindings. - */ - public abstract var Variable.d: T - - /** - * Performs update of derivative after the rest of the formula in the back-pass. - * - * For example, implementation of `sin` function is: - * - * ``` - * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result - * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function - * } - * ``` - */ - public abstract fun derive(value: R, block: F.(R) -> Unit): R - - /** - * - */ - public abstract fun variable(value: T): Variable - - public inline fun variable(block: F.() -> T): Variable = variable(context.block()) - - // Overloads for Double constants - - override operator fun Number.plus(b: Variable): Variable = - derive(variable { this@plus.toDouble() * one + b.value }) { z -> - b.d += z.d - } - - override operator fun Variable.plus(b: Number): Variable = b.plus(this) - - override operator fun Number.minus(b: Variable): Variable = - derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } - - override operator fun Variable.minus(b: Number): Variable = - derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } -} - -/** - * Automatic Differentiation context class. - */ -@PublishedApi -internal class AutoDiffContext>(override val context: F) : AutoDiffField() { - // this stack contains pairs of blocks and values to apply them to - private var stack: Array = arrayOfNulls(8) - private var sp: Int = 0 - val derivatives: MutableMap, T> = hashMapOf() - override val zero: Variable get() = Variable(context.zero) - override val one: Variable get() = Variable(context.one) - - /** - * A variable coupled with its derivative. For internal use only - */ - private class VariableWithDeriv(x: T, var d: T) : Variable(x) - - override fun variable(value: T): Variable = - VariableWithDeriv(value, context.zero) - - override var Variable.d: T - get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero - set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value - - @Suppress("UNCHECKED_CAST") - override fun derive(value: R, block: F.(R) -> Unit): R { - // save block to stack for backward pass - if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) - stack[sp++] = block - stack[sp++] = value - return value - } - - @Suppress("UNCHECKED_CAST") - fun runBackwardPass() { - while (sp > 0) { - val value = stack[--sp] - val block = stack[--sp] as F.(Any?) -> Unit - context.block(value) - } - } - - // Basic math (+, -, *, /) - - override fun add(a: Variable, b: Variable): Variable = derive(variable { a.value + b.value }) { z -> - a.d += z.d - b.d += z.d - } - - override fun multiply(a: Variable, b: Variable): Variable = derive(variable { a.value * b.value }) { z -> - a.d += z.d * b.value - b.d += z.d * a.value - } - - override fun divide(a: Variable, b: Variable): Variable = derive(variable { a.value / b.value }) { z -> - a.d += z.d / b.value - b.d -= z.d * a.value / (b.value * b.value) - } - - override fun multiply(a: Variable, k: Number): Variable = derive(variable { k.toDouble() * a.value }) { z -> - a.d += z.d * k.toDouble() - } -} - -// Extensions for differentiation of various basic mathematical functions - -// x ^ 2 -public fun > AutoDiffField.sqr(x: Variable): Variable = - derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } - -// x ^ 1/2 -public fun > AutoDiffField.sqrt(x: Variable): Variable = - derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } - -// x ^ y (const) -public fun > AutoDiffField.pow(x: Variable, y: Double): Variable = - derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } - -public fun > AutoDiffField.pow(x: Variable, y: Int): Variable = - pow(x, y.toDouble()) - -// exp(x) -public fun > AutoDiffField.exp(x: Variable): Variable = - derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value } - -// ln(x) -public fun > AutoDiffField.ln(x: Variable): Variable = - derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value } - -// x ^ y (any) -public fun > AutoDiffField.pow(x: Variable, y: Variable): Variable = - exp(y * ln(x)) - -// sin(x) -public fun > AutoDiffField.sin(x: Variable): Variable = - derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } - -// cos(x) -public fun > AutoDiffField.cos(x: Variable): Variable = - derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } - -public fun > AutoDiffField.tan(x: Variable): Variable = - derive(variable { tan(x.value) }) { z -> - val c = cos(x.value) - x.d += z.d / (c * c) - } - -public fun > AutoDiffField.asin(x: Variable): Variable = - derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } - -public fun > AutoDiffField.acos(x: Variable): Variable = - derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } - -public fun > AutoDiffField.atan(x: Variable): Variable = - derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } - -public fun > AutoDiffField.sinh(x: Variable): Variable = - derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } - -public fun > AutoDiffField.cosh(x: Variable): Variable = - derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } - -public fun > AutoDiffField.tanh(x: Variable): Variable = - derive(variable { tan(x.value) }) { z -> - val c = cosh(x.value) - x.d += z.d / (c * c) - } - -public fun > AutoDiffField.asinh(x: Variable): Variable = - derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } - -public fun > AutoDiffField.acosh(x: Variable): Variable = - derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } - -public fun > AutoDiffField.atanh(x: Variable): Variable = - derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } - diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/ExpressionFieldTest.kt index 1d3f520f6..484993eef 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/ExpressionFieldTest.kt @@ -6,19 +6,21 @@ import kscience.kmath.operations.RealField import kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFails class ExpressionFieldTest { + val x by symbol @Test fun testExpression() { val context = FunctionalExpressionField(RealField) val expression = context { - val x = variable("x", 2.0) + val x by binding() x * x + 2 * x + one } - assertEquals(expression("x" to 1.0), 4.0) - assertEquals(expression(), 9.0) + assertEquals(expression(x to 1.0), 4.0) + assertFails { expression()} } @Test @@ -26,33 +28,33 @@ class ExpressionFieldTest { val context = FunctionalExpressionField(ComplexField) val expression = context { - val x = variable("x", Complex(2.0, 0.0)) + val x = bind(x) x * x + 2 * x + one } - assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) - assertEquals(expression(), Complex(9.0, 0.0)) + assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0)) + //assertEquals(expression(), Complex(9.0, 0.0)) } @Test fun separateContext() { fun FunctionalExpressionField.expression(): Expression { - val x = variable("x") + val x by binding() return x * x + 2 * x + one } val expression = FunctionalExpressionField(RealField).expression() - assertEquals(expression("x" to 1.0), 4.0) + assertEquals(expression(x to 1.0), 4.0) } @Test fun valueExpression() { val expressionBuilder: FunctionalExpressionField.() -> Expression = { - val x = variable("x") + val x by binding() x * x + 2 * x + one } val expression = FunctionalExpressionField(RealField).expressionBuilder() - assertEquals(expression("x" to 1.0), 4.0) + assertEquals(expression(x to 1.0), 4.0) } } diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt new file mode 100644 index 000000000..ca5b626fd --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -0,0 +1,277 @@ +package kscience.kmath.expressions + +import kscience.kmath.operations.RealField +import kscience.kmath.structures.asBuffer +import kotlin.math.PI +import kotlin.math.pow +import kotlin.math.sqrt +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class SimpleAutoDiffTest { + fun d( + vararg bindings: Pair, + body: AutoDiffField.() -> BoundSymbol, + ): DerivationResult = RealField.withAutoDiff(bindings = bindings, body) + + fun dx( + xBinding: Pair, + body: AutoDiffField.(x: BoundSymbol) -> BoundSymbol, + ): DerivationResult = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) } + + fun dxy( + xBinding: Pair, + yBinding: Pair, + body: AutoDiffField.(x: BoundSymbol, y: BoundSymbol) -> BoundSymbol, + ): DerivationResult = RealField.withAutoDiff(xBinding, yBinding) { + body(bind(xBinding.first), bind(yBinding.first)) + } + + fun diff(block: AutoDiffField.() -> BoundSymbol): SimpleAutoDiffExpression { + return SimpleAutoDiffExpression(RealField, block) + } + + val x by symbol + val y by symbol + val z by symbol + + @Test + fun testPlusX2() { + val y = d(x to 3.0) { + // diff w.r.t this x at 3 + val x = bind(x) + x + x + } + assertEquals(6.0, y.value) // y = x + x = 6 + assertEquals(2.0, y.derivative(x)) // dy/dx = 2 + } + + @Test + fun testPlus() { + // two variables + val z = d(x to 2.0, y to 3.0) { + val x = bind(x) + val y = bind(y) + x + y + } + assertEquals(5.0, z.value) // z = x + y = 5 + assertEquals(1.0, z.derivative(x)) // dz/dx = 1 + assertEquals(1.0, z.derivative(y)) // dz/dy = 1 + } + + @Test + fun testMinus() { + // two variables + val z = d(x to 7.0, y to 3.0) { + val x = bind(x) + val y = bind(y) + + x - y + } + assertEquals(4.0, z.value) // z = x - y = 4 + assertEquals(1.0, z.derivative(x)) // dz/dx = 1 + assertEquals(-1.0, z.derivative(y)) // dz/dy = -1 + } + + @Test + fun testMulX2() { + val y = dx(x to 3.0) { x -> + // diff w.r.t this x at 3 + x * x + } + assertEquals(9.0, y.value) // y = x * x = 9 + assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7 + } + + @Test + fun testSqr() { + val y = dx(x to 3.0) { x -> sqr(x) } + assertEquals(9.0, y.value) // y = x ^ 2 = 9 + assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7 + } + + @Test + fun testSqrSqr() { + val y = dx(x to 2.0) { x -> sqr(sqr(x)) } + assertEquals(16.0, y.value) // y = x ^ 4 = 16 + assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32 + } + + @Test + fun testX3() { + val y = dx(x to 2.0) { x -> + // diff w.r.t this x at 2 + x * x * x + } + assertEquals(8.0, y.value) // y = x * x * x = 8 + assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12 + } + + @Test + fun testDiv() { + val z = dxy(x to 5.0, y to 2.0) { x, y -> + x / y + } + assertEquals(2.5, z.value) // z = x / y = 2.5 + assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5 + assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25 + } + + @Test + fun testPow3() { + val y = dx(x to 2.0) { x -> + // diff w.r.t this x at 2 + pow(x, 3) + } + assertEquals(8.0, y.value) // y = x ^ 3 = 8 + assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12 + } + + @Test + fun testPowFull() { + val z = dxy(x to 2.0, y to 3.0) { x, y -> + pow(x, y) + } + assertApprox(8.0, z.value) // z = x ^ y = 8 + assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12 + assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x) + } + + @Test + fun testFromPaper() { + val y = dx(x to 3.0) { x -> 2 * x + x * x * x } + assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33 + assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29 + } + + @Test + fun testInnerVariable() { + val y = dx(x to 1.0) { x -> + const(1.0) * x + } + assertEquals(1.0, y.value) // y = x ^ n = 1 + assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1 + } + + @Test + fun testLongChain() { + val n = 10_000 + val y = dx(x to 1.0) { x -> + var res = const(1.0) + for (i in 1..n) res *= x + res + } + assertEquals(1.0, y.value) // y = x ^ n = 1 + assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1 + } + + @Test + fun testExample() { + val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 } + assertEquals(17.0, y.value) // the value of result (y) + assertEquals(9.0, y.derivative(x)) // dy/dx + } + + @Test + fun testSqrt() { + val y = dx(x to 16.0) { x -> sqrt(x) } + assertEquals(4.0, y.value) // y = x ^ 1/2 = 4 + assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8 + } + + @Test + fun testSin() { + val y = dx(x to PI / 6.0) { x -> sin(x) } + assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5 + assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2 + } + + @Test + fun testCos() { + val y = dx(x to PI / 6) { x -> cos(x) } + assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2 + assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5 + } + + @Test + fun testTan() { + val y = dx(x to PI / 6) { x -> tan(x) } + assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3) + assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3 + } + + @Test + fun testAsin() { + val y = dx(x to PI / 6) { x -> asin(x) } + assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6) + assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2) + } + + @Test + fun testAcos() { + val y = dx(x to PI / 6) { x -> acos(x) } + assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6) + assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2) + } + + @Test + fun testAtan() { + val y = dx(x to PI / 6) { x -> atan(x) } + assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6) + assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2) + } + + @Test + fun testSinh() { + val y = dx(x to 0.0) { x -> sinh(x) } + assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0) + assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0) + } + + @Test + fun testCosh() { + val y = dx(x to 0.0) { x -> cosh(x) } + assertApprox(1.0, y.value) //y = cosh(0) + assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0) + } + + @Test + fun testTanh() { + val y = dx(x to PI / 6) { x -> tanh(x) } + assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6) + assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2 + } + + @Test + fun testAsinh() { + val y = dx(x to PI / 6) { x -> asinh(x) } + assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6) + assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36) + } + + @Test + fun testAcosh() { + val y = dx(x to PI / 6) { x -> acosh(x) } + assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6) + assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2) + } + + @Test + fun testAtanh() { + val y = dx(x to PI / 6) { x -> atanh(x) } + assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6) + assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36) + } + + @Test + fun testDivGrad() { + val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y } + assertEquals(6.0, res.div()) + assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer())) + } + + private fun assertApprox(a: Double, b: Double) { + if ((a - b) > 1e-10) assertEquals(a, b) + } +} diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/misc/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/misc/AutoDiffTest.kt deleted file mode 100644 index 3b1813185..000000000 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/misc/AutoDiffTest.kt +++ /dev/null @@ -1,261 +0,0 @@ -package kscience.kmath.misc - -import kscience.kmath.operations.RealField -import kscience.kmath.structures.asBuffer -import kotlin.math.PI -import kotlin.math.pow -import kotlin.math.sqrt -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertTrue - -class AutoDiffTest { - inline fun deriv(body: AutoDiffField.() -> Variable): DerivationResult = - RealField.deriv(body) - - @Test - fun testPlusX2() { - val x = Variable(3.0) // diff w.r.t this x at 3 - val y = deriv { x + x } - assertEquals(6.0, y.value) // y = x + x = 6 - assertEquals(2.0, y.deriv(x)) // dy/dx = 2 - } - - @Test - fun testPlus() { - // two variables - val x = Variable(2.0) - val y = Variable(3.0) - val z = deriv { x + y } - assertEquals(5.0, z.value) // z = x + y = 5 - assertEquals(1.0, z.deriv(x)) // dz/dx = 1 - assertEquals(1.0, z.deriv(y)) // dz/dy = 1 - } - - @Test - fun testMinus() { - // two variables - val x = Variable(7.0) - val y = Variable(3.0) - val z = deriv { x - y } - assertEquals(4.0, z.value) // z = x - y = 4 - assertEquals(1.0, z.deriv(x)) // dz/dx = 1 - assertEquals(-1.0, z.deriv(y)) // dz/dy = -1 - } - - @Test - fun testMulX2() { - val x = Variable(3.0) // diff w.r.t this x at 3 - val y = deriv { x * x } - assertEquals(9.0, y.value) // y = x * x = 9 - assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 - } - - @Test - fun testSqr() { - val x = Variable(3.0) - val y = deriv { sqr(x) } - assertEquals(9.0, y.value) // y = x ^ 2 = 9 - assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 - } - - @Test - fun testSqrSqr() { - val x = Variable(2.0) - val y = deriv { sqr(sqr(x)) } - assertEquals(16.0, y.value) // y = x ^ 4 = 16 - assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32 - } - - @Test - fun testX3() { - val x = Variable(2.0) // diff w.r.t this x at 2 - val y = deriv { x * x * x } - assertEquals(8.0, y.value) // y = x * x * x = 8 - assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12 - } - - @Test - fun testDiv() { - val x = Variable(5.0) - val y = Variable(2.0) - val z = deriv { x / y } - assertEquals(2.5, z.value) // z = x / y = 2.5 - assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5 - assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25 - } - - @Test - fun testPow3() { - val x = Variable(2.0) // diff w.r.t this x at 2 - val y = deriv { pow(x, 3) } - assertEquals(8.0, y.value) // y = x ^ 3 = 8 - assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12 - } - - @Test - fun testPowFull() { - val x = Variable(2.0) - val y = Variable(3.0) - val z = deriv { pow(x, y) } - assertApprox(8.0, z.value) // z = x ^ y = 8 - assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12 - assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x) - } - - @Test - fun testFromPaper() { - val x = Variable(3.0) - val y = deriv { 2 * x + x * x * x } - assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33 - assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29 - } - - @Test - fun testInnerVariable() { - val x = Variable(1.0) - val y = deriv { - Variable(1.0) * x - } - assertEquals(1.0, y.value) // y = x ^ n = 1 - assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 - } - - @Test - fun testLongChain() { - val n = 10_000 - val x = Variable(1.0) - val y = deriv { - var res = Variable(1.0) - for (i in 1..n) res *= x - res - } - assertEquals(1.0, y.value) // y = x ^ n = 1 - assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 - } - - @Test - fun testExample() { - val x = Variable(2.0) - val y = deriv { sqr(x) + 5 * x + 3 } - assertEquals(17.0, y.value) // the value of result (y) - assertEquals(9.0, y.deriv(x)) // dy/dx - } - - @Test - fun testSqrt() { - val x = Variable(16.0) - val y = deriv { sqrt(x) } - assertEquals(4.0, y.value) // y = x ^ 1/2 = 4 - assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8 - } - - @Test - fun testSin() { - val x = Variable(PI / 6.0) - val y = deriv { sin(x) } - assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5 - assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2 - } - - @Test - fun testCos() { - val x = Variable(PI / 6) - val y = deriv { cos(x) } - assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2 - assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5 - } - - @Test - fun testTan() { - val x = Variable(PI / 6) - val y = deriv { tan(x) } - assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3) - assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3 - } - - @Test - fun testAsin() { - val x = Variable(PI / 6) - val y = deriv { asin(x) } - assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6) - assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2) - } - - @Test - fun testAcos() { - val x = Variable(PI / 6) - val y = deriv { acos(x) } - assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6) - assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2) - } - - @Test - fun testAtan() { - val x = Variable(PI / 6) - val y = deriv { atan(x) } - assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6) - assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2) - } - - @Test - fun testSinh() { - val x = Variable(0.0) - val y = deriv { sinh(x) } - assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0) - assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0) - } - - @Test - fun testCosh() { - val x = Variable(0.0) - val y = deriv { cosh(x) } - assertApprox(1.0, y.value) //y = cosh(0) - assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0) - } - - @Test - fun testTanh() { - val x = Variable(PI / 6) - val y = deriv { tanh(x) } - assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6) - assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2 - } - - @Test - fun testAsinh() { - val x = Variable(PI / 6) - val y = deriv { asinh(x) } - assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6) - assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36) - } - - @Test - fun testAcosh() { - val x = Variable(PI / 6) - val y = deriv { acosh(x) } - assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6) - assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2) - } - - @Test - fun testAtanh() { - val x = Variable(PI / 6.0) - val y = deriv { atanh(x) } - assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6) - assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36) - } - - @Test - fun testDivGrad() { - val x = Variable(1.0) - val y = Variable(2.0) - val res = deriv { x * x + y * y } - assertEquals(6.0, res.div()) - assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer())) - } - - private fun assertApprox(a: Double, b: Double) { - if ((a - b) > 1e-10) assertEquals(a, b) - } -} From 6386f2b894cf27f7d257d25d33d7d7b20e679b06 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 20 Oct 2020 10:03:09 +0300 Subject: [PATCH 40/69] Update build tools --- settings.gradle.kts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/settings.gradle.kts b/settings.gradle.kts index 323d5d15f..0f549f9ab 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -10,8 +10,8 @@ pluginManagement { maven("https://dl.bintray.com/kotlin/kotlin-dev/") } - val toolsVersion = "0.6.3-dev-1.4.20-M1" - val kotlinVersion = "1.4.20-M1" + val toolsVersion = "0.6.4-dev-1.4.20-M2" + val kotlinVersion = "1.4.20-M2" plugins { id("kotlinx.benchmark") version "0.2.0-dev-20" @@ -39,6 +39,6 @@ include( ":kmath-for-real", ":kmath-geometry", ":kmath-ast", - ":examples", - ":kmath-ejml" + ":kmath-ejml", + ":examples" ) From ae07652d9ec60ba632d985281c799d9493653e36 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 21 Oct 2020 11:38:28 +0300 Subject: [PATCH 41/69] Symbol identity is always a string --- .../DerivativeStructureExpression.kt | 4 +- .../kscience/kmath/expressions/Expression.kt | 2 +- .../kmath/expressions/SimpleAutoDiff.kt | 140 +++++++++--------- .../kmath/expressions/SimpleAutoDiffTest.kt | 8 +- 4 files changed, 73 insertions(+), 81 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 9a27e40cd..2ec69255e 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -25,13 +25,13 @@ public class DerivativeStructureField( */ public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { - override val identity: Any = symbol.identity + override val identity: String = symbol.identity } /** * Identity-based symbol bindings map */ - private val variables: Map = bindings.entries.associate { (key, value) -> + private val variables: Map = bindings.entries.associate { (key, value) -> key.identity to DerivativeStructureSymbol(key, value) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index d64eb5a55..bd83261f7 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -12,7 +12,7 @@ public interface Symbol { * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol. * By default uses object identity */ - public val identity: Any get() = this + public val identity: String } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index 5e8fe3e99..a718154d3 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -12,20 +12,7 @@ import kotlin.contracts.contract */ -/** - * A [Symbol] with bound value - */ -public interface BoundSymbol : Symbol { - public val value: T -} - -/** - * Bind a [Symbol] to a [value] and produce [BoundSymbol] - */ -public fun Symbol.bind(value: T): BoundSymbol = object : BoundSymbol { - override val identity = this@bind.identity - override val value: T = value -} +public open class AutoDiffValue(public val value: T) /** * Represents result of [withAutoDiff] call. @@ -36,10 +23,10 @@ public fun Symbol.bind(value: T): BoundSymbol = object : BoundSymbol { * @property context The field over [T]. */ public class DerivationResult( - override val value: T, - private val derivativeValues: Map, + public val value: T, + private val derivativeValues: Map, public val context: Field, -) : BoundSymbol { +) { /** * Returns derivative of [variable] or returns [Ring.zero] in [context]. */ @@ -76,8 +63,8 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point> F.withAutoDiff( - bindings: Collection>, - body: AutoDiffField.() -> BoundSymbol, + bindings: Map, + body: AutoDiffField.() -> AutoDiffValue, ): DerivationResult { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } @@ -86,14 +73,14 @@ public fun > F.withAutoDiff( public fun > F.withAutoDiff( vararg bindings: Pair, - body: AutoDiffField.() -> BoundSymbol, -): DerivationResult = withAutoDiff(bindings.map { it.first.bind(it.second) }, body) + body: AutoDiffField.() -> AutoDiffValue, +): DerivationResult = withAutoDiff(bindings.toMap(), body) /** * Represents field in context of which functions can be derived. */ public abstract class AutoDiffField> - : Field>, ExpressionAlgebra> { + : Field>, ExpressionAlgebra> { public abstract val context: F @@ -101,7 +88,7 @@ public abstract class AutoDiffField> * A variable accessing inner state of derivatives. * Use this value in inner builders to avoid creating additional derivative bindings. */ - public abstract var BoundSymbol.d: T + public abstract var AutoDiffValue.d: T /** * Performs update of derivative after the rest of the formula in the back-pass. @@ -116,21 +103,21 @@ public abstract class AutoDiffField> */ public abstract fun derive(value: R, block: F.(R) -> Unit): R - public inline fun const(block: F.() -> T): BoundSymbol = const(context.block()) + public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) // Overloads for Double constants - override operator fun Number.plus(b: BoundSymbol): BoundSymbol = + override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue = derive(const { this@plus.toDouble() * one + b.value }) { z -> b.d += z.d } - override operator fun BoundSymbol.plus(b: Number): BoundSymbol = b.plus(this) + override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this) - override operator fun Number.minus(b: BoundSymbol): BoundSymbol = + override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue = derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } - override operator fun BoundSymbol.minus(b: Number): BoundSymbol = + override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } } @@ -139,14 +126,14 @@ public abstract class AutoDiffField> */ private class AutoDiffContext>( override val context: F, - bindings: Collection>, + bindings: Map, ) : AutoDiffField() { // this stack contains pairs of blocks and values to apply them to private var stack: Array = arrayOfNulls(8) private var sp: Int = 0 - private val derivatives: MutableMap = hashMapOf() - override val zero: BoundSymbol get() = const(context.zero) - override val one: BoundSymbol get() = const(context.one) + private val derivatives: MutableMap, T> = hashMapOf() + override val zero: AutoDiffValue get() = const(context.zero) + override val one: AutoDiffValue get() = const(context.one) /** * Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result @@ -155,17 +142,23 @@ private class AutoDiffContext>( * @param T the non-nullable type of value. * @property value The value of this variable. */ - private class AutoDiffVariableWithDeriv(override val value: T, var d: T) : BoundSymbol + private class AutoDiffVariableWithDeriv( + override val identity: String, + value: T, + var d: T, + ) : AutoDiffValue(value), Symbol - private val bindings: Map> = bindings.associateBy { it.identity } + private val bindings: Map> = bindings.entries.associate { + it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero) + } - override fun bindOrNull(symbol: Symbol): BoundSymbol? = bindings[symbol.identity] + override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv? = bindings[symbol.identity] - override fun const(value: T): BoundSymbol = AutoDiffVariableWithDeriv(value, context.zero) + override fun const(value: T): AutoDiffValue = AutoDiffValue(value) - override var BoundSymbol.d: T - get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero - set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value + override var AutoDiffValue.d: T + get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero + set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value @Suppress("UNCHECKED_CAST") override fun derive(value: R, block: F.(R) -> Unit): R { @@ -187,34 +180,34 @@ private class AutoDiffContext>( // Basic math (+, -, *, /) - override fun add(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value + b.value }) { z -> a.d += z.d b.d += z.d } - override fun multiply(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value * b.value }) { z -> a.d += z.d * b.value b.d += z.d * a.value } - override fun divide(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value / b.value }) { z -> a.d += z.d / b.value b.d -= z.d * a.value / (b.value * b.value) } - override fun multiply(a: BoundSymbol, k: Number): BoundSymbol = + override fun multiply(a: AutoDiffValue, k: Number): AutoDiffValue = derive(const { k.toDouble() * a.value }) { z -> a.d += z.d * k.toDouble() } - inline fun derivate(function: AutoDiffField.() -> BoundSymbol): DerivationResult { + inline fun derivate(function: AutoDiffField.() -> AutoDiffValue): DerivationResult { val result = function() result.d = context.one // computing derivative w.r.t result runBackwardPass() - return DerivationResult(result.value, derivatives, context) + return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) } } @@ -223,11 +216,11 @@ private class AutoDiffContext>( */ public class SimpleAutoDiffExpression>( public val field: F, - public val function: AutoDiffField.() -> BoundSymbol, + public val function: AutoDiffField.() -> AutoDiffValue, ) : DifferentiableExpression { public override operator fun invoke(arguments: Map): T { - val bindings = arguments.entries.map { it.key.bind(it.value) } - return AutoDiffContext(field, bindings).function().value + //val bindings = arguments.entries.map { it.key.bind(it.value) } + return AutoDiffContext(field, arguments).function().value } /** @@ -237,8 +230,8 @@ public class SimpleAutoDiffExpression>( val dSymbol = orders.entries.singleOrNull { it.value == 1 } ?: error("SimpleAutoDiff supports only first order derivatives") return Expression { arguments -> - val bindings = arguments.entries.map { it.key.bind(it.value) } - val derivationResult = AutoDiffContext(field, bindings).derivate(function) + //val bindings = arguments.entries.map { it.key.bind(it.value) } + val derivationResult = AutoDiffContext(field, arguments).derivate(function) derivationResult.derivative(dSymbol.key) } } @@ -248,82 +241,81 @@ public class SimpleAutoDiffExpression>( // Extensions for differentiation of various basic mathematical functions // x ^ 2 -public fun > AutoDiffField.sqr(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue = derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } // x ^ 1/2 -public fun > AutoDiffField.sqrt(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue = derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } // x ^ y (const) public fun > AutoDiffField.pow( - x: BoundSymbol, + x: AutoDiffValue, y: Double, -): BoundSymbol = +): AutoDiffValue = derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } public fun > AutoDiffField.pow( - x: BoundSymbol, + x: AutoDiffValue, y: Int, -): BoundSymbol = - pow(x, y.toDouble()) +): AutoDiffValue = pow(x, y.toDouble()) // exp(x) -public fun > AutoDiffField.exp(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.exp(x: AutoDiffValue): AutoDiffValue = derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } // ln(x) -public fun > AutoDiffField.ln(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.ln(x: AutoDiffValue): AutoDiffValue = derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } // x ^ y (any) public fun > AutoDiffField.pow( - x: BoundSymbol, - y: BoundSymbol, -): BoundSymbol = + x: AutoDiffValue, + y: AutoDiffValue, +): AutoDiffValue = exp(y * ln(x)) // sin(x) -public fun > AutoDiffField.sin(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sin(x: AutoDiffValue): AutoDiffValue = derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } // cos(x) -public fun > AutoDiffField.cos(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.cos(x: AutoDiffValue): AutoDiffValue = derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } -public fun > AutoDiffField.tan(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.tan(x: AutoDiffValue): AutoDiffValue = derive(const { tan(x.value) }) { z -> val c = cos(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asin(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.asin(x: AutoDiffValue): AutoDiffValue = derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.acos(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.acos(x: AutoDiffValue): AutoDiffValue = derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.atan(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.atan(x: AutoDiffValue): AutoDiffValue = derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } -public fun > AutoDiffField.sinh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue = derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } -public fun > AutoDiffField.cosh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue = derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } -public fun > AutoDiffField.tanh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue = derive(const { tan(x.value) }) { z -> val c = cosh(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asinh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue = derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } -public fun > AutoDiffField.acosh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue = derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } -public fun > AutoDiffField.atanh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue = derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt index ca5b626fd..ef4a6a06a 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -12,23 +12,23 @@ import kotlin.test.assertTrue class SimpleAutoDiffTest { fun d( vararg bindings: Pair, - body: AutoDiffField.() -> BoundSymbol, + body: AutoDiffField.() -> AutoDiffValue, ): DerivationResult = RealField.withAutoDiff(bindings = bindings, body) fun dx( xBinding: Pair, - body: AutoDiffField.(x: BoundSymbol) -> BoundSymbol, + body: AutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) } fun dxy( xBinding: Pair, yBinding: Pair, - body: AutoDiffField.(x: BoundSymbol, y: BoundSymbol) -> BoundSymbol, + body: AutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.withAutoDiff(xBinding, yBinding) { body(bind(xBinding.first), bind(yBinding.first)) } - fun diff(block: AutoDiffField.() -> BoundSymbol): SimpleAutoDiffExpression { + fun diff(block: AutoDiffField.() -> AutoDiffValue): SimpleAutoDiffExpression { return SimpleAutoDiffExpression(RealField, block) } From 04d3f4a99f313132c451697099994d64b1ae1453 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 22 Oct 2020 09:28:18 +0300 Subject: [PATCH 42/69] Fix ASM --- gradle/wrapper/gradle-wrapper.properties | 2 +- .../kscience/kmath/asm/internal/AsmBuilder.kt | 16 ++++------------ .../kscience/kmath/asm/internal/mapIntrinsics.kt | 7 +++---- .../kscience/kmath/expressions/SimpleAutoDiff.kt | 4 +++- 4 files changed, 11 insertions(+), 18 deletions(-) diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 12d38de6a..be52383ef 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.6.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index 06f02a94d..a1e482103 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -25,7 +25,7 @@ internal class AsmBuilder internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String, - private val invokeLabel0Visitor: AsmBuilder.() -> Unit + private val invokeLabel0Visitor: AsmBuilder.() -> Unit, ) { /** * Internal classloader of [AsmBuilder] with alias to define class from byte array. @@ -379,22 +379,14 @@ internal class AsmBuilder internal constructor( * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be * provided. */ - internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, MAP_TYPE) aconst(name) - if (defaultValue != null) - loadTConstant(defaultValue) - invokestatic( MAP_INTRINSICS_TYPE.internalName, "getOrFail", - - Type.getMethodDescriptor( - OBJECT_TYPE, - MAP_TYPE, - OBJECT_TYPE, - *OBJECT_TYPE.wrapToArrayIf { defaultValue != null }), + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), false ) @@ -429,7 +421,7 @@ internal class AsmBuilder internal constructor( method: String, descriptor: String, expectedArity: Int, - opcode: Int = INVOKEINTERFACE + opcode: Int = INVOKEINTERFACE, ) { run loop@{ repeat(expectedArity) { diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt index 09e9a71b0..588b9611a 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt @@ -3,12 +3,11 @@ package kscience.kmath.asm.internal import kscience.kmath.expressions.StringSymbol +import kscience.kmath.expressions.Symbol /** - * Gets value with given [key] or throws [IllegalStateException] whenever it is not present. + * Gets value with given [key] or throws [NoSuchElementException] whenever it is not present. * * @author Iaroslav Postovalov */ -@JvmOverloads -internal fun Map.getOrFail(key: K, default: V? = null): V = - this[StringSymbol(key.toString())] ?: default ?: error("Parameter not found: $key") +internal fun Map.getOrFail(key: String): V = getValue(StringSymbol(key)) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index a718154d3..af7c8fbf2 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -146,7 +146,9 @@ private class AutoDiffContext>( override val identity: String, value: T, var d: T, - ) : AutoDiffValue(value), Symbol + ) : AutoDiffValue(value), Symbol{ + override fun toString(): String = identity + } private val bindings: Map> = bindings.entries.associate { it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero) From f7614da230a9e1491263b71b955110382f58a9e4 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 22 Oct 2020 11:27:08 +0300 Subject: [PATCH 43/69] Refactoring --- .../DerivativeStructureExpression.kt | 3 +++ .../expressions/DifferentiableExpression.kt | 21 ++++++++++++++++ .../kscience/kmath/expressions/Expression.kt | 19 +++----------- .../FunctionalExpressionAlgebra.kt | 5 ++-- .../kmath/expressions/SimpleAutoDiff.kt | 14 ++++++----- .../kmath/expressions/SimpleAutoDiffTest.kt | 25 ++++++++++++------- 6 files changed, 55 insertions(+), 32 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 2ec69255e..a1ee91419 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -26,6 +26,9 @@ public class DerivativeStructureField( public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { override val identity: String = symbol.identity + override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt new file mode 100644 index 000000000..841531d01 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -0,0 +1,21 @@ +package kscience.kmath.expressions + +/** + * And object that could be differentiated + */ +public interface Differentiable { + public fun derivative(orders: Map): T +} + +public interface DifferentiableExpression : Differentiable>, Expression + +public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = + derivative(mapOf(*orders)) + +public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1) + +public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) + +//public interface DifferentiableExpressionBuilder>: ExpressionBuilder { +// public override fun expression(block: A.() -> E): DifferentiableExpression +//} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index bd83261f7..7da5a2529 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -56,21 +56,6 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) -/** - * And object that could be differentiated - */ -public interface Differentiable { - public fun derivative(orders: Map): T -} - -public interface DifferentiableExpression : Differentiable>, Expression - -public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = - derivative(mapOf(*orders)) - -public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1) - -public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) /** * A context for expression construction @@ -96,6 +81,10 @@ public interface ExpressionAlgebra : Algebra { public fun const(value: T): E } +//public interface ExpressionBuilder> { +// public fun expression(block: A.() -> E): Expression +//} + /** * Bind a given [Symbol] to this context variable and produce context-specific object. */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 9fd15238a..0630e8e4b 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -7,8 +7,9 @@ import kscience.kmath.operations.* * * @param algebra The algebra to provide for Expressions built. */ -public abstract class FunctionalExpressionAlgebra>(public val algebra: A) : - ExpressionAlgebra> { +public abstract class FunctionalExpressionAlgebra>( + public val algebra: A, +) : ExpressionAlgebra> { /** * Builds an Expression of constant expression which does not depend on arguments. */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index af7c8fbf2..e5ea33c81 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -15,11 +15,11 @@ import kotlin.contracts.contract public open class AutoDiffValue(public val value: T) /** - * Represents result of [withAutoDiff] call. + * Represents result of [simpleAutoDiff] call. * * @param T the non-nullable type of value. * @param value the value of result. - * @property withAutoDiff The mapping of differentiated variables to their derivatives. + * @property simpleAutoDiff The mapping of differentiated variables to their derivatives. * @property context The field over [T]. */ public class DerivationResult( @@ -62,7 +62,7 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point> F.withAutoDiff( +public fun > F.simpleAutoDiff( bindings: Map, body: AutoDiffField.() -> AutoDiffValue, ): DerivationResult { @@ -71,10 +71,10 @@ public fun > F.withAutoDiff( return AutoDiffContext(this, bindings).derivate(body) } -public fun > F.withAutoDiff( +public fun > F.simpleAutoDiff( vararg bindings: Pair, body: AutoDiffField.() -> AutoDiffValue, -): DerivationResult = withAutoDiff(bindings.toMap(), body) +): DerivationResult = simpleAutoDiff(bindings.toMap(), body) /** * Represents field in context of which functions can be derived. @@ -136,7 +136,7 @@ private class AutoDiffContext>( override val one: AutoDiffValue get() = const(context.one) /** - * Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result + * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result * with respect to this variable. * * @param T the non-nullable type of value. @@ -148,6 +148,8 @@ private class AutoDiffContext>( var d: T, ) : AutoDiffValue(value), Symbol{ override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() } private val bindings: Map> = bindings.entries.associate { diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt index ef4a6a06a..ca8ec1e17 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -10,21 +10,17 @@ import kotlin.test.assertEquals import kotlin.test.assertTrue class SimpleAutoDiffTest { - fun d( - vararg bindings: Pair, - body: AutoDiffField.() -> AutoDiffValue, - ): DerivationResult = RealField.withAutoDiff(bindings = bindings, body) fun dx( xBinding: Pair, body: AutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, - ): DerivationResult = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) } + ): DerivationResult = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) } fun dxy( xBinding: Pair, yBinding: Pair, body: AutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, - ): DerivationResult = RealField.withAutoDiff(xBinding, yBinding) { + ): DerivationResult = RealField.simpleAutoDiff(xBinding, yBinding) { body(bind(xBinding.first), bind(yBinding.first)) } @@ -38,7 +34,7 @@ class SimpleAutoDiffTest { @Test fun testPlusX2() { - val y = d(x to 3.0) { + val y = RealField.simpleAutoDiff(x to 3.0) { // diff w.r.t this x at 3 val x = bind(x) x + x @@ -47,10 +43,21 @@ class SimpleAutoDiffTest { assertEquals(2.0, y.derivative(x)) // dy/dx = 2 } + @Test + fun testPlusX2Expr() { + val expr = diff{ + val x = bind(x) + x + x + } + assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6 + assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2 + } + + @Test fun testPlus() { // two variables - val z = d(x to 2.0, y to 3.0) { + val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) { val x = bind(x) val y = bind(y) x + y @@ -63,7 +70,7 @@ class SimpleAutoDiffTest { @Test fun testMinus() { // two variables - val z = d(x to 7.0, y to 3.0) { + val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) { val x = bind(x) val y = bind(y) From 94df61cd439f41873d407902566e0ea279fa1330 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 24 Oct 2020 13:05:36 +0300 Subject: [PATCH 44/69] cleanup --- .../kotlin/kscience/kmath/expressions/Expression.kt | 9 +-------- .../kscience/kmath/expressions/expressionBuilders.kt | 6 ++++++ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index 7da5a2529..b523d99b1 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -10,7 +10,6 @@ import kotlin.properties.ReadOnlyProperty public interface Symbol { /** * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol. - * By default uses object identity */ public val identity: String } @@ -33,8 +32,6 @@ public fun interface Expression { * @return the value. */ public operator fun invoke(arguments: Map): T - - public companion object } /** @@ -81,17 +78,13 @@ public interface ExpressionAlgebra : Algebra { public fun const(value: T): E } -//public interface ExpressionBuilder> { -// public fun expression(block: A.() -> E): Expression -//} - /** * Bind a given [Symbol] to this context variable and produce context-specific object. */ public fun ExpressionAlgebra.bind(symbol: Symbol): E = bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this") -public val symbol: ReadOnlyProperty = ReadOnlyProperty { _, property -> +public val symbol: ReadOnlyProperty = ReadOnlyProperty { _, property -> StringSymbol(property.name) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt index 1702a5921..defbb14ad 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt @@ -7,6 +7,12 @@ import kscience.kmath.operations.Space import kotlin.contracts.InvocationKind import kotlin.contracts.contract + +//public interface ExpressionBuilder> { +// public fun expression(block: A.() -> E): Expression +//} + + /** * Creates a functional expression with this [Space]. */ From d826dd9e8311d04253ffc08ab47d01a02e37d2dc Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 24 Oct 2020 20:33:19 +0300 Subject: [PATCH 45/69] Initial optimization implementation for CM --- CHANGELOG.md | 4 + .../kmath/commons/optimization/optimize.kt | 103 ++++++++++++++++++ .../commons/optimization/OptimizeTest.kt | 37 +++++++ .../kscience/kmath/expressions/Expression.kt | 2 +- .../kmath/expressions/SymbolIndexer.kt | 45 ++++++++ 5 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt create mode 100644 kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt create mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 89e02d3b1..109168475 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ - Automatic README generation for features (#139) - Native support for `memory`, `core` and `dimensions` - `kmath-ejml` to supply EJML SimpleMatrix wrapper. +- A separate `Symbol` entity, which is used for global unbound symbol. +- A `Symbol` indexing scope. +- Basic optimization API for Commons-math. ### Changed - Package changed from `scientifik` to `kscience.kmath`. @@ -16,6 +19,7 @@ - `Polynomial` secondary constructor made function. - Kotlin version: 1.3.72 -> 1.4.20-M1 - `kmath-ast` doesn't depend on heavy `kotlin-reflect` library. +- Full autodiff refactoring based on `Symbol` ### Deprecated diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt new file mode 100644 index 000000000..3bf6354ea --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt @@ -0,0 +1,103 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.expressions.* +import org.apache.commons.math3.optim.* +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType +import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient +import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer + +public typealias ParameterSpacePoint = Map + +public class OptimizationResult(public val point: ParameterSpacePoint, public val value: Double) + +public operator fun PointValuePair.component1(): DoubleArray = point +public operator fun PointValuePair.component2(): Double = value + +public object Optimization { + public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_MAX_ITER: Int = 1000 +} + + +private fun SymbolIndexer.objectiveFunction(expression: Expression) = ObjectiveFunction { + val args = it.toMap() + expression(args) +} + +private fun SymbolIndexer.objectiveFunctionGradient( + expression: DifferentiableExpression, +) = ObjectiveFunctionGradient { + val args = it.toMap() + DoubleArray(symbols.size) { index -> + expression.derivative(symbols[index])(args) + } +} + +private fun SymbolIndexer.initialGuess(point: ParameterSpacePoint) = InitialGuess(point.toArray()) + +/** + * Optimize expression without derivatives + */ +public fun Expression.optimize( + startingPoint: ParameterSpacePoint, + goalType: GoalType = GoalType.MAXIMIZE, + vararg additionalArguments: OptimizationData, + optimizerBuilder: () -> MultivariateOptimizer = { + SimplexOptimizer( + SimpleValueChecker( + Optimization.DEFAULT_RELATIVE_TOLERANCE, + Optimization.DEFAULT_ABSOLUTE_TOLERANCE, + Optimization.DEFAULT_MAX_ITER + ) + ) + }, +): OptimizationResult = withSymbols(startingPoint.keys) { + val optimizer = optimizerBuilder() + val objectiveFunction = objectiveFunction(this@optimize) + val (point, value) = optimizer.optimize( + objectiveFunction, + initialGuess(startingPoint), + goalType, + MaxEval.unlimited(), + NelderMeadSimplex(symbols.size, 1.0), + *additionalArguments + ) + OptimizationResult(point.toMap(), value) +} + +/** + * Optimize differentiable expression + */ +public fun DifferentiableExpression.optimize( + startingPoint: ParameterSpacePoint, + goalType: GoalType = GoalType.MAXIMIZE, + vararg additionalArguments: OptimizationData, + optimizerBuilder: () -> NonLinearConjugateGradientOptimizer = { + NonLinearConjugateGradientOptimizer( + NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, + SimpleValueChecker( + Optimization.DEFAULT_RELATIVE_TOLERANCE, + Optimization.DEFAULT_ABSOLUTE_TOLERANCE, + Optimization.DEFAULT_MAX_ITER + ) + ) + }, +): OptimizationResult = withSymbols(startingPoint.keys) { + val optimizer = optimizerBuilder() + val objectiveFunction = objectiveFunction(this@optimize) + val objectiveGradient = objectiveFunctionGradient(this@optimize) + val (point, value) = optimizer.optimize( + objectiveFunction, + objectiveGradient, + initialGuess(startingPoint), + goalType, + MaxEval.unlimited(), + *additionalArguments + ) + OptimizationResult(point.toMap(), value) +} \ No newline at end of file diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt new file mode 100644 index 000000000..779f37dad --- /dev/null +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -0,0 +1,37 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.commons.expressions.DerivativeStructureExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol +import kscience.kmath.expressions.symbol +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer +import org.junit.jupiter.api.Test + +internal class OptimizeTest { + val x by symbol + val y by symbol + + val normal = DerivativeStructureExpression { + val x = bind(x) + val y = bind(y) + exp(-x.pow(2)/2) + exp(-y.pow(2)/2) + } + + val startingPoint: Map = mapOf(x to 1.0, y to 1.0) + + @Test + fun testOptimization() { + val result = normal.optimize(startingPoint) + println(result.point) + println(result.value) + } + + @Test + fun testSimplexOptimization() { + val result = (normal as Expression).optimize(startingPoint){ + SimplexOptimizer(1e-4,1e-4) + } + println(result.point) + println(result.value) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index b523d99b1..7e1eb0cd7 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -35,7 +35,7 @@ public fun interface Expression { } /** - * Invlode an expression without parameters + * Invoke an expression without parameters */ public operator fun Expression.invoke(): T = invoke(emptyMap()) //This method exists to avoid resolution ambiguity of vararg methods diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt new file mode 100644 index 000000000..aef30c6dd --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt @@ -0,0 +1,45 @@ +package kscience.kmath.expressions + +/** + * An environment to easy transform indexed variables to symbols and back. + */ +public interface SymbolIndexer { + public val symbols: List + public fun indexOf(symbol: Symbol): Int = symbols.indexOf(symbol) + + public operator fun List.get(symbol: Symbol): T { + require(size == symbols.size) { "The input list size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + + public operator fun Array.get(symbol: Symbol): T { + require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + + public operator fun DoubleArray.get(symbol: Symbol): Double { + require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + + public fun DoubleArray.toMap(): Map { + require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } + return symbols.indices.associate { symbols[it] to get(it) } + } + + + public fun Map.toList(): List = symbols.map { getValue(it) } + + public fun Map.toArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) } +} + +public inline class SimpleSymbolIndexer(override val symbols: List) : SymbolIndexer + +/** + * Execute the block with symbol indexer based on given symbol order + */ +public inline fun withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R = + with(SimpleSymbolIndexer(symbols.toList()), block) + +public inline fun withSymbols(symbols: Collection, block: SymbolIndexer.() -> R): R = + with(SimpleSymbolIndexer(symbols.toList()), block) \ No newline at end of file From 1fbe12149dfeae360f93e1c7d2c5d2da29aaa5eb Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 25 Oct 2020 19:31:12 +0300 Subject: [PATCH 46/69] Advanced configuration API for cm-optimization --- .space.kts | 4 +- .../DerivativeStructureExpression.kt | 2 +- .../optimization/CMOptimizationProblem.kt | 100 +++++++++++++++++ .../optimization/OptimizationProblem.kt | 17 +++ .../kmath/commons/optimization/optimize.kt | 105 +++--------------- .../commons/optimization/OptimizeTest.kt | 18 +-- .../expressions/DifferentiableExpression.kt | 21 +++- .../kmath/expressions/SimpleAutoDiff.kt | 17 +-- .../kmath/expressions/SymbolIndexer.kt | 18 ++- 9 files changed, 188 insertions(+), 114 deletions(-) create mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt create mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt diff --git a/.space.kts b/.space.kts index 9dda0cbf7..d70ad6d59 100644 --- a/.space.kts +++ b/.space.kts @@ -1 +1,3 @@ -job("Build") { gradlew("openjdk:11", "build") } +job("Build") { + gradlew("openjdk:11", "build") +} diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index a1ee91419..376fea7a3 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -106,7 +106,7 @@ public class DerivativeStructureExpression( /** * Get the derivative expression with given orders */ - public override fun derivative(orders: Map): Expression = Expression { arguments -> + public override fun derivativeOrNull(orders: Map): Expression = Expression { arguments -> with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) } } } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt new file mode 100644 index 000000000..f7c136ed2 --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -0,0 +1,100 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.expressions.* +import org.apache.commons.math3.optim.* +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType +import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient +import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer +import kotlin.reflect.KClass + +public operator fun PointValuePair.component1(): DoubleArray = point +public operator fun PointValuePair.component2(): Double = value + +public class CMOptimizationProblem( + override val symbols: List, +) : OptimizationProblem, SymbolIndexer { + protected val optimizationData: HashMap, OptimizationData> = HashMap() + private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null + + public var convergenceChecker: ConvergenceChecker = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE, + DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER) + + private fun addOptimizationData(data: OptimizationData) { + optimizationData[data::class] = data + } + + init { + addOptimizationData(MaxEval.unlimited()) + } + + public fun initialGuess(map: Map): Unit { + addOptimizationData(InitialGuess(map.toDoubleArray())) + } + + public fun expression(expression: Expression): Unit { + val objectiveFunction = ObjectiveFunction { + val args = it.toMap() + expression(args) + } + addOptimizationData(objectiveFunction) + } + + public fun derivatives(expression: DifferentiableExpression): Unit { + expression(expression) + val gradientFunction = ObjectiveFunctionGradient { + val args = it.toMap() + DoubleArray(symbols.size) { index -> + expression.derivative(symbols[index])(args) + } + } + addOptimizationData(gradientFunction) + if (optimizatorBuilder == null) { + optimizatorBuilder = { + NonLinearConjugateGradientOptimizer( + NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, + convergenceChecker + ) + } + } + } + + public fun simplex(simplex: AbstractSimplex) { + addOptimizationData(simplex) + //Set optimization builder to simplex if it is not present + if (optimizatorBuilder == null) { + optimizatorBuilder = { SimplexOptimizer(convergenceChecker) } + } + } + + public fun simplexSteps(steps: Map) { + simplex(NelderMeadSimplex(steps.toDoubleArray())) + } + + public fun goal(goalType: GoalType) { + addOptimizationData(goalType) + } + + public fun optimizer(block: () -> MultivariateOptimizer) { + optimizatorBuilder = block + } + + override fun optimize(): OptimizationResult { + val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined") + val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray()) + return OptimizationResult(point.toMap(), value) + } + + public companion object { + public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_MAX_ITER: Int = 1000 + } +} + +public fun CMOptimizationProblem.initialGuess(vararg pairs: Pair): Unit = initialGuess(pairs.toMap()) +public fun CMOptimizationProblem.simplexSteps(vararg pairs: Pair): Unit = simplexSteps(pairs.toMap()) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt new file mode 100644 index 000000000..56291e09c --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt @@ -0,0 +1,17 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.expressions.Symbol +import kotlin.reflect.KClass + +public typealias ParameterSpacePoint = Map + +public class OptimizationResult( + public val point: ParameterSpacePoint, + public val value: T, + public val extra: Map, Any> = emptyMap() +) + +public interface OptimizationProblem { + public fun optimize(): OptimizationResult +} + diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt index 3bf6354ea..a49949b93 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt @@ -1,103 +1,32 @@ package kscience.kmath.commons.optimization -import kscience.kmath.expressions.* -import org.apache.commons.math3.optim.* -import org.apache.commons.math3.optim.nonlinear.scalar.GoalType -import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer -import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction -import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient -import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer -import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex -import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol -public typealias ParameterSpacePoint = Map - -public class OptimizationResult(public val point: ParameterSpacePoint, public val value: Double) - -public operator fun PointValuePair.component1(): DoubleArray = point -public operator fun PointValuePair.component2(): Double = value - -public object Optimization { - public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 - public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 - public const val DEFAULT_MAX_ITER: Int = 1000 -} - - -private fun SymbolIndexer.objectiveFunction(expression: Expression) = ObjectiveFunction { - val args = it.toMap() - expression(args) -} - -private fun SymbolIndexer.objectiveFunctionGradient( - expression: DifferentiableExpression, -) = ObjectiveFunctionGradient { - val args = it.toMap() - DoubleArray(symbols.size) { index -> - expression.derivative(symbols[index])(args) - } -} - -private fun SymbolIndexer.initialGuess(point: ParameterSpacePoint) = InitialGuess(point.toArray()) /** * Optimize expression without derivatives */ public fun Expression.optimize( - startingPoint: ParameterSpacePoint, - goalType: GoalType = GoalType.MAXIMIZE, - vararg additionalArguments: OptimizationData, - optimizerBuilder: () -> MultivariateOptimizer = { - SimplexOptimizer( - SimpleValueChecker( - Optimization.DEFAULT_RELATIVE_TOLERANCE, - Optimization.DEFAULT_ABSOLUTE_TOLERANCE, - Optimization.DEFAULT_MAX_ITER - ) - ) - }, -): OptimizationResult = withSymbols(startingPoint.keys) { - val optimizer = optimizerBuilder() - val objectiveFunction = objectiveFunction(this@optimize) - val (point, value) = optimizer.optimize( - objectiveFunction, - initialGuess(startingPoint), - goalType, - MaxEval.unlimited(), - NelderMeadSimplex(symbols.size, 1.0), - *additionalArguments - ) - OptimizationResult(point.toMap(), value) + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) + problem.expression(this) + return problem.optimize() } /** * Optimize differentiable expression */ public fun DifferentiableExpression.optimize( - startingPoint: ParameterSpacePoint, - goalType: GoalType = GoalType.MAXIMIZE, - vararg additionalArguments: OptimizationData, - optimizerBuilder: () -> NonLinearConjugateGradientOptimizer = { - NonLinearConjugateGradientOptimizer( - NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, - SimpleValueChecker( - Optimization.DEFAULT_RELATIVE_TOLERANCE, - Optimization.DEFAULT_ABSOLUTE_TOLERANCE, - Optimization.DEFAULT_MAX_ITER - ) - ) - }, -): OptimizationResult = withSymbols(startingPoint.keys) { - val optimizer = optimizerBuilder() - val objectiveFunction = objectiveFunction(this@optimize) - val objectiveGradient = objectiveFunctionGradient(this@optimize) - val (point, value) = optimizer.optimize( - objectiveFunction, - objectiveGradient, - initialGuess(startingPoint), - goalType, - MaxEval.unlimited(), - *additionalArguments - ) - OptimizationResult(point.toMap(), value) + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) + problem.derivatives(this) + return problem.optimize() } \ No newline at end of file diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index 779f37dad..65d61dcd1 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -1,10 +1,7 @@ package kscience.kmath.commons.optimization import kscience.kmath.commons.expressions.DerivativeStructureExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.Symbol import kscience.kmath.expressions.symbol -import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer import org.junit.jupiter.api.Test internal class OptimizeTest { @@ -14,22 +11,25 @@ internal class OptimizeTest { val normal = DerivativeStructureExpression { val x = bind(x) val y = bind(y) - exp(-x.pow(2)/2) + exp(-y.pow(2)/2) + exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2) } - val startingPoint: Map = mapOf(x to 1.0, y to 1.0) - @Test fun testOptimization() { - val result = normal.optimize(startingPoint) + val result = normal.optimize(x, y) { + initialGuess(x to 1.0, y to 1.0) + //no need to select optimizer. Gradient optimizer is used by default + } println(result.point) println(result.value) } @Test fun testSimplexOptimization() { - val result = (normal as Expression).optimize(startingPoint){ - SimplexOptimizer(1e-4,1e-4) + val result = normal.optimize(x, y) { + initialGuess(x to 1.0, y to 1.0) + simplexSteps(x to 2.0, y to 0.5) + //this sets simplex optimizer } println(result.point) println(result.value) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index 841531d01..5fe31caca 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -4,9 +4,15 @@ package kscience.kmath.expressions * And object that could be differentiated */ public interface Differentiable { - public fun derivative(orders: Map): T + public fun derivativeOrNull(orders: Map): T? } +public fun Differentiable.derivative(orders: Map): T = + derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided") + +/** + * An expression that provid + */ public interface DifferentiableExpression : Differentiable>, Expression public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = @@ -14,8 +20,19 @@ public fun DifferentiableExpression.derivative(vararg orders: Pair DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1) -public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) +public fun DifferentiableExpression.derivative(name: String): Expression = + derivative(StringSymbol(name) to 1) //public interface DifferentiableExpressionBuilder>: ExpressionBuilder { // public override fun expression(block: A.() -> E): DifferentiableExpression //} + +public abstract class FirstDerivativeExpression : DifferentiableExpression { + + public abstract fun derivativeOrNull(symbol: Symbol): Expression? + + public override fun derivativeOrNull(orders: Map): Expression? { + val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null + return derivativeOrNull(dSymbol) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index e5ea33c81..6231a40c1 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -221,23 +221,16 @@ private class AutoDiffContext>( public class SimpleAutoDiffExpression>( public val field: F, public val function: AutoDiffField.() -> AutoDiffValue, -) : DifferentiableExpression { +) : FirstDerivativeExpression() { public override operator fun invoke(arguments: Map): T { //val bindings = arguments.entries.map { it.key.bind(it.value) } return AutoDiffContext(field, arguments).function().value } - /** - * Get the derivative expression with given orders - */ - public override fun derivative(orders: Map): Expression { - val dSymbol = orders.entries.singleOrNull { it.value == 1 } - ?: error("SimpleAutoDiff supports only first order derivatives") - return Expression { arguments -> - //val bindings = arguments.entries.map { it.key.bind(it.value) } - val derivationResult = AutoDiffContext(field, arguments).derivate(function) - derivationResult.derivative(dSymbol.key) - } + override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> + //val bindings = arguments.entries.map { it.key.bind(it.value) } + val derivationResult = AutoDiffContext(field, arguments).derivate(function) + derivationResult.derivative(symbol) } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt index aef30c6dd..6c61c7c7d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt @@ -1,7 +1,12 @@ package kscience.kmath.expressions +import kscience.kmath.linear.Point +import kscience.kmath.structures.BufferFactory +import kscience.kmath.structures.Structure2D + /** * An environment to easy transform indexed variables to symbols and back. + * TODO requires multi-receivers to be beutiful */ public interface SymbolIndexer { public val symbols: List @@ -22,15 +27,26 @@ public interface SymbolIndexer { return get(this@SymbolIndexer.indexOf(symbol)) } + public operator fun Point.get(symbol: Symbol): T { + require(size == symbols.size) { "The input buffer size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + public fun DoubleArray.toMap(): Map { require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } return symbols.indices.associate { symbols[it] to get(it) } } + public operator fun Structure2D.get(rowSymbol: Symbol, columnSymbol: Symbol): T = + get(indexOf(rowSymbol), indexOf(columnSymbol)) + public fun Map.toList(): List = symbols.map { getValue(it) } - public fun Map.toArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) } + public fun Map.toPoint(bufferFactory: BufferFactory): Point = + bufferFactory(symbols.size) { getValue(symbols[it]) } + + public fun Map.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) } } public inline class SimpleSymbolIndexer(override val symbols: List) : SymbolIndexer From 57781678e5e6aa295fca25b00ee3dd2dc2ae8f4f Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 25 Oct 2020 19:46:22 +0300 Subject: [PATCH 47/69] Cleanup --- .../optimization/CMOptimizationProblem.kt | 8 +++++-- .../optimization/OptimizationProblem.kt | 23 +++++++++++++++---- .../kmath/commons/optimization/optimize.kt | 6 ++--- .../commons/optimization/OptimizeTest.kt | 6 ++--- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt index f7c136ed2..b5ea59d6b 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -36,7 +36,7 @@ public class CMOptimizationProblem( addOptimizationData(InitialGuess(map.toDoubleArray())) } - public fun expression(expression: Expression): Unit { + public override fun expression(expression: Expression): Unit { val objectiveFunction = ObjectiveFunction { val args = it.toMap() expression(args) @@ -44,7 +44,7 @@ public class CMOptimizationProblem( addOptimizationData(objectiveFunction) } - public fun derivatives(expression: DifferentiableExpression): Unit { + public override fun diffExpression(expression: DifferentiableExpression): Unit { expression(expression) val gradientFunction = ObjectiveFunctionGradient { val args = it.toMap() @@ -83,6 +83,10 @@ public class CMOptimizationProblem( optimizatorBuilder = block } + override fun update(result: OptimizationResult) { + initialGuess(result.point) + } + override fun optimize(): OptimizationResult { val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined") val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray()) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt index 56291e09c..e52450be1 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt @@ -1,17 +1,32 @@ package kscience.kmath.commons.optimization +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Symbol -import kotlin.reflect.KClass -public typealias ParameterSpacePoint = Map +public interface OptimizationResultFeature + public class OptimizationResult( - public val point: ParameterSpacePoint, + public val point: Map, public val value: T, - public val extra: Map, Any> = emptyMap() + public val features: Set = emptySet(), ) +/** + * A configuration builder for optimization problem + */ public interface OptimizationProblem { + /** + * Set an objective function expression + */ + public fun expression(expression: Expression): Unit + + /** + * + */ + public fun diffExpression(expression: DifferentiableExpression): Unit + public fun update(result: OptimizationResult) public fun optimize(): OptimizationResult } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt index a49949b93..c4bd5704e 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt @@ -13,7 +13,7 @@ public fun Expression.optimize( configuration: CMOptimizationProblem.() -> Unit, ): OptimizationResult { require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) problem.expression(this) return problem.optimize() } @@ -26,7 +26,7 @@ public fun DifferentiableExpression.optimize( configuration: CMOptimizationProblem.() -> Unit, ): OptimizationResult { require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) - problem.derivatives(this) + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) + problem.diffExpression(this) return problem.optimize() } \ No newline at end of file diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index 65d61dcd1..bd7870573 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -9,16 +9,14 @@ internal class OptimizeTest { val y by symbol val normal = DerivativeStructureExpression { - val x = bind(x) - val y = bind(y) - exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2) + exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2) } @Test fun testOptimization() { val result = normal.optimize(x, y) { initialGuess(x to 1.0, y to 1.0) - //no need to select optimizer. Gradient optimizer is used by default + //no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function } println(result.point) println(result.value) From 30132964dd35f4d039ee84729a96bdc2a9014086 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 26 Oct 2020 10:01:30 +0300 Subject: [PATCH 48/69] Separate object for fitting. Chi-squared --- CHANGELOG.md | 1 + .../kmath/commons/optimization/CMFit.kt | 103 ++++++++++++++++++ .../kmath/commons/optimization/optimize.kt | 32 ------ .../commons/optimization/OptimizeTest.kt | 2 +- 4 files changed, 105 insertions(+), 33 deletions(-) create mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt delete mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 109168475..f28041adf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - A separate `Symbol` entity, which is used for global unbound symbol. - A `Symbol` indexing scope. - Basic optimization API for Commons-math. +- Chi squared optimization for array-like data in CM ### Changed - Package changed from `scientifik` to `kscience.kmath`. diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt new file mode 100644 index 000000000..4ffd0559d --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt @@ -0,0 +1,103 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.commons.expressions.DerivativeStructureExpression +import kscience.kmath.commons.expressions.DerivativeStructureField +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.StringSymbol +import kscience.kmath.expressions.Symbol +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.indices +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType +import kotlin.math.pow + + +public object CMFit { + + /** + * Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives + * TODO move to core/separate module + */ + public fun chiSquaredExpression( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: Expression, + xSymbol: Symbol = StringSymbol("x"), + ): Expression { + require(x.size == y.size) { "X and y buffers should be of the same size" } + require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + return Expression { arguments -> + x.indices.sumByDouble { + val xValue = x[it] + val yValue = y[it] + val yErrValue = yErr[it] + val modifiedArgs = arguments + (xSymbol to xValue) + val modelValue = model(modifiedArgs) + ((yValue - modelValue) / yErrValue).pow(2) / 2 + } + } + } + + /** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation + */ + public fun chiSquaredExpression( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, + ): DerivativeStructureExpression { + require(x.size == y.size) { "X and y buffers should be of the same size" } + require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + return DerivativeStructureExpression { + var sum = zero + x.indices.forEach { + val xValue = x[it] + val yValue = y[it] + val yErrValue = yErr[it] + val modelValue = model(const(xValue)) + sum += ((yValue - modelValue) / yErrValue).pow(2) / 2 + } + sum + } + } +} + +/** + * Optimize expression without derivatives + */ +public fun Expression.optimize( + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) + problem.expression(this) + return problem.optimize() +} + +/** + * Optimize differentiable expression + */ +public fun DifferentiableExpression.optimize( + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) + problem.diffExpression(this) + return problem.optimize() +} + +public fun DifferentiableExpression.minimize( + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) + problem.diffExpression(this) + problem.goal(GoalType.MINIMIZE) + return problem.optimize() +} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt deleted file mode 100644 index c4bd5704e..000000000 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt +++ /dev/null @@ -1,32 +0,0 @@ -package kscience.kmath.commons.optimization - -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.Symbol - - -/** - * Optimize expression without derivatives - */ -public fun Expression.optimize( - vararg symbols: Symbol, - configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) - problem.expression(this) - return problem.optimize() -} - -/** - * Optimize differentiable expression - */ -public fun DifferentiableExpression.optimize( - vararg symbols: Symbol, - configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) - problem.diffExpression(this) - return problem.optimize() -} \ No newline at end of file diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index bd7870573..07bda2aa4 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -13,7 +13,7 @@ internal class OptimizeTest { } @Test - fun testOptimization() { + fun testGradientOptimization() { val result = normal.optimize(x, y) { initialGuess(x to 1.0, y to 1.0) //no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function From 4450c0fcc7941896968283bdc120a3a1661dec09 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 26 Oct 2020 14:44:57 +0300 Subject: [PATCH 49/69] Fix orders in DerivativeStructures --- .../DerivativeStructureExpression.kt | 4 +-- .../kmath/commons/optimization/CMFit.kt | 19 +++++------ .../optimization/CMOptimizationProblem.kt | 11 ++++--- .../optimization/OptimizationProblem.kt | 11 +++++-- .../random/CMRandomGeneratorWrapper.kt | 5 +-- .../commons/optimization/OptimizeTest.kt | 32 ++++++++++++++++++- 6 files changed, 60 insertions(+), 22 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 376fea7a3..272501729 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -38,13 +38,13 @@ public class DerivativeStructureField( key.identity to DerivativeStructureSymbol(key, value) } - override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value) + override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value) public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) - public fun Number.const(): DerivativeStructure = const(toDouble()) + //public fun Number.const(): DerivativeStructure = const(toDouble()) public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double { return derivative(mapOf(parameter to order)) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt index 4ffd0559d..a62630ed3 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt @@ -17,9 +17,9 @@ public object CMFit { /** * Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives - * TODO move to core/separate module + * TODO move to prob/stat */ - public fun chiSquaredExpression( + public fun chiSquared( x: Buffer, y: Buffer, yErr: Buffer, @@ -35,7 +35,7 @@ public object CMFit { val yErrValue = yErr[it] val modifiedArgs = arguments + (xSymbol to xValue) val modelValue = model(modifiedArgs) - ((yValue - modelValue) / yErrValue).pow(2) / 2 + ((yValue - modelValue) / yErrValue).pow(2) } } } @@ -43,7 +43,7 @@ public object CMFit { /** * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation */ - public fun chiSquaredExpression( + public fun chiSquared( x: Buffer, y: Buffer, yErr: Buffer, @@ -58,7 +58,7 @@ public object CMFit { val yValue = y[it] val yErrValue = yErr[it] val modelValue = model(const(xValue)) - sum += ((yValue - modelValue) / yErrValue).pow(2) / 2 + sum += ((yValue - modelValue) / yErrValue).pow(2) } sum } @@ -92,12 +92,13 @@ public fun DifferentiableExpression.optimize( } public fun DifferentiableExpression.minimize( - vararg symbols: Symbol, - configuration: CMOptimizationProblem.() -> Unit, + vararg startPoint: Pair, + configuration: CMOptimizationProblem.() -> Unit = {}, ): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) + require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration) problem.diffExpression(this) + problem.initialGuess(startPoint.toMap()) problem.goal(GoalType.MINIMIZE) return problem.optimize() } \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt index b5ea59d6b..2ca907d05 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -17,14 +17,13 @@ public operator fun PointValuePair.component2(): Double = value public class CMOptimizationProblem( override val symbols: List, -) : OptimizationProblem, SymbolIndexer { - protected val optimizationData: HashMap, OptimizationData> = HashMap() +) : OptimizationProblem, SymbolIndexer, OptimizationFeature { + private val optimizationData: HashMap, OptimizationData> = HashMap() private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null - public var convergenceChecker: ConvergenceChecker = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE, DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER) - private fun addOptimizationData(data: OptimizationData) { + public fun addOptimizationData(data: OptimizationData) { optimizationData[data::class] = data } @@ -32,6 +31,8 @@ public class CMOptimizationProblem( addOptimizationData(MaxEval.unlimited()) } + public fun exportOptimizationData(): List = optimizationData.values.toList() + public fun initialGuess(map: Map): Unit { addOptimizationData(InitialGuess(map.toDoubleArray())) } @@ -90,7 +91,7 @@ public class CMOptimizationProblem( override fun optimize(): OptimizationResult { val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined") val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray()) - return OptimizationResult(point.toMap(), value) + return OptimizationResult(point.toMap(), value, setOf(this)) } public companion object { diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt index e52450be1..a246a817b 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt @@ -4,14 +4,19 @@ import kscience.kmath.expressions.DifferentiableExpression import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Symbol -public interface OptimizationResultFeature +public interface OptimizationFeature +//TODO move to prob/stat public class OptimizationResult( public val point: Map, public val value: T, - public val features: Set = emptySet(), -) + public val features: Set = emptySet(), +){ + override fun toString(): String { + return "OptimizationResult(point=$point, value=$value)" + } +} /** * A configuration builder for optimization problem diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt index 58609deae..9600f6901 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -2,8 +2,9 @@ package kscience.kmath.commons.random import kscience.kmath.prob.RandomGenerator -public class CMRandomGeneratorWrapper(public val factory: (IntArray) -> RandomGenerator) : - org.apache.commons.math3.random.RandomGenerator { +public class CMRandomGeneratorWrapper( + public val factory: (IntArray) -> RandomGenerator, +) : org.apache.commons.math3.random.RandomGenerator { private var generator: RandomGenerator = factory(intArrayOf()) public override fun nextBoolean(): Boolean = generator.nextBoolean() diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index 07bda2aa4..ff5542235 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -2,14 +2,19 @@ package kscience.kmath.commons.optimization import kscience.kmath.commons.expressions.DerivativeStructureExpression import kscience.kmath.expressions.symbol +import kscience.kmath.prob.Distribution +import kscience.kmath.prob.RandomGenerator +import kscience.kmath.prob.normal +import kscience.kmath.structures.asBuffer import org.junit.jupiter.api.Test +import kotlin.math.pow internal class OptimizeTest { val x by symbol val y by symbol val normal = DerivativeStructureExpression { - exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2) + exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2) } @Test @@ -32,4 +37,29 @@ internal class OptimizeTest { println(result.point) println(result.value) } + + @Test + fun testFit() { + val a by symbol + val b by symbol + val c by symbol + + val sigma = 1.0 + val generator = Distribution.normal(0.0, sigma) + val chain = generator.sample(RandomGenerator.default(1126)) + val x = (1..100).map { it.toDouble() } + val y = x.map { it -> + it.pow(2) + it + 1 + chain.nextDouble() + } + val yErr = x.map { sigma } + with(CMFit) { + val chi2 = chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x -> + bind(a) * x.pow(2) + bind(b) * x + bind(c) + } + + val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) + println(result) + println("Chi2/dof = ${result.value / (x.size - 3)}") + } + } } \ No newline at end of file From 9a147d033e02abd2647df344a09bb802d3e359e7 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 27 Oct 2020 17:57:17 +0300 Subject: [PATCH 50/69] Another refactor of SimpleAutoDiff --- .../DerivativeStructureExpression.kt | 11 +- .../kmath/commons/optimization/CMFit.kt | 15 +- .../optimization/CMOptimizationProblem.kt | 6 +- .../optimization/OptimizationProblem.kt | 37 --- .../expressions/DifferentiableExpression.kt | 14 +- .../kmath/expressions/SimpleAutoDiff.kt | 286 +++++++++++------- .../kmath/expressions/expressionBuilders.kt | 5 - .../kmath/expressions/SimpleAutoDiffTest.kt | 15 +- 8 files changed, 214 insertions(+), 175 deletions(-) delete mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 272501729..c593f5103 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -1,9 +1,6 @@ package kscience.kmath.commons.expressions -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.ExpressionAlgebra -import kscience.kmath.expressions.Symbol +import kscience.kmath.expressions.* import kscience.kmath.operations.ExtendedField import org.apache.commons.math3.analysis.differentiation.DerivativeStructure @@ -92,6 +89,12 @@ public class DerivativeStructureField( public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this + + public companion object : AutoDiffProcessor { + override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression { + return DerivativeStructureExpression(function) + } + } } /** diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt index a62630ed3..3143dcca5 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt @@ -71,12 +71,8 @@ public object CMFit { public fun Expression.optimize( vararg symbols: Symbol, configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) - problem.expression(this) - return problem.optimize() -} +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) + /** * Optimize differentiable expression @@ -84,12 +80,7 @@ public fun Expression.optimize( public fun DifferentiableExpression.optimize( vararg symbols: Symbol, configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) - problem.diffExpression(this) - return problem.optimize() -} +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) public fun DifferentiableExpression.minimize( vararg startPoint: Pair, diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt index 2ca907d05..0d96faaa3 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -33,7 +33,7 @@ public class CMOptimizationProblem( public fun exportOptimizationData(): List = optimizationData.values.toList() - public fun initialGuess(map: Map): Unit { + public override fun initialGuess(map: Map): Unit { addOptimizationData(InitialGuess(map.toDoubleArray())) } @@ -94,10 +94,12 @@ public class CMOptimizationProblem( return OptimizationResult(point.toMap(), value, setOf(this)) } - public companion object { + public companion object : OptimizationProblemFactory { public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 public const val DEFAULT_MAX_ITER: Int = 1000 + + override fun build(symbols: List): CMOptimizationProblem = CMOptimizationProblem(symbols) } } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt deleted file mode 100644 index a246a817b..000000000 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt +++ /dev/null @@ -1,37 +0,0 @@ -package kscience.kmath.commons.optimization - -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.Symbol - -public interface OptimizationFeature - -//TODO move to prob/stat - -public class OptimizationResult( - public val point: Map, - public val value: T, - public val features: Set = emptySet(), -){ - override fun toString(): String { - return "OptimizationResult(point=$point, value=$value)" - } -} - -/** - * A configuration builder for optimization problem - */ -public interface OptimizationProblem { - /** - * Set an objective function expression - */ - public fun expression(expression: Expression): Unit - - /** - * - */ - public fun diffExpression(expression: DifferentiableExpression): Unit - public fun update(result: OptimizationResult) - public fun optimize(): OptimizationResult -} - diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index 5fe31caca..705839b57 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -23,10 +23,9 @@ public fun DifferentiableExpression.derivative(symbol: Symbol): Expressio public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) -//public interface DifferentiableExpressionBuilder>: ExpressionBuilder { -// public override fun expression(block: A.() -> E): DifferentiableExpression -//} - +/** + * A [DifferentiableExpression] that defines only first derivatives + */ public abstract class FirstDerivativeExpression : DifferentiableExpression { public abstract fun derivativeOrNull(symbol: Symbol): Expression? @@ -35,4 +34,11 @@ public abstract class FirstDerivativeExpression : DifferentiableExpression val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null return derivativeOrNull(dSymbol) } +} + +/** + * A factory that converts an expression in autodiff variables to a [DifferentiableExpression] + */ +public interface AutoDiffProcessor> { + public fun process(function: A.() -> I): DifferentiableExpression } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index 6231a40c1..e66832fdb 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -47,7 +47,7 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point DerivationResult.grad(vararg variables: Symbol): Point> F.simpleAutoDiff( bindings: Map, - body: AutoDiffField.() -> AutoDiffValue, + body: SimpleAutoDiffField.() -> AutoDiffValue, ): DerivationResult { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } - return AutoDiffContext(this, bindings).derivate(body) + return SimpleAutoDiffField(this, bindings).derivate(body) } public fun > F.simpleAutoDiff( vararg bindings: Pair, - body: AutoDiffField.() -> AutoDiffValue, + body: SimpleAutoDiffField.() -> AutoDiffValue, ): DerivationResult = simpleAutoDiff(bindings.toMap(), body) /** * Represents field in context of which functions can be derived. */ -public abstract class AutoDiffField> - : Field>, ExpressionAlgebra> { +public open class SimpleAutoDiffField>( + public val context: F, + bindings: Map, +) : Field>, ExpressionAlgebra> { - public abstract val context: F + // this stack contains pairs of blocks and values to apply them to + private var stack: Array = arrayOfNulls(8) + private var sp: Int = 0 + private val derivatives: MutableMap, T> = hashMapOf() + + /** + * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result + * with respect to this variable. + * + * @param T the non-nullable type of value. + * @property value The value of this variable. + */ + private class AutoDiffVariableWithDerivative( + override val identity: String, + value: T, + var d: T, + ) : AutoDiffValue(value), Symbol { + override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() + } + + private val bindings: Map> = bindings.entries.associate { + it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero) + } + + override fun bindOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity] + + private fun getDerivative(variable: AutoDiffValue): T = + (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero + + private fun setDerivative(variable: AutoDiffValue, value: T) { + if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value + } + + + @Suppress("UNCHECKED_CAST") + private fun runBackwardPass() { + while (sp > 0) { + val value = stack[--sp] + val block = stack[--sp] as F.(Any?) -> Unit + context.block(value) + } + } + + override val zero: AutoDiffValue get() = const(context.zero) + override val one: AutoDiffValue get() = const(context.one) + + override fun const(value: T): AutoDiffValue = AutoDiffValue(value) /** * A variable accessing inner state of derivatives. * Use this value in inner builders to avoid creating additional derivative bindings. */ - public abstract var AutoDiffValue.d: T + public var AutoDiffValue.d: T + get() = getDerivative(this) + set(value) = setDerivative(this, value) + + public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) /** * Performs update of derivative after the rest of the formula in the back-pass. @@ -101,9 +155,22 @@ public abstract class AutoDiffField> * } * ``` */ - public abstract fun derive(value: R, block: F.(R) -> Unit): R + @Suppress("UNCHECKED_CAST") + public fun derive(value: R, block: F.(R) -> Unit): R { + // save block to stack for backward pass + if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) + stack[sp++] = block + stack[sp++] = value + return value + } - public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) + + internal fun derivate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult { + val result = function() + result.d = context.one // computing derivative w.r.t result + runBackwardPass() + return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) + } // Overloads for Double constants @@ -119,68 +186,7 @@ public abstract class AutoDiffField> override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } -} -/** - * Automatic Differentiation context class. - */ -private class AutoDiffContext>( - override val context: F, - bindings: Map, -) : AutoDiffField() { - // this stack contains pairs of blocks and values to apply them to - private var stack: Array = arrayOfNulls(8) - private var sp: Int = 0 - private val derivatives: MutableMap, T> = hashMapOf() - override val zero: AutoDiffValue get() = const(context.zero) - override val one: AutoDiffValue get() = const(context.one) - - /** - * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result - * with respect to this variable. - * - * @param T the non-nullable type of value. - * @property value The value of this variable. - */ - private class AutoDiffVariableWithDeriv( - override val identity: String, - value: T, - var d: T, - ) : AutoDiffValue(value), Symbol{ - override fun toString(): String = identity - override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity - override fun hashCode(): Int = identity.hashCode() - } - - private val bindings: Map> = bindings.entries.associate { - it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero) - } - - override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv? = bindings[symbol.identity] - - override fun const(value: T): AutoDiffValue = AutoDiffValue(value) - - override var AutoDiffValue.d: T - get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero - set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value - - @Suppress("UNCHECKED_CAST") - override fun derive(value: R, block: F.(R) -> Unit): R { - // save block to stack for backward pass - if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) - stack[sp++] = block - stack[sp++] = value - return value - } - - @Suppress("UNCHECKED_CAST") - fun runBackwardPass() { - while (sp > 0) { - val value = stack[--sp] - val block = stack[--sp] as F.(Any?) -> Unit - context.block(value) - } - } // Basic math (+, -, *, /) @@ -206,13 +212,6 @@ private class AutoDiffContext>( derive(const { k.toDouble() * a.value }) { z -> a.d += z.d * k.toDouble() } - - inline fun derivate(function: AutoDiffField.() -> AutoDiffValue): DerivationResult { - val result = function() - result.d = context.one // computing derivative w.r.t result - runBackwardPass() - return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) - } } /** @@ -220,99 +219,178 @@ private class AutoDiffContext>( */ public class SimpleAutoDiffExpression>( public val field: F, - public val function: AutoDiffField.() -> AutoDiffValue, + public val function: SimpleAutoDiffField.() -> AutoDiffValue, ) : FirstDerivativeExpression() { public override operator fun invoke(arguments: Map): T { //val bindings = arguments.entries.map { it.key.bind(it.value) } - return AutoDiffContext(field, arguments).function().value + return SimpleAutoDiffField(field, arguments).function().value } override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> //val bindings = arguments.entries.map { it.key.bind(it.value) } - val derivationResult = AutoDiffContext(field, arguments).derivate(function) + val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function) derivationResult.derivative(symbol) } } +/** + * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] + */ +public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField> { + return object : AutoDiffProcessor, SimpleAutoDiffField> { + override fun process(function: SimpleAutoDiffField.() -> AutoDiffValue): DifferentiableExpression { + return SimpleAutoDiffExpression(field, function) + } + } +} + // Extensions for differentiation of various basic mathematical functions // x ^ 2 -public fun > AutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue = derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } // x ^ 1/2 -public fun > AutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue = derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } // x ^ y (const) -public fun > AutoDiffField.pow( +public fun > SimpleAutoDiffField.pow( x: AutoDiffValue, y: Double, ): AutoDiffValue = derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } -public fun > AutoDiffField.pow( +public fun > SimpleAutoDiffField.pow( x: AutoDiffValue, y: Int, ): AutoDiffValue = pow(x, y.toDouble()) // exp(x) -public fun > AutoDiffField.exp(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.exp(x: AutoDiffValue): AutoDiffValue = derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } // ln(x) -public fun > AutoDiffField.ln(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.ln(x: AutoDiffValue): AutoDiffValue = derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } // x ^ y (any) -public fun > AutoDiffField.pow( +public fun > SimpleAutoDiffField.pow( x: AutoDiffValue, y: AutoDiffValue, ): AutoDiffValue = exp(y * ln(x)) // sin(x) -public fun > AutoDiffField.sin(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.sin(x: AutoDiffValue): AutoDiffValue = derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } // cos(x) -public fun > AutoDiffField.cos(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.cos(x: AutoDiffValue): AutoDiffValue = derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } -public fun > AutoDiffField.tan(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.tan(x: AutoDiffValue): AutoDiffValue = derive(const { tan(x.value) }) { z -> val c = cos(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asin(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.asin(x: AutoDiffValue): AutoDiffValue = derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.acos(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.acos(x: AutoDiffValue): AutoDiffValue = derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.atan(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.atan(x: AutoDiffValue): AutoDiffValue = derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } -public fun > AutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue = - derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } +public fun > SimpleAutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue = + derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) } -public fun > AutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue = - derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } +public fun > SimpleAutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue = + derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) } -public fun > AutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue = - derive(const { tan(x.value) }) { z -> +public fun > SimpleAutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue = + derive(const { tanh(x.value) }) { z -> val c = cosh(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue = derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } -public fun > AutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue = derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } -public fun > AutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue = +public fun > SimpleAutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue = derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } +public class SimpleAutoDiffExtendedField>( + context: F, + bindings: Map, +) : ExtendedField>, SimpleAutoDiffField(context, bindings) { + // x ^ 2 + public fun sqr(x: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sqr(x) + + // x ^ 1/2 + public override fun sqrt(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sqrt(arg) + + // x ^ y (const) + public override fun power(arg: AutoDiffValue, pow: Number): AutoDiffValue = + (this as SimpleAutoDiffField).pow(arg, pow.toDouble()) + + // exp(x) + public override fun exp(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).exp(arg) + + // ln(x) + public override fun ln(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).ln(arg) + + // x ^ y (any) + public fun pow( + x: AutoDiffValue, + y: AutoDiffValue, + ): AutoDiffValue = exp(y * ln(x)) + + // sin(x) + public override fun sin(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sin(arg) + + // cos(x) + public override fun cos(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).cos(arg) + + public override fun tan(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).tan(arg) + + public override fun asin(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).asin(arg) + + public override fun acos(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).acos(arg) + + public override fun atan(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).atan(arg) + + public override fun sinh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sinh(arg) + + public override fun cosh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).cosh(arg) + + public override fun tanh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).tanh(arg) + + public override fun asinh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).asinh(arg) + + public override fun acosh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).acosh(arg) + + public override fun atanh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).atanh(arg) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt index defbb14ad..1603bc21d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt @@ -8,11 +8,6 @@ import kotlin.contracts.InvocationKind import kotlin.contracts.contract -//public interface ExpressionBuilder> { -// public fun expression(block: A.() -> E): Expression -//} - - /** * Creates a functional expression with this [Space]. */ diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt index ca8ec1e17..510ed23a9 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -2,6 +2,7 @@ package kscience.kmath.expressions import kscience.kmath.operations.RealField import kscience.kmath.structures.asBuffer +import kotlin.math.E import kotlin.math.PI import kotlin.math.pow import kotlin.math.sqrt @@ -13,18 +14,18 @@ class SimpleAutoDiffTest { fun dx( xBinding: Pair, - body: AutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, + body: SimpleAutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) } fun dxy( xBinding: Pair, yBinding: Pair, - body: AutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, + body: SimpleAutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.simpleAutoDiff(xBinding, yBinding) { body(bind(xBinding.first), bind(yBinding.first)) } - fun diff(block: AutoDiffField.() -> AutoDiffValue): SimpleAutoDiffExpression { + fun diff(block: SimpleAutoDiffField.() -> AutoDiffValue): SimpleAutoDiffExpression { return SimpleAutoDiffExpression(RealField, block) } @@ -45,7 +46,7 @@ class SimpleAutoDiffTest { @Test fun testPlusX2Expr() { - val expr = diff{ + val expr = diff { val x = bind(x) x + x } @@ -245,9 +246,9 @@ class SimpleAutoDiffTest { @Test fun testTanh() { - val y = dx(x to PI / 6) { x -> tanh(x) } - assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6) - assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2 + val y = dx(x to 1.0) { x -> tanh(x) } + assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6) + assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2 } @Test From 1c1580c8e6c411fe792143882b2e5b67ea2b2c46 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 27 Oct 2020 17:57:49 +0300 Subject: [PATCH 51/69] Generification of autodiff and chi2 --- kmath-commons/build.gradle.kts | 2 +- .../commons/optimization/OptimizeTest.kt | 5 +- .../kscience/kmath/functions/Polynomial.kt | 12 +-- .../kscience/kmath/functions/functions.kt | 34 ------- .../kotlin/kscience/kmath/prob/Fit.kt | 36 ++++++++ .../kmath/prob/OptimizationProblem.kt | 91 +++++++++++++++++++ 6 files changed, 133 insertions(+), 47 deletions(-) delete mode 100644 kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/functions.kt create mode 100644 kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt create mode 100644 kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/OptimizationProblem.kt diff --git a/kmath-commons/build.gradle.kts b/kmath-commons/build.gradle.kts index ed6452ad8..f0b20e82f 100644 --- a/kmath-commons/build.gradle.kts +++ b/kmath-commons/build.gradle.kts @@ -7,6 +7,6 @@ dependencies { api(project(":kmath-core")) api(project(":kmath-coroutines")) api(project(":kmath-prob")) -// api(project(":kmath-functions")) + api(project(":kmath-functions")) api("org.apache.commons:commons-math3:3.6.1") } diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index ff5542235..d9fc5ebef 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -46,7 +46,7 @@ internal class OptimizeTest { val sigma = 1.0 val generator = Distribution.normal(0.0, sigma) - val chain = generator.sample(RandomGenerator.default(1126)) + val chain = generator.sample(RandomGenerator.default(112667)) val x = (1..100).map { it.toDouble() } val y = x.map { it -> it.pow(2) + it + 1 + chain.nextDouble() @@ -54,7 +54,8 @@ internal class OptimizeTest { val yErr = x.map { sigma } with(CMFit) { val chi2 = chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x -> - bind(a) * x.pow(2) + bind(b) * x + bind(c) + val cWithDefault = bindOrNull(c)?: one + bind(a) * x.pow(2) + bind(b) * x + cWithDefault } val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) diff --git a/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Polynomial.kt index c513a6889..820076c4c 100644 --- a/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Polynomial.kt @@ -8,13 +8,13 @@ import kotlin.contracts.contract import kotlin.math.max import kotlin.math.pow -// TODO make `inline`, when KT-41771 gets fixed /** * Polynomial coefficients without fixation on specific context they are applied to * @param coefficients constant is the leftmost coefficient */ public inline class Polynomial(public val coefficients: List) +@Suppress("FunctionName") public fun Polynomial(vararg coefficients: T): Polynomial = Polynomial(coefficients.toList()) public fun Polynomial.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) } @@ -33,14 +33,6 @@ public fun > Polynomial.value(ring: C, arg: T): T = ring res } -/** - * Represent a polynomial as a context-dependent function - */ -public fun > Polynomial.asMathFunction(): MathFunction = - object : MathFunction { - override fun C.invoke(arg: T): T = value(this, arg) - } - /** * Represent the polynomial as a regular context-less function */ @@ -49,7 +41,7 @@ public fun > Polynomial.asFunction(ring: C): (T) -> T = /** * An algebra for polynomials */ -public class PolynomialSpace>(public val ring: C) : Space> { +public class PolynomialSpace>(private val ring: C) : Space> { public override val zero: Polynomial = Polynomial(emptyList()) public override fun add(a: Polynomial, b: Polynomial): Polynomial { diff --git a/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/functions.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/functions.kt deleted file mode 100644 index d780c16f3..000000000 --- a/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/functions.kt +++ /dev/null @@ -1,34 +0,0 @@ -package kscience.kmath.functions - -import kscience.kmath.operations.Algebra -import kscience.kmath.operations.RealField - -// TODO make fun interface when KT-41770 is fixed -/** - * A regular function that could be called only inside specific algebra context - * @param T source type - * @param C source algebra constraint - * @param R result type - */ -public /*fun*/ interface MathFunction, R> { - public operator fun C.invoke(arg: T): R -} - -public fun MathFunction.invoke(arg: Double): R = RealField.invoke(arg) - -/** - * A suspendable function defined in algebraic context - */ -// TODO make fun interface, when the new JVM IR is enabled -public interface SuspendableMathFunction, R> { - public suspend operator fun C.invoke(arg: T): R -} - -public suspend fun SuspendableMathFunction.invoke(arg: Double): R = RealField.invoke(arg) - -/** - * A parametric function with parameter - */ -public fun interface ParametricFunction> { - public operator fun C.invoke(arg: T, parameter: P): T -} diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt new file mode 100644 index 000000000..efe582212 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt @@ -0,0 +1,36 @@ +package kscience.kmath.prob + +import kscience.kmath.expressions.AutoDiffProcessor +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.ExpressionAlgebra +import kscience.kmath.operations.ExtendedField +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.indices + +public object Fit { + + /** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation + */ + public fun chiSquared( + autoDiff: AutoDiffProcessor, + x: Buffer, + y: Buffer, + yErr: Buffer, + model: A.(I) -> I, + ): DifferentiableExpression where A : ExtendedField, A : ExpressionAlgebra { + require(x.size == y.size) { "X and y buffers should be of the same size" } + require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + return autoDiff.process { + var sum = zero + x.indices.forEach { + val xValue = const(x[it]) + val yValue = const(y[it]) + val yErrValue = const(yErr[it]) + val modelValue = model(xValue) + sum += ((yValue - modelValue) / yErrValue).pow(2) + } + sum + } + } +} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/OptimizationProblem.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/OptimizationProblem.kt new file mode 100644 index 000000000..c5fb3fa54 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/OptimizationProblem.kt @@ -0,0 +1,91 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol + +public interface OptimizationFeature + +public class OptimizationResult( + public val point: Map, + public val value: T, + public val features: Set = emptySet(), +) { + override fun toString(): String { + return "OptimizationResult(point=$point, value=$value)" + } +} + +public operator fun OptimizationResult.plus( + feature: OptimizationFeature, +): OptimizationResult = OptimizationResult(point, value, features + feature) + +/** + * A configuration builder for optimization problem + */ +public interface OptimizationProblem { + /** + * Define the initial guess for the optimization problem + */ + public fun initialGuess(map: Map): Unit + + /** + * Set an objective function expression + */ + public fun expression(expression: Expression): Unit + + /** + * Set a differentiable expression as objective function as function and gradient provider + */ + public fun diffExpression(expression: DifferentiableExpression): Unit + + /** + * Update the problem from previous optimization run + */ + public fun update(result: OptimizationResult) + + /** + * Make an optimization run + */ + public fun optimize(): OptimizationResult +} + +public interface OptimizationProblemFactory> { + public fun build(symbols: List): P + +} + +public operator fun > OptimizationProblemFactory.invoke( + symbols: List, + block: P.() -> Unit, +): P = build(symbols).apply(block) + + +/** + * Optimize expression without derivatives using specific [OptimizationProblemFactory] + */ +public fun > Expression.optimizeWith( + factory: OptimizationProblemFactory, + vararg symbols: Symbol, + configuration: F.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = factory(symbols.toList(),configuration) + problem.expression(this) + return problem.optimize() +} + +/** + * Optimize differentiable expression using specific [OptimizationProblemFactory] + */ +public fun > DifferentiableExpression.optimizeWith( + factory: OptimizationProblemFactory, + vararg symbols: Symbol, + configuration: F.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = factory(symbols.toList(), configuration) + problem.diffExpression(this) + return problem.optimize() +} + From f8c3d1793c80f7f6fcdbd53b39e9c004cd535a3a Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 28 Oct 2020 09:08:37 +0300 Subject: [PATCH 52/69] Fitting refactor --- .../kmath/commons/optimization/CMFit.kt | 95 ------------------- .../kmath/commons/optimization/cmFit.kt | 68 +++++++++++++ .../commons/optimization/OptimizeTest.kt | 20 ++-- .../kmath/prob/{Fit.kt => Fitting.kt} | 31 +++++- 4 files changed, 105 insertions(+), 109 deletions(-) delete mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt create mode 100644 kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt rename kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/{Fit.kt => Fitting.kt} (52%) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt deleted file mode 100644 index 3143dcca5..000000000 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMFit.kt +++ /dev/null @@ -1,95 +0,0 @@ -package kscience.kmath.commons.optimization - -import kscience.kmath.commons.expressions.DerivativeStructureExpression -import kscience.kmath.commons.expressions.DerivativeStructureField -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.StringSymbol -import kscience.kmath.expressions.Symbol -import kscience.kmath.structures.Buffer -import kscience.kmath.structures.indices -import org.apache.commons.math3.analysis.differentiation.DerivativeStructure -import org.apache.commons.math3.optim.nonlinear.scalar.GoalType -import kotlin.math.pow - - -public object CMFit { - - /** - * Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives - * TODO move to prob/stat - */ - public fun chiSquared( - x: Buffer, - y: Buffer, - yErr: Buffer, - model: Expression, - xSymbol: Symbol = StringSymbol("x"), - ): Expression { - require(x.size == y.size) { "X and y buffers should be of the same size" } - require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } - return Expression { arguments -> - x.indices.sumByDouble { - val xValue = x[it] - val yValue = y[it] - val yErrValue = yErr[it] - val modifiedArgs = arguments + (xSymbol to xValue) - val modelValue = model(modifiedArgs) - ((yValue - modelValue) / yErrValue).pow(2) - } - } - } - - /** - * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation - */ - public fun chiSquared( - x: Buffer, - y: Buffer, - yErr: Buffer, - model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, - ): DerivativeStructureExpression { - require(x.size == y.size) { "X and y buffers should be of the same size" } - require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } - return DerivativeStructureExpression { - var sum = zero - x.indices.forEach { - val xValue = x[it] - val yValue = y[it] - val yErrValue = yErr[it] - val modelValue = model(const(xValue)) - sum += ((yValue - modelValue) / yErrValue).pow(2) - } - sum - } - } -} - -/** - * Optimize expression without derivatives - */ -public fun Expression.optimize( - vararg symbols: Symbol, - configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) - - -/** - * Optimize differentiable expression - */ -public fun DifferentiableExpression.optimize( - vararg symbols: Symbol, - configuration: CMOptimizationProblem.() -> Unit, -): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) - -public fun DifferentiableExpression.minimize( - vararg startPoint: Pair, - configuration: CMOptimizationProblem.() -> Unit = {}, -): OptimizationResult { - require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration) - problem.diffExpression(this) - problem.initialGuess(startPoint.toMap()) - problem.goal(GoalType.MINIMIZE) - return problem.optimize() -} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt new file mode 100644 index 000000000..24df3177d --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt @@ -0,0 +1,68 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.commons.expressions.DerivativeStructureField +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol +import kscience.kmath.prob.Fitting +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.asBuffer +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType + + +/** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation + */ +public fun Fitting.chiSquared( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, +): DifferentiableExpression = chiSquared(DerivativeStructureField, x, y, yErr, model) + +/** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation + */ +public fun Fitting.chiSquared( + x: Iterable, + y: Iterable, + yErr: Iterable, + model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, +): DifferentiableExpression = chiSquared( + DerivativeStructureField, + x.toList().asBuffer(), + y.toList().asBuffer(), + yErr.toList().asBuffer(), + model +) + + +/** + * Optimize expression without derivatives + */ +public fun Expression.optimize( + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) + + +/** + * Optimize differentiable expression + */ +public fun DifferentiableExpression.optimize( + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) + +public fun DifferentiableExpression.minimize( + vararg startPoint: Pair, + configuration: CMOptimizationProblem.() -> Unit = {}, +): OptimizationResult { + require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration) + problem.diffExpression(this) + problem.initialGuess(startPoint.toMap()) + problem.goal(GoalType.MINIMIZE) + return problem.optimize() +} \ No newline at end of file diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index d9fc5ebef..502ed40f8 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -3,6 +3,7 @@ package kscience.kmath.commons.optimization import kscience.kmath.commons.expressions.DerivativeStructureExpression import kscience.kmath.expressions.symbol import kscience.kmath.prob.Distribution +import kscience.kmath.prob.Fitting import kscience.kmath.prob.RandomGenerator import kscience.kmath.prob.normal import kscience.kmath.structures.asBuffer @@ -39,7 +40,7 @@ internal class OptimizeTest { } @Test - fun testFit() { + fun testCmFit() { val a by symbol val b by symbol val c by symbol @@ -52,15 +53,14 @@ internal class OptimizeTest { it.pow(2) + it + 1 + chain.nextDouble() } val yErr = x.map { sigma } - with(CMFit) { - val chi2 = chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x -> - val cWithDefault = bindOrNull(c)?: one - bind(a) * x.pow(2) + bind(b) * x + cWithDefault - } - - val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) - println(result) - println("Chi2/dof = ${result.value / (x.size - 3)}") + val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x -> + val cWithDefault = bindOrNull(c) ?: one + bind(a) * x.pow(2) + bind(b) * x + cWithDefault } + + val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) + println(result) + println("Chi2/dof = ${result.value / (x.size - 3)}") } + } \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fitting.kt similarity index 52% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt rename to kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fitting.kt index efe582212..97548d676 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fitting.kt @@ -1,13 +1,12 @@ package kscience.kmath.prob -import kscience.kmath.expressions.AutoDiffProcessor -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.ExpressionAlgebra +import kscience.kmath.expressions.* import kscience.kmath.operations.ExtendedField import kscience.kmath.structures.Buffer import kscience.kmath.structures.indices +import kotlin.math.pow -public object Fit { +public object Fitting { /** * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation @@ -33,4 +32,28 @@ public object Fit { sum } } + + /** + * Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives + */ + public fun chiSquared( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: Expression, + xSymbol: Symbol = StringSymbol("x"), + ): Expression { + require(x.size == y.size) { "X and y buffers should be of the same size" } + require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + return Expression { arguments -> + x.indices.sumByDouble { + val xValue = x[it] + val yValue = y[it] + val yErrValue = yErr[it] + val modifiedArgs = arguments + (xSymbol to xValue) + val modelValue = model(modifiedArgs) + ((yValue - modelValue) / yErrValue).pow(2) + } + } + } } \ No newline at end of file From dfa1bcaf01a2dbbf3eed62706be17a09d92a8618 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 28 Oct 2020 09:16:21 +0300 Subject: [PATCH 53/69] prob renamed to stat --- CHANGELOG.md | 2 ++ README.md | 6 +++--- build.gradle.kts | 2 +- examples/build.gradle.kts | 2 +- .../kscience/kmath/commons/prob/DistributionBenchmark.kt | 2 +- .../kscience/kmath/commons/prob/DistributionDemo.kt | 6 +++--- kmath-commons/build.gradle.kts | 2 +- .../kmath/commons/optimization/CMOptimizationProblem.kt | 4 ++++ .../kotlin/kscience/kmath/commons/optimization/cmFit.kt | 4 +++- .../kmath/commons/random/CMRandomGeneratorWrapper.kt | 2 +- .../kscience/kmath/commons/optimization/OptimizeTest.kt | 8 ++++---- kmath-core/README.md | 2 +- kmath-core/build.gradle.kts | 2 +- {kmath-prob => kmath-stat}/build.gradle.kts | 0 .../kotlin/kscience/kmath/stat}/Distribution.kt | 2 +- .../kotlin/kscience/kmath/stat}/FactorizedDistribution.kt | 2 +- .../src/commonMain/kotlin/kscience/kmath/stat}/Fitting.kt | 2 +- .../kotlin/kscience/kmath/stat}/OptimizationProblem.kt | 2 +- .../commonMain/kotlin/kscience/kmath/stat}/RandomChain.kt | 2 +- .../kotlin/kscience/kmath/stat}/RandomGenerator.kt | 2 +- .../kotlin/kscience/kmath/stat}/SamplerAlgebra.kt | 2 +- .../commonMain/kotlin/kscience/kmath/stat}/Statistic.kt | 2 +- .../kotlin/kscience/kmath/stat}/UniformDistribution.kt | 2 +- .../kotlin/kscience/kmath/stat}/RandomSourceGenerator.kt | 2 +- .../jvmMain/kotlin/kscience/kmath/stat}/distributions.kt | 2 +- .../kscience/kmath/stat}/CommonsDistributionsTest.kt | 2 +- .../jvmTest/kotlin/kscience/kmath/stat}/SamplerTest.kt | 2 +- .../jvmTest/kotlin/kscience/kmath/stat}/StatisticTest.kt | 2 +- settings.gradle.kts | 2 +- 29 files changed, 41 insertions(+), 33 deletions(-) rename {kmath-prob => kmath-stat}/build.gradle.kts (100%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/Distribution.kt (98%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/FactorizedDistribution.kt (98%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/Fitting.kt (98%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/OptimizationProblem.kt (98%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/RandomChain.kt (94%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/RandomGenerator.kt (99%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/SamplerAlgebra.kt (97%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/Statistic.kt (99%) rename {kmath-prob/src/commonMain/kotlin/kscience/kmath/prob => kmath-stat/src/commonMain/kotlin/kscience/kmath/stat}/UniformDistribution.kt (96%) rename {kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob => kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat}/RandomSourceGenerator.kt (98%) rename {kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob => kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat}/distributions.kt (99%) rename {kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob => kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat}/CommonsDistributionsTest.kt (96%) rename {kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob => kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat}/SamplerTest.kt (92%) rename {kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob => kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat}/StatisticTest.kt (96%) diff --git a/CHANGELOG.md b/CHANGELOG.md index f28041adf..2f802d85d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - A `Symbol` indexing scope. - Basic optimization API for Commons-math. - Chi squared optimization for array-like data in CM +- `Fitting` utility object in prob/stat ### Changed - Package changed from `scientifik` to `kscience.kmath`. @@ -21,6 +22,7 @@ - Kotlin version: 1.3.72 -> 1.4.20-M1 - `kmath-ast` doesn't depend on heavy `kotlin-reflect` library. - Full autodiff refactoring based on `Symbol` +- `kmath-prob` renamed to `kmath-stat` ### Deprecated diff --git a/README.md b/README.md index cbdf98afb..afab32dcf 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ can be used for a wide variety of purposes from high performance calculations to > - [buffers](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure > - [expressions](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions > - [domains](kmath-core/src/commonMain/kotlin/kscience/kmath/domains) : Domains -> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation +> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
@@ -151,7 +151,7 @@ can be used for a wide variety of purposes from high performance calculations to > **Maturity**: EXPERIMENTAL
-* ### [kmath-prob](kmath-prob) +* ### [kmath-stat](kmath-stat) > > > **Maturity**: EXPERIMENTAL @@ -201,4 +201,4 @@ with the same artifact names. ## Contributing -The project requires a lot of additional work. Please feel free to contribute in any way and propose new features. +The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero). \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index 74b76d731..acb9f3b68 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("ru.mipt.npm.project") } -val kmathVersion: String by extra("0.2.0-dev-2") +val kmathVersion: String by extra("0.2.0-dev-3") val bintrayRepo: String by extra("kscience") val githubProject: String by extra("kmath") diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 900da966b..9ba1ec5be 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -23,7 +23,7 @@ dependencies { implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) - implementation(project(":kmath-prob")) + implementation(project(":kmath-stat")) implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) implementation(project(":kmath-ejml")) diff --git a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt index 9c0a01961..ef554aeff 100644 --- a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt @@ -4,7 +4,7 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking import kscience.kmath.chains.BlockingRealChain -import kscience.kmath.prob.* +import kscience.kmath.stat.* import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler import org.apache.commons.rng.simple.RandomSource import java.time.Duration diff --git a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt index 7d53e5178..6146e17af 100644 --- a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt +++ b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt @@ -3,9 +3,9 @@ package kscience.kmath.commons.prob import kotlinx.coroutines.runBlocking import kscience.kmath.chains.Chain import kscience.kmath.chains.collectWithState -import kscience.kmath.prob.Distribution -import kscience.kmath.prob.RandomGenerator -import kscience.kmath.prob.normal +import kscience.kmath.stat.Distribution +import kscience.kmath.stat.RandomGenerator +import kscience.kmath.stat.normal private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) diff --git a/kmath-commons/build.gradle.kts b/kmath-commons/build.gradle.kts index f0b20e82f..6a44c92f2 100644 --- a/kmath-commons/build.gradle.kts +++ b/kmath-commons/build.gradle.kts @@ -6,7 +6,7 @@ description = "Commons math binding for kmath" dependencies { api(project(":kmath-core")) api(project(":kmath-coroutines")) - api(project(":kmath-prob")) + api(project(":kmath-stat")) api(project(":kmath-functions")) api("org.apache.commons:commons-math3:3.6.1") } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt index 0d96faaa3..13f9af7bb 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -1,6 +1,10 @@ package kscience.kmath.commons.optimization import kscience.kmath.expressions.* +import kscience.kmath.stat.OptimizationFeature +import kscience.kmath.stat.OptimizationProblem +import kscience.kmath.stat.OptimizationProblemFactory +import kscience.kmath.stat.OptimizationResult import org.apache.commons.math3.optim.* import org.apache.commons.math3.optim.nonlinear.scalar.GoalType import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt index 24df3177d..42475db6c 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt @@ -4,7 +4,9 @@ import kscience.kmath.commons.expressions.DerivativeStructureField import kscience.kmath.expressions.DifferentiableExpression import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Symbol -import kscience.kmath.prob.Fitting +import kscience.kmath.stat.Fitting +import kscience.kmath.stat.OptimizationResult +import kscience.kmath.stat.optimizeWith import kscience.kmath.structures.Buffer import kscience.kmath.structures.asBuffer import org.apache.commons.math3.analysis.differentiation.DerivativeStructure diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt index 9600f6901..1eab5f2bd 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -1,6 +1,6 @@ package kscience.kmath.commons.random -import kscience.kmath.prob.RandomGenerator +import kscience.kmath.stat.RandomGenerator public class CMRandomGeneratorWrapper( public val factory: (IntArray) -> RandomGenerator, diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index 502ed40f8..4384a5124 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -2,10 +2,10 @@ package kscience.kmath.commons.optimization import kscience.kmath.commons.expressions.DerivativeStructureExpression import kscience.kmath.expressions.symbol -import kscience.kmath.prob.Distribution -import kscience.kmath.prob.Fitting -import kscience.kmath.prob.RandomGenerator -import kscience.kmath.prob.normal +import kscience.kmath.stat.Distribution +import kscience.kmath.stat.Fitting +import kscience.kmath.stat.RandomGenerator +import kscience.kmath.stat.normal import kscience.kmath.structures.asBuffer import org.junit.jupiter.api.Test import kotlin.math.pow diff --git a/kmath-core/README.md b/kmath-core/README.md index 6935c0d3c..5501b1d7a 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -7,7 +7,7 @@ The core features of KMath: - [buffers](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure - [expressions](src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions - [domains](src/commonMain/kotlin/kscience/kmath/domains) : Domains - - [autodif](src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation + - [autodif](src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation > #### Artifact: diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index bd254c39d..b0849eca5 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -41,6 +41,6 @@ readme { feature( id = "autodif", description = "Automatic differentiation", - ref = "src/commonMain/kotlin/kscience/kmath/misc/SimpleAutoDiff.kt" + ref = "src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt" ) } \ No newline at end of file diff --git a/kmath-prob/build.gradle.kts b/kmath-stat/build.gradle.kts similarity index 100% rename from kmath-prob/build.gradle.kts rename to kmath-stat/build.gradle.kts diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Distribution.kt similarity index 98% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Distribution.kt index 72660e20d..c4ceb29eb 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Distribution.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kscience.kmath.chains.Chain import kscience.kmath.chains.collect diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/FactorizedDistribution.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/FactorizedDistribution.kt similarity index 98% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/FactorizedDistribution.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/FactorizedDistribution.kt index 4d713fc4e..1ed9deba9 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/FactorizedDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/FactorizedDistribution.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kscience.kmath.chains.Chain import kscience.kmath.chains.SimpleChain diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fitting.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt similarity index 98% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fitting.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt index 97548d676..01fdf4c5e 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fitting.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kscience.kmath.expressions.* import kscience.kmath.operations.ExtendedField diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/OptimizationProblem.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt similarity index 98% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/OptimizationProblem.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt index c5fb3fa54..ea522bff9 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/OptimizationProblem.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt @@ -1,4 +1,4 @@ -package kscience.kmath.commons.optimization +package kscience.kmath.stat import kscience.kmath.expressions.DifferentiableExpression import kscience.kmath.expressions.Expression diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomChain.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomChain.kt similarity index 94% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomChain.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomChain.kt index b4a80f6c5..0f10851b9 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomChain.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomChain.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kscience.kmath.chains.Chain diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomGenerator.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomGenerator.kt similarity index 99% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomGenerator.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomGenerator.kt index 2dd4ce51e..4486ae016 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomGenerator.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomGenerator.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kotlin.random.Random diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/SamplerAlgebra.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/SamplerAlgebra.kt similarity index 97% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/SamplerAlgebra.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/SamplerAlgebra.kt index e363ba30b..f416028a5 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/SamplerAlgebra.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/SamplerAlgebra.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kscience.kmath.chains.Chain import kscience.kmath.chains.ConstantChain diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Statistic.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Statistic.kt similarity index 99% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Statistic.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Statistic.kt index 6720a3d7f..a4624fc21 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Statistic.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Statistic.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/UniformDistribution.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/UniformDistribution.kt similarity index 96% rename from kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/UniformDistribution.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/UniformDistribution.kt index 8df2c01e1..1ba5c96f1 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/UniformDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/UniformDistribution.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kscience.kmath.chains.Chain import kscience.kmath.chains.SimpleChain diff --git a/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/RandomSourceGenerator.kt b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/RandomSourceGenerator.kt similarity index 98% rename from kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/RandomSourceGenerator.kt rename to kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/RandomSourceGenerator.kt index 18be6f019..5cba28a95 100644 --- a/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/RandomSourceGenerator.kt +++ b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/RandomSourceGenerator.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import org.apache.commons.rng.UniformRandomProvider import org.apache.commons.rng.simple.RandomSource diff --git a/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/distributions.kt b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/distributions.kt similarity index 99% rename from kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/distributions.kt rename to kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/distributions.kt index ff20572cc..9a77b0bd2 100644 --- a/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/distributions.kt +++ b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/distributions.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kscience.kmath.chains.BlockingIntChain import kscience.kmath.chains.BlockingRealChain diff --git a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/CommonsDistributionsTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/CommonsDistributionsTest.kt similarity index 96% rename from kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/CommonsDistributionsTest.kt rename to kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/CommonsDistributionsTest.kt index 12a00684b..fe58fac08 100644 --- a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/CommonsDistributionsTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/CommonsDistributionsTest.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.flow.take import kotlinx.coroutines.flow.toList diff --git a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/SamplerTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/SamplerTest.kt similarity index 92% rename from kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/SamplerTest.kt rename to kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/SamplerTest.kt index 75db5c402..afed4c5d0 100644 --- a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/SamplerTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/SamplerTest.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.runBlocking import kotlin.test.Test diff --git a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/StatisticTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/StatisticTest.kt similarity index 96% rename from kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/StatisticTest.kt rename to kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/StatisticTest.kt index 22ca472a8..5cee4d172 100644 --- a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/StatisticTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/StatisticTest.kt @@ -1,4 +1,4 @@ -package kscience.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.flow.drop import kotlinx.coroutines.flow.first diff --git a/settings.gradle.kts b/settings.gradle.kts index 0f549f9ab..fa9edcf22 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -34,7 +34,7 @@ include( ":kmath-histograms", ":kmath-commons", ":kmath-viktor", - ":kmath-prob", + ":kmath-stat", ":kmath-dimensions", ":kmath-for-real", ":kmath-geometry", From 5fa4d40f415e95303f17b7af34850eeeb0508602 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 28 Oct 2020 09:25:37 +0300 Subject: [PATCH 54/69] Remove Differentiable --- .../kmath/expressions/DifferentiableExpression.kt | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index 705839b57..4fe73f283 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -1,20 +1,15 @@ package kscience.kmath.expressions /** - * And object that could be differentiated + * An expression that provides derivatives */ -public interface Differentiable { - public fun derivativeOrNull(orders: Map): T? +public interface DifferentiableExpression : Expression{ + public fun derivativeOrNull(orders: Map): Expression? } -public fun Differentiable.derivative(orders: Map): T = +public fun DifferentiableExpression.derivative(orders: Map): Expression = derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided") -/** - * An expression that provid - */ -public interface DifferentiableExpression : Differentiable>, Expression - public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = derivative(mapOf(*orders)) From 73b4294122a966095eddf26942cd6bff5a673405 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 28 Oct 2020 09:56:33 +0300 Subject: [PATCH 55/69] Try to fix Native compilation bug --- .../kotlin/kscience/kmath/expressions/Expression.kt | 5 +++-- .../kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index 7e1eb0cd7..9743363c6 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -3,6 +3,7 @@ package kscience.kmath.expressions import kscience.kmath.operations.Algebra import kotlin.jvm.JvmName import kotlin.properties.ReadOnlyProperty +import kotlin.reflect.KProperty /** * A marker interface for a symbol. A symbol mus have an identity @@ -84,8 +85,8 @@ public interface ExpressionAlgebra : Algebra { public fun ExpressionAlgebra.bind(symbol: Symbol): E = bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this") -public val symbol: ReadOnlyProperty = ReadOnlyProperty { _, property -> - StringSymbol(property.name) +public val symbol: ReadOnlyProperty = object : ReadOnlyProperty { + override fun getValue(thisRef: Any?, property: KProperty<*>): StringSymbol = StringSymbol(property.name) } public fun ExpressionAlgebra.binding(): ReadOnlyProperty = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index e66832fdb..5a9642690 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -244,7 +244,6 @@ public fun > simpleAutoDiff(field: F): AutoDiffProcessor Date: Wed, 28 Oct 2020 10:07:50 +0300 Subject: [PATCH 56/69] Fix did not work, rolled back. --- .../kscience/kmath/expressions/Expression.kt | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index 9743363c6..ab9ff0e72 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -3,7 +3,6 @@ package kscience.kmath.expressions import kscience.kmath.operations.Algebra import kotlin.jvm.JvmName import kotlin.properties.ReadOnlyProperty -import kotlin.reflect.KProperty /** * A marker interface for a symbol. A symbol mus have an identity @@ -85,11 +84,16 @@ public interface ExpressionAlgebra : Algebra { public fun ExpressionAlgebra.bind(symbol: Symbol): E = bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this") -public val symbol: ReadOnlyProperty = object : ReadOnlyProperty { - override fun getValue(thisRef: Any?, property: KProperty<*>): StringSymbol = StringSymbol(property.name) +/** + * A delegate to create a symbol with a string identity in this scope + */ +public val symbol: ReadOnlyProperty = ReadOnlyProperty { thisRef, property -> + StringSymbol(property.name) } -public fun ExpressionAlgebra.binding(): ReadOnlyProperty = - ReadOnlyProperty { _, property -> - bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist") - } \ No newline at end of file +/** + * Bind a symbol by name inside the [ExpressionAlgebra] + */ +public fun ExpressionAlgebra.binding(): ReadOnlyProperty = ReadOnlyProperty { _, property -> + bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist") +} \ No newline at end of file From 7f8abbdd206f41da85c834508efa53815e234183 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 29 Oct 2020 02:22:34 +0700 Subject: [PATCH 57/69] Fix typo, introduce KG protocol delegating to algebra --- .../kscience/kmath/ast/KotlingradSupport.kt | 8 +-- .../kmath/kotlingrad/ScalarsAdapters.kt | 57 ++++++++++++++----- .../kmath/kotlingrad/AdaptingTests.kt | 12 ++-- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index 366a2b4fd..c8478a631 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -2,10 +2,10 @@ package kscience.kmath.ast import edu.umontreal.kotlingrad.experimental.DoublePrecision import kscience.kmath.asm.compile -import kscience.kmath.kotlingrad.toMst -import kscience.kmath.kotlingrad.tSFun -import kscience.kmath.kotlingrad.toSVar import kscience.kmath.expressions.invoke +import kscience.kmath.kotlingrad.toMst +import kscience.kmath.kotlingrad.toSFun +import kscience.kmath.kotlingrad.toSVar import kscience.kmath.operations.RealField /** @@ -15,7 +15,7 @@ import kscience.kmath.operations.RealField fun main() { val proto = DoublePrecision.prototype val x by MstAlgebra.symbol("x").toSVar(proto) - val quadratic = "x^2-4*x-44".parseMath().tSFun(proto) + val quadratic = "x^2-4*x-44".parseMath().toSFun(proto) val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt index 99ab5e635..a9a8a14b2 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -2,8 +2,13 @@ package kscience.kmath.kotlingrad import edu.umontreal.kotlingrad.experimental.* import kscience.kmath.ast.MST +import kscience.kmath.ast.MstAlgebra +import kscience.kmath.ast.MstExpression import kscience.kmath.ast.MstExtendedField import kscience.kmath.ast.MstExtendedField.unaryMinus +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol import kscience.kmath.operations.* /** @@ -80,28 +85,52 @@ public fun > MST.Symbolic.toSVar(proto: X): SVar = SVar(proto, va * @param proto the prototype instance. * @return a scalar function. */ -public fun > MST.tSFun(proto: X): SFun = when (this) { +public fun > MST.toSFun(proto: X): SFun = when (this) { is MST.Numeric -> toSConst() is MST.Symbolic -> toSVar(proto) is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.tSFun(proto) - SpaceOperations.MINUS_OPERATION -> -value.tSFun(proto) - TrigonometricOperations.SIN_OPERATION -> sin(value.tSFun(proto)) - TrigonometricOperations.COS_OPERATION -> cos(value.tSFun(proto)) - TrigonometricOperations.TAN_OPERATION -> tan(value.tSFun(proto)) - PowerOperations.SQRT_OPERATION -> value.tSFun(proto).sqrt() - ExponentialOperations.EXP_OPERATION -> E() pow value.tSFun(proto) - ExponentialOperations.LN_OPERATION -> value.tSFun(proto).ln() + SpaceOperations.PLUS_OPERATION -> value.toSFun(proto) + SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto) + TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto)) + TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto)) + TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto)) + PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt() + ExponentialOperations.EXP_OPERATION -> E() pow value.toSFun(proto) + ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln() else -> error("Unary operation $operation not defined in $this") } is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> left.tSFun(proto) + right.tSFun(proto) - SpaceOperations.MINUS_OPERATION -> left.tSFun(proto) - right.tSFun(proto) - RingOperations.TIMES_OPERATION -> left.tSFun(proto) * right.tSFun(proto) - FieldOperations.DIV_OPERATION -> left.tSFun(proto) / right.tSFun(proto) - PowerOperations.POW_OPERATION -> left.tSFun(proto) pow (right as MST.Numeric).toSConst() + SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto) + SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto) + RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto) + FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto) + PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst() else -> error("Binary operation $operation not defined in $this") } } + +public class KMathNumber(public val algebra: A, value: T) : + RealNumber, T>(value) where T : Number, A : NumericAlgebra { + public override fun wrap(number: Number): SConst> = SConst(algebra.number(number)) + override val proto: KMathNumber by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) } +} + +public class KMathProtocol(algebra: A) : + Protocol>(KMathNumber(algebra, algebra.number(Double.NaN))) + where T : Number, A : NumericAlgebra + +public class DifferentiableMstExpression(public val algebra: A, public val mst: MST) : + DifferentiableExpression where A : NumericAlgebra, T : Number { + public val proto by lazy { KMathProtocol(algebra).prototype } + public val expr by lazy { MstExpression(algebra, mst) } + + public override fun invoke(arguments: Map): T = expr(arguments) + + public override fun derivativeOrNull(orders: Map): Expression { + val sfun = mst.toSFun(proto) + val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) } + TODO() + } +} diff --git a/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt index 25bdbf4be..682b0cf2e 100644 --- a/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt +++ b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt @@ -19,7 +19,7 @@ internal class AdaptingTests { fun symbol() { val c1 = MstAlgebra.symbol("x") assertTrue(c1.toSVar(proto).name == "x") - val c2 = "kitten".parseMath().tSFun(proto) + val c2 = "kitten".parseMath().toSFun(proto) if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() } @@ -27,15 +27,15 @@ internal class AdaptingTests { fun number() { val c1 = MstAlgebra.number(12354324) assertTrue(c1.toSConst().doubleValue == 12354324.0) - val c2 = "0.234".parseMath().tSFun(proto) + val c2 = "0.234".parseMath().toSFun(proto) if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() - val c3 = "1e-3".parseMath().tSFun(proto) + val c3 = "1e-3".parseMath().toSFun(proto) if (c3 is SConst) assertEquals(0.001, c3.value) else fail() } @Test fun simpleFunctionShape() { - val linear = "2*x+16".parseMath().tSFun(proto) + val linear = "2*x+16".parseMath().toSFun(proto) if (linear !is Sum) fail() if (linear.left !is Prod) fail() if (linear.right !is SConst) fail() @@ -44,7 +44,7 @@ internal class AdaptingTests { @Test fun simpleFunctionDerivative() { val x = MstAlgebra.symbol("x").toSVar(proto) - val quadratic = "x^2-4*x-44".parseMath().tSFun(proto) + val quadratic = "x^2-4*x-44".parseMath().toSFun(proto) val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) @@ -53,7 +53,7 @@ internal class AdaptingTests { @Test fun moreComplexDerivative() { val x = MstAlgebra.symbol("x").toSVar(proto) - val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().tSFun(proto) + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun(proto) val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile() val expectedDerivative = MstExpression( From 6f0f6577de5c555993d19ae3c2b19f5ab27f1e5b Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 29 Oct 2020 13:34:12 +0700 Subject: [PATCH 58/69] Refactor toSFun, update KG, delete KMath algebra protocol, update DifferentiableMstExpr. --- kmath-kotlingrad/build.gradle.kts | 2 +- .../kmath/kotlingrad/ScalarsAdapters.kt | 50 +++++++++---------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/kmath-kotlingrad/build.gradle.kts b/kmath-kotlingrad/build.gradle.kts index 0fe6e6b93..f2245c3d5 100644 --- a/kmath-kotlingrad/build.gradle.kts +++ b/kmath-kotlingrad/build.gradle.kts @@ -3,6 +3,6 @@ plugins { } dependencies { - api("com.github.breandan:kotlingrad:0.3.2") + api("com.github.breandan:kotlingrad:0.3.7") api(project(":kmath-ast")) } diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt index a9a8a14b2..24e8377bd 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -69,7 +69,7 @@ public fun > MST.Numeric.toSConst(): SConst = SConst(value) * @param proto the prototype instance. * @return a new variable. */ -public fun > MST.Symbolic.toSVar(proto: X): SVar = SVar(proto, value) +public fun > MST.Symbolic.toSVar(): SVar = SVar(value) /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. @@ -85,28 +85,28 @@ public fun > MST.Symbolic.toSVar(proto: X): SVar = SVar(proto, va * @param proto the prototype instance. * @return a scalar function. */ -public fun > MST.toSFun(proto: X): SFun = when (this) { +public fun > MST.toSFun(): SFun = when (this) { is MST.Numeric -> toSConst() - is MST.Symbolic -> toSVar(proto) + is MST.Symbolic -> toSVar() is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.toSFun(proto) - SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto) - TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto)) - TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto)) - TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto)) - PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt() - ExponentialOperations.EXP_OPERATION -> E() pow value.toSFun(proto) - ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln() + SpaceOperations.PLUS_OPERATION -> value.toSFun() + SpaceOperations.MINUS_OPERATION -> (-value).toSFun() + TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) + TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) + TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) + PowerOperations.SQRT_OPERATION -> value.toSFun().sqrt() + ExponentialOperations.EXP_OPERATION -> exp(value.toSFun()) + ExponentialOperations.LN_OPERATION -> value.toSFun().ln() else -> error("Unary operation $operation not defined in $this") } is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto) - SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto) - RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto) - FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto) - PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst() + SpaceOperations.PLUS_OPERATION -> left.toSFun() + right.toSFun() + SpaceOperations.MINUS_OPERATION -> left.toSFun() - right.toSFun() + RingOperations.TIMES_OPERATION -> left.toSFun() * right.toSFun() + FieldOperations.DIV_OPERATION -> left.toSFun() / right.toSFun() + PowerOperations.POW_OPERATION -> left.toSFun() pow (right as MST.Numeric).toSConst() else -> error("Binary operation $operation not defined in $this") } } @@ -114,23 +114,23 @@ public fun > MST.toSFun(proto: X): SFun = when (this) { public class KMathNumber(public val algebra: A, value: T) : RealNumber, T>(value) where T : Number, A : NumericAlgebra { public override fun wrap(number: Number): SConst> = SConst(algebra.number(number)) - override val proto: KMathNumber by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) } } -public class KMathProtocol(algebra: A) : - Protocol>(KMathNumber(algebra, algebra.number(Double.NaN))) - where T : Number, A : NumericAlgebra - public class DifferentiableMstExpression(public val algebra: A, public val mst: MST) : DifferentiableExpression where A : NumericAlgebra, T : Number { - public val proto by lazy { KMathProtocol(algebra).prototype } public val expr by lazy { MstExpression(algebra, mst) } - public override fun invoke(arguments: Map): T = expr(arguments) + public override fun invoke(arguments: Map): T = expr(arguments) public override fun derivativeOrNull(orders: Map): Expression { - val sfun = mst.toSFun(proto) - val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) } + TODO() + } + + public fun derivativeOrNull(orders: List): Expression { + orders.map { MstAlgebra.symbol(it.identity).toSVar>() } + .fold>, SFun>>(mst.toSFun()) { result, sVar -> result.d(sVar) } + .toMst() + TODO() } } From 57910f617ab71493a4299a62b4134505a07b4aab Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 29 Oct 2020 15:39:53 +0700 Subject: [PATCH 59/69] Rename API classes, update readme files --- README.md | 92 ++++-- build.gradle.kts | 11 +- docs/templates/README-TEMPLATE.md | 78 +++-- kmath-core/README.md | 6 +- kmath-core/build.gradle.kts | 12 +- .../kscience/kmath/expressions/Expression.kt | 20 +- .../kscience/kmath/structures/NDAlgebra.kt | 2 +- kmath-nd4j/build.gradle.kts | 26 ++ kmath-nd4j/docs/README-TEMPLATE.md | 43 +++ .../kscience.kmath.nd4j/INDArrayAlgebra.kt | 284 ----------------- .../kscience.kmath.nd4j/Nd4jArrayAlgebra.kt | 288 ++++++++++++++++++ ...ArrayIterators.kt => Nd4jArrayIterator.kt} | 22 +- ...rayStructures.kt => Nd4jArrayStructure.kt} | 28 +- ...AlgebraTest.kt => Nd4jArrayAlgebraTest.kt} | 8 +- ...ctureTest.kt => Nd4jArrayStructureTest.kt} | 2 +- 15 files changed, 534 insertions(+), 388 deletions(-) create mode 100644 kmath-nd4j/docs/README-TEMPLATE.md delete mode 100644 kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt create mode 100644 kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt rename kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/{INDArrayIterators.kt => Nd4jArrayIterator.kt} (63%) rename kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/{INDArrayStructures.kt => Nd4jArrayStructure.kt} (63%) rename kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/{INDArrayAlgebraTest.kt => Nd4jArrayAlgebraTest.kt} (79%) rename kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/{INDArrayStructureTest.kt => Nd4jArrayStructureTest.kt} (98%) diff --git a/README.md b/README.md index afab32dcf..2df9d3246 100644 --- a/README.md +++ b/README.md @@ -8,41 +8,50 @@ Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion) # KMath -Could be pronounced as `key-math`. -The Kotlin MATHematics library was initially intended as a Kotlin-based analog to Python's `numpy` library. Later we found that kotlin is much more flexible language and allows superior architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could be achieved with [kmath-for-real](/kmath-for-real) extension module. + +Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to +Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture +designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could +be achieved with [kmath-for-real](/kmath-for-real) extension module. ## Publications and talks + * [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2) * [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814) * [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103) # Goal -* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future). + +* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native). * Provide basic multiplatform implementations for those abstractions (without significant performance optimization). * Provide bindings and wrappers with those abstractions for popular optimized platform libraries. ## Non-goals -* Be like Numpy. It was the idea at the beginning, but we decided that we can do better in terms of API. -* Provide best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. + +* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API. +* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. * Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually. -* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better experience for those, who want to work with specific types. +* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like +for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better +experience for those, who want to work with specific types. ## Features -Actual feature list is [here](/docs/features.md) +Current feature list is [here](/docs/features.md) * **Algebra** - * Algebraic structures like rings, spaces and field (**TODO** add example to wiki) + * Algebraic structures like rings, spaces and fields (**TODO** add example to wiki) * Basic linear algebra operations (sums, products, etc.), backed by the `Space` API. - * Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays). + * Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and + N-dimensional arrays). * Advanced linear algebra operations like matrix inversion and LU decomposition. * **Array-like structures** Full support of many-dimensional array-like structures including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking). -* **Expressions** By writing a single mathematical expression -once, users will be able to apply different types of objects to the expression by providing a context. Expressions -can be used for a wide variety of purposes from high performance calculations to code generation. +* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of +objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high +performance calculations to code generation. * **Histograms** Fast multi-dimensional histograms. @@ -50,9 +59,10 @@ can be used for a wide variety of purposes from high performance calculations to * **Type-safe dimensions** Type-safe dimensions for matrix operations. -* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/) - library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free - to submit a feature request if you want something to be done first. +* **Commons-math wrapper** It is planned to gradually wrap most parts of +[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some +parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to +submit a feature request if you want something to be implemented first. ## Planned features @@ -151,6 +161,18 @@ can be used for a wide variety of purposes from high performance calculations to > **Maturity**: EXPERIMENTAL
+* ### [kmath-nd4j](kmath-nd4j) +> ND4J NDStructure implementation and according NDAlgebra classes +> +> **Maturity**: EXPERIMENTAL +> +> **Features:** +> - [nd4jarraystrucure](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray +> - [nd4jarrayrings](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long +> - [nd4jarrayfields](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double + +
+ * ### [kmath-stat](kmath-stat) > > @@ -166,39 +188,53 @@ can be used for a wide variety of purposes from high performance calculations to ## Multi-platform support -KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the [common module](/kmath-core/src/commonMain). Implementation is also done in the common module wherever possible. In some cases, features are delegated to platform-specific implementations even if they could be done in the common module for performance reasons. Currently, the JVM is the main focus of development, however Kotlin/Native and Kotlin/JS contributions are also welcome. +KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the +[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features +are delegated to platform-specific implementations even if they could be provided in the common module for performance +reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and +feedback are also welcome. ## Performance -Calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve both performance and flexibility. We expect to focus on creating convenient universal API first and then work on increasing performance for specific cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy. +Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve +both performance and flexibility. -### Dependency +We expect to focus on creating convenient universal API first and then work on increasing performance for specific +cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized +native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be +better than SciPy. -Release artifacts are accessible from bintray with following configuration (see documentation for [kotlin-multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) form more details): +### Repositories + +Release artifacts are accessible from bintray with following configuration (see documentation of +[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details): ```kotlin -repositories{ +repositories { maven("https://dl.bintray.com/mipt-npm/kscience") } -dependencies{ - api("kscience.kmath:kmath-core:0.2.0-dev-2") - //api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version +dependencies { + api("kscience.kmath:kmath-core:0.2.0-dev-3") + // api("kscience.kmath:kmath-core-jvm:0.2.0-dev-3") for jvm-specific version } ``` Gradle `6.0+` is required for multiplatform artifacts. -### Development +#### Development + +Development builds are uploaded to the separate repository: -Development builds are accessible from the reposirtory ```kotlin -repositories{ +repositories { maven("https://dl.bintray.com/mipt-npm/dev") } ``` -with the same artifact names. ## Contributing -The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero). \ No newline at end of file +The project requires a lot of additional work. The most important thing we need is a feedback about what features are +required the most. Feel free to create feature requests. We are also welcome to code contributions, +especially in issues marked with +[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label. diff --git a/build.gradle.kts b/build.gradle.kts index b03c03ab8..de0714543 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,16 +2,15 @@ plugins { id("ru.mipt.npm.project") } -val kmathVersion: String by extra("0.2.0-dev-3") -val bintrayRepo: String by extra("kscience") -val githubProject: String by extra("kmath") +internal val kmathVersion: String by extra("0.2.0-dev-3") +internal val bintrayRepo: String by extra("kscience") +internal val githubProject: String by extra("kmath") allprojects { repositories { jcenter() maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlinx") - mavenCentral() maven("https://dl.bintray.com/hotkeytlt/maven") } @@ -27,6 +26,6 @@ readme { readmeTemplate = file("docs/templates/README-TEMPLATE.md") } -apiValidation{ +apiValidation { validationDisabled = true -} \ No newline at end of file +} diff --git a/docs/templates/README-TEMPLATE.md b/docs/templates/README-TEMPLATE.md index 5117e0694..ee1df818c 100644 --- a/docs/templates/README-TEMPLATE.md +++ b/docs/templates/README-TEMPLATE.md @@ -8,41 +8,50 @@ Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion) # KMath -Could be pronounced as `key-math`. -The Kotlin MATHematics library was initially intended as a Kotlin-based analog to Python's `numpy` library. Later we found that kotlin is much more flexible language and allows superior architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could be achieved with [kmath-for-real](/kmath-for-real) extension module. + +Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to +Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture +designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could +be achieved with [kmath-for-real](/kmath-for-real) extension module. ## Publications and talks + * [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2) * [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814) * [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103) # Goal -* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future). + +* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native). * Provide basic multiplatform implementations for those abstractions (without significant performance optimization). * Provide bindings and wrappers with those abstractions for popular optimized platform libraries. ## Non-goals -* Be like Numpy. It was the idea at the beginning, but we decided that we can do better in terms of API. -* Provide best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. + +* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API. +* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. * Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually. -* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better experience for those, who want to work with specific types. +* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like +for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better +experience for those, who want to work with specific types. ## Features -Actual feature list is [here](/docs/features.md) +Current feature list is [here](/docs/features.md) * **Algebra** - * Algebraic structures like rings, spaces and field (**TODO** add example to wiki) + * Algebraic structures like rings, spaces and fields (**TODO** add example to wiki) * Basic linear algebra operations (sums, products, etc.), backed by the `Space` API. - * Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays). + * Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and + N-dimensional arrays). * Advanced linear algebra operations like matrix inversion and LU decomposition. * **Array-like structures** Full support of many-dimensional array-like structures including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking). -* **Expressions** By writing a single mathematical expression -once, users will be able to apply different types of objects to the expression by providing a context. Expressions -can be used for a wide variety of purposes from high performance calculations to code generation. +* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of +objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high +performance calculations to code generation. * **Histograms** Fast multi-dimensional histograms. @@ -50,9 +59,10 @@ can be used for a wide variety of purposes from high performance calculations to * **Type-safe dimensions** Type-safe dimensions for matrix operations. -* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/) - library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free - to submit a feature request if you want something to be done first. +* **Commons-math wrapper** It is planned to gradually wrap most parts of +[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some +parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to +submit a feature request if you want something to be implemented first. ## Planned features @@ -72,39 +82,53 @@ $modules ## Multi-platform support -KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the [common module](/kmath-core/src/commonMain). Implementation is also done in the common module wherever possible. In some cases, features are delegated to platform-specific implementations even if they could be done in the common module for performance reasons. Currently, the JVM is the main focus of development, however Kotlin/Native and Kotlin/JS contributions are also welcome. +KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the +[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features +are delegated to platform-specific implementations even if they could be provided in the common module for performance +reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and +feedback are also welcome. ## Performance -Calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve both performance and flexibility. We expect to focus on creating convenient universal API first and then work on increasing performance for specific cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy. +Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve +both performance and flexibility. -### Dependency +We expect to focus on creating convenient universal API first and then work on increasing performance for specific +cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized +native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be +better than SciPy. -Release artifacts are accessible from bintray with following configuration (see documentation for [kotlin-multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) form more details): +### Repositories + +Release artifacts are accessible from bintray with following configuration (see documentation of +[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details): ```kotlin -repositories{ +repositories { maven("https://dl.bintray.com/mipt-npm/kscience") } -dependencies{ +dependencies { api("kscience.kmath:kmath-core:$version") - //api("kscience.kmath:kmath-core-jvm:$version") for jvm-specific version + // api("kscience.kmath:kmath-core-jvm:$version") for jvm-specific version } ``` Gradle `6.0+` is required for multiplatform artifacts. -### Development +#### Development + +Development builds are uploaded to the separate repository: -Development builds are accessible from the reposirtory ```kotlin -repositories{ +repositories { maven("https://dl.bintray.com/mipt-npm/dev") } ``` -with the same artifact names. ## Contributing -The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero). \ No newline at end of file +The project requires a lot of additional work. The most important thing we need is a feedback about what features are +required the most. Feel free to create feature requests. We are also welcome to code contributions, +especially in issues marked with +[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label. diff --git a/kmath-core/README.md b/kmath-core/README.md index 5501b1d7a..42a513a10 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -12,7 +12,7 @@ The core features of KMath: > #### Artifact: > -> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`. +> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-3`. > > Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion) > @@ -30,7 +30,7 @@ The core features of KMath: > } > > dependencies { -> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2' +> implementation 'kscience.kmath:kmath-core:0.2.0-dev-3' > } > ``` > **Gradle Kotlin DSL:** @@ -44,6 +44,6 @@ The core features of KMath: > } > > dependencies { -> implementation("kscience.kmath:kmath-core:0.2.0-dev-2") +> implementation("kscience.kmath:kmath-core:0.2.0-dev-3") > } > ``` diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index b0849eca5..7f889d9b4 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -1,3 +1,5 @@ +import ru.mipt.npm.gradle.Maturity + plugins { id("ru.mipt.npm.mpp") id("ru.mipt.npm.native") @@ -11,36 +13,42 @@ kotlin.sourceSets.commonMain { readme { description = "Core classes, algebra definitions, basic linear algebra" - maturity = ru.mipt.npm.gradle.Maturity.DEVELOPMENT + maturity = Maturity.DEVELOPMENT propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) + feature( id = "algebras", description = "Algebraic structures: contexts and elements", ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt" ) + feature( id = "nd", description = "Many-dimensional structures", ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt" ) + feature( id = "buffers", description = "One-dimensional structure", ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt" ) + feature( id = "expressions", description = "Functional Expressions", ref = "src/commonMain/kotlin/kscience/kmath/expressions" ) + feature( id = "domains", description = "Domains", ref = "src/commonMain/kotlin/kscience/kmath/domains" ) + feature( id = "autodif", description = "Automatic differentiation", ref = "src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt" ) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index ab9ff0e72..568de255e 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -35,20 +35,27 @@ public fun interface Expression { } /** - * Invoke an expression without parameters + * Calls this expression without providing any arguments. + * + * @return a value. */ public operator fun Expression.invoke(): T = invoke(emptyMap()) -//This method exists to avoid resolution ambiguity of vararg methods /** * Calls this expression from arguments. * - * @param pairs the pair of arguments' names to values. - * @return the value. + * @param pairs the pairs of arguments to values. + * @return a value. */ @JvmName("callBySymbol") public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) +/** + * Calls this expression from arguments. + * + * @param pairs the pairs of arguments' names to values. + * @return a value. + */ @JvmName("callByString") public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) @@ -61,7 +68,6 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = * @param E type of the actual expression state */ public interface ExpressionAlgebra : Algebra { - /** * Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context. */ @@ -87,7 +93,7 @@ public fun ExpressionAlgebra.bind(symbol: Symbol): E = /** * A delegate to create a symbol with a string identity in this scope */ -public val symbol: ReadOnlyProperty = ReadOnlyProperty { thisRef, property -> +public val symbol: ReadOnlyProperty = ReadOnlyProperty { _, property -> StringSymbol(property.name) } @@ -96,4 +102,4 @@ public val symbol: ReadOnlyProperty = ReadOnlyProperty { */ public fun ExpressionAlgebra.binding(): ReadOnlyProperty = ReadOnlyProperty { _, property -> bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist") -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt index c1cfcbe49..d7b019c65 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt @@ -73,7 +73,7 @@ public interface NDAlgebra> { public fun check(vararg elements: N): Array = elements .map(NDStructure::shape) .singleOrNull { !shape.contentEquals(it) } - ?.let { throw ShapeMismatchException(shape, it) } + ?.let> { throw ShapeMismatchException(shape, it) } ?: elements /** diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts index 67569b870..953530b01 100644 --- a/kmath-nd4j/build.gradle.kts +++ b/kmath-nd4j/build.gradle.kts @@ -1,3 +1,5 @@ +import ru.mipt.npm.gradle.Maturity + plugins { id("ru.mipt.npm.jvm") } @@ -9,3 +11,27 @@ dependencies { testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") testImplementation("org.slf4j:slf4j-simple:1.7.30") } + +readme { + description = "ND4J NDStructure implementation and according NDAlgebra classes" + maturity = Maturity.EXPERIMENTAL + propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) + + feature( + id = "nd4jarraystrucure", + description = "NDStructure wrapper for INDArray", + ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt" + ) + + feature( + id = "nd4jarrayrings", + description = "Rings over Nd4jArrayStructure of Int and Long", + ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt" + ) + + feature( + id = "nd4jarrayfields", + description = "Fields over Nd4jArrayStructure of Float and Double", + ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt" + ) +} diff --git a/kmath-nd4j/docs/README-TEMPLATE.md b/kmath-nd4j/docs/README-TEMPLATE.md new file mode 100644 index 000000000..76ce8c9a7 --- /dev/null +++ b/kmath-nd4j/docs/README-TEMPLATE.md @@ -0,0 +1,43 @@ +# ND4J NDStructure implementation (`kmath-nd4j`) + +This subproject implements the following features: + +${features} + +${artifact} + +## Examples + +NDStructure wrapper for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.structures.* + +val array = Nd4j.ones(2, 2).asRealStructure() +println(array[0, 0]) // 1.0 +array[intArrayOf(0, 0)] = 24.0 +println(array[0, 0]) // 24.0 +``` + +Fast element-wise and in-place arithmetics for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.operations.* + +val field = RealNd4jArrayField(intArrayOf(2, 2)) +val array = Nd4j.rand(2, 2).asRealStructure() + +val res = field { + (25.0 / array + 20) * 4 +} + +println(res.ndArray) +// [[ 250.6449, 428.5840], +// [ 269.7913, 202.2077]] +``` + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt deleted file mode 100644 index 728ce3773..000000000 --- a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayAlgebra.kt +++ /dev/null @@ -1,284 +0,0 @@ -package kscience.kmath.nd4j - -import org.nd4j.linalg.api.ndarray.INDArray -import org.nd4j.linalg.factory.Nd4j -import kscience.kmath.operations.* -import kscience.kmath.structures.* - -/** - * Represents [NDAlgebra] over [INDArrayAlgebra]. - * - * @param T the type of ND-structure element. - * @param C the type of the element context. - */ -public interface INDArrayAlgebra : NDAlgebra> { - /** - * Wraps [INDArray] to [N]. - */ - public fun INDArray.wrap(): INDArrayStructure - - public override fun produce(initializer: C.(IntArray) -> T): INDArrayStructure { - val struct = Nd4j.create(*shape)!!.wrap() - struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } - return struct - } - - public override fun map(arg: INDArrayStructure, transform: C.(T) -> T): INDArrayStructure { - check(arg) - val newStruct = arg.ndArray.dup().wrap() - newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } - return newStruct - } - - public override fun mapIndexed( - arg: INDArrayStructure, - transform: C.(index: IntArray, T) -> T - ): INDArrayStructure { - check(arg) - val new = Nd4j.create(*shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) } - return new - } - - public override fun combine( - a: INDArrayStructure, - b: INDArrayStructure, - transform: C.(T, T) -> T - ): INDArrayStructure { - check(a, b) - val new = Nd4j.create(*shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } - return new - } -} - -/** - * Represents [NDSpace] over [INDArrayStructure]. - * - * @param T the type of the element contained in ND structure. - * @param S the type of space of structure elements. - */ -public interface INDArraySpace : NDSpace>, INDArrayAlgebra where S : Space { - public override val zero: INDArrayStructure - get() = Nd4j.zeros(*shape).wrap() - - public override fun add(a: INDArrayStructure, b: INDArrayStructure): INDArrayStructure { - check(a, b) - return a.ndArray.add(b.ndArray).wrap() - } - - public override operator fun INDArrayStructure.minus(b: INDArrayStructure): INDArrayStructure { - check(this, b) - return ndArray.sub(b.ndArray).wrap() - } - - public override operator fun INDArrayStructure.unaryMinus(): INDArrayStructure { - check(this) - return ndArray.neg().wrap() - } - - public override fun multiply(a: INDArrayStructure, k: Number): INDArrayStructure { - check(a) - return a.ndArray.mul(k).wrap() - } - - public override operator fun INDArrayStructure.div(k: Number): INDArrayStructure { - check(this) - return ndArray.div(k).wrap() - } - - public override operator fun INDArrayStructure.times(k: Number): INDArrayStructure { - check(this) - return ndArray.mul(k).wrap() - } -} - -/** - * Represents [NDRing] over [INDArrayStructure]. - * - * @param T the type of the element contained in ND structure. - * @param R the type of ring of structure elements. - */ -public interface INDArrayRing : NDRing>, INDArraySpace where R : Ring { - public override val one: INDArrayStructure - get() = Nd4j.ones(*shape).wrap() - - public override fun multiply(a: INDArrayStructure, b: INDArrayStructure): INDArrayStructure { - check(a, b) - return a.ndArray.mul(b.ndArray).wrap() - } - - public override operator fun INDArrayStructure.minus(b: Number): INDArrayStructure { - check(this) - return ndArray.sub(b).wrap() - } - - public override operator fun INDArrayStructure.plus(b: Number): INDArrayStructure { - check(this) - return ndArray.add(b).wrap() - } - - public override operator fun Number.minus(b: INDArrayStructure): INDArrayStructure { - check(b) - return b.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDField] over [INDArrayStructure]. - * - * @param T the type of the element contained in ND structure. - * @param N the type of ND structure. - * @param F the type field of structure elements. - */ -public interface INDArrayField : NDField>, INDArrayRing where F : Field { - public override fun divide(a: INDArrayStructure, b: INDArrayStructure): INDArrayStructure { - check(a, b) - return a.ndArray.div(b.ndArray).wrap() - } - - public override operator fun Number.div(b: INDArrayStructure): INDArrayStructure { - check(b) - return b.ndArray.rdiv(this).wrap() - } -} - -/** - * Represents [NDField] over [INDArrayRealStructure]. - */ -public class RealINDArrayField(public override val shape: IntArray) : INDArrayField { - public override val elementContext: RealField - get() = RealField - - public override fun INDArray.wrap(): INDArrayStructure = check(asRealStructure()) - - public override operator fun INDArrayStructure.div(arg: Double): INDArrayStructure { - check(this) - return ndArray.div(arg).wrap() - } - - public override operator fun INDArrayStructure.plus(arg: Double): INDArrayStructure { - check(this) - return ndArray.add(arg).wrap() - } - - public override operator fun INDArrayStructure.minus(arg: Double): INDArrayStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - public override operator fun INDArrayStructure.times(arg: Double): INDArrayStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - public override operator fun Double.div(arg: INDArrayStructure): INDArrayStructure { - check(arg) - return arg.ndArray.rdiv(this).wrap() - } - - public override operator fun Double.minus(arg: INDArrayStructure): INDArrayStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDField] over [INDArrayStructure] of [Float]. - */ -public class FloatINDArrayField(public override val shape: IntArray) : INDArrayField { - public override val elementContext: FloatField - get() = FloatField - - public override fun INDArray.wrap(): INDArrayStructure = check(asFloatStructure()) - - public override operator fun INDArrayStructure.div(arg: Float): INDArrayStructure { - check(this) - return ndArray.div(arg).wrap() - } - - public override operator fun INDArrayStructure.plus(arg: Float): INDArrayStructure { - check(this) - return ndArray.add(arg).wrap() - } - - public override operator fun INDArrayStructure.minus(arg: Float): INDArrayStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - public override operator fun INDArrayStructure.times(arg: Float): INDArrayStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - public override operator fun Float.div(arg: INDArrayStructure): INDArrayStructure { - check(arg) - return arg.ndArray.rdiv(this).wrap() - } - - public override operator fun Float.minus(arg: INDArrayStructure): INDArrayStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDRing] over [INDArrayIntStructure]. - */ -public class IntINDArrayRing(public override val shape: IntArray) : INDArrayRing { - public override val elementContext: IntRing - get() = IntRing - - public override fun INDArray.wrap(): INDArrayStructure = check(asIntStructure()) - - public override operator fun INDArrayStructure.plus(arg: Int): INDArrayStructure { - check(this) - return ndArray.add(arg).wrap() - } - - public override operator fun INDArrayStructure.minus(arg: Int): INDArrayStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - public override operator fun INDArrayStructure.times(arg: Int): INDArrayStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - public override operator fun Int.minus(arg: INDArrayStructure): INDArrayStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} - -/** - * Represents [NDRing] over [INDArrayStructure] of [Long]. - */ -public class LongINDArrayRing(public override val shape: IntArray) : INDArrayRing { - public override val elementContext: LongRing - get() = LongRing - - public override fun INDArray.wrap(): INDArrayStructure = check(asLongStructure()) - - public override operator fun INDArrayStructure.plus(arg: Long): INDArrayStructure { - check(this) - return ndArray.add(arg).wrap() - } - - public override operator fun INDArrayStructure.minus(arg: Long): INDArrayStructure { - check(this) - return ndArray.sub(arg).wrap() - } - - public override operator fun INDArrayStructure.times(arg: Long): INDArrayStructure { - check(this) - return ndArray.mul(arg).wrap() - } - - public override operator fun Long.minus(arg: INDArrayStructure): INDArrayStructure { - check(arg) - return arg.ndArray.rsub(this).wrap() - } -} diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt new file mode 100644 index 000000000..2093a3cb3 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt @@ -0,0 +1,288 @@ +package kscience.kmath.nd4j + +import kscience.kmath.operations.* +import kscience.kmath.structures.NDAlgebra +import kscience.kmath.structures.NDField +import kscience.kmath.structures.NDRing +import kscience.kmath.structures.NDSpace +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.factory.Nd4j + +/** + * Represents [NDAlgebra] over [Nd4jArrayAlgebra]. + * + * @param T the type of ND-structure element. + * @param C the type of the element context. + */ +public interface Nd4jArrayAlgebra : NDAlgebra> { + /** + * Wraps [INDArray] to [N]. + */ + public fun INDArray.wrap(): Nd4jArrayStructure + + public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure { + val struct = Nd4j.create(*shape)!!.wrap() + struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } + return struct + } + + public override fun map(arg: Nd4jArrayStructure, transform: C.(T) -> T): Nd4jArrayStructure { + check(arg) + val newStruct = arg.ndArray.dup().wrap() + newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } + return newStruct + } + + public override fun mapIndexed( + arg: Nd4jArrayStructure, + transform: C.(index: IntArray, T) -> T + ): Nd4jArrayStructure { + check(arg) + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) } + return new + } + + public override fun combine( + a: Nd4jArrayStructure, + b: Nd4jArrayStructure, + transform: C.(T, T) -> T + ): Nd4jArrayStructure { + check(a, b) + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } + return new + } +} + +/** + * Represents [NDSpace] over [Nd4jArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param S the type of space of structure elements. + */ +public interface Nd4jArraySpace : NDSpace>, + Nd4jArrayAlgebra where S : Space { + public override val zero: Nd4jArrayStructure + get() = Nd4j.zeros(*shape).wrap() + + public override fun add(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { + check(a, b) + return a.ndArray.add(b.ndArray).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { + check(this, b) + return ndArray.sub(b.ndArray).wrap() + } + + public override operator fun Nd4jArrayStructure.unaryMinus(): Nd4jArrayStructure { + check(this) + return ndArray.neg().wrap() + } + + public override fun multiply(a: Nd4jArrayStructure, k: Number): Nd4jArrayStructure { + check(a) + return a.ndArray.mul(k).wrap() + } + + public override operator fun Nd4jArrayStructure.div(k: Number): Nd4jArrayStructure { + check(this) + return ndArray.div(k).wrap() + } + + public override operator fun Nd4jArrayStructure.times(k: Number): Nd4jArrayStructure { + check(this) + return ndArray.mul(k).wrap() + } +} + +/** + * Represents [NDRing] over [Nd4jArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param R the type of ring of structure elements. + */ +public interface Nd4jArrayRing : NDRing>, Nd4jArraySpace where R : Ring { + public override val one: Nd4jArrayStructure + get() = Nd4j.ones(*shape).wrap() + + public override fun multiply(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { + check(a, b) + return a.ndArray.mul(b.ndArray).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { + check(this) + return ndArray.sub(b).wrap() + } + + public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { + check(this) + return ndArray.add(b).wrap() + } + + public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { + check(b) + return b.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDField] over [Nd4jArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param F the type field of structure elements. + */ +public interface Nd4jArrayField : NDField>, Nd4jArrayRing where F : Field { + public override fun divide(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { + check(a, b) + return a.ndArray.div(b.ndArray).wrap() + } + + public override operator fun Number.div(b: Nd4jArrayStructure): Nd4jArrayStructure { + check(b) + return b.ndArray.rdiv(this).wrap() + } +} + +/** + * Represents [NDField] over [Nd4jArrayRealStructure]. + */ +public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField { + public override val elementContext: RealField + get() = RealField + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asRealStructure()) + + public override operator fun Nd4jArrayStructure.div(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.div(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.plus(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Double.div(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + public override operator fun Double.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDField] over [Nd4jArrayStructure] of [Float]. + */ +public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField { + public override val elementContext: FloatField + get() = FloatField + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asFloatStructure()) + + public override operator fun Nd4jArrayStructure.div(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.div(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.plus(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Float.div(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + public override operator fun Float.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDRing] over [Nd4jArrayIntStructure]. + */ +public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing { + public override val elementContext: IntRing + get() = IntRing + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asIntStructure()) + + public override operator fun Nd4jArrayStructure.plus(arg: Int): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Int): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Int): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Int.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDRing] over [Nd4jArrayStructure] of [Long]. + */ +public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing { + public override val elementContext: LongRing + get() = LongRing + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asLongStructure()) + + public override operator fun Nd4jArrayStructure.plus(arg: Long): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Long): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Long): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Long.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayIterator.kt similarity index 63% rename from kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayIterators.kt rename to kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayIterator.kt index 9e7ef9e16..1463a92fe 100644 --- a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayIterators.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayIterator.kt @@ -3,7 +3,7 @@ package kscience.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.shape.Shape -private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Iterator { +private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iterator { private var i: Int = 0 override fun hasNext(): Boolean = i < iterateOver.length() @@ -18,9 +18,9 @@ private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Itera } } -internal fun INDArray.indicesIterator(): Iterator = INDArrayIndicesIterator(this) +internal fun INDArray.indicesIterator(): Iterator = Nd4jArrayIndicesIterator(this) -private sealed class INDArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { +private sealed class Nd4jArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { private var i: Int = 0 final override fun hasNext(): Boolean = i < iterateOver.length() @@ -37,26 +37,26 @@ private sealed class INDArrayIteratorBase(protected val iterateOver: INDArray } } -private class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class Nd4jArrayRealIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices) } -internal fun INDArray.realIterator(): Iterator> = INDArrayRealIterator(this) +internal fun INDArray.realIterator(): Iterator> = Nd4jArrayRealIterator(this) -private class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) } -internal fun INDArray.longIterator(): Iterator> = INDArrayLongIterator(this) +internal fun INDArray.longIterator(): Iterator> = Nd4jArrayLongIterator(this) -private class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray()) } -internal fun INDArray.intIterator(): Iterator> = INDArrayIntIterator(this) +internal fun INDArray.intIterator(): Iterator> = Nd4jArrayIntIterator(this) -private class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +private class Nd4jArrayFloatIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices) } -internal fun INDArray.floatIterator(): Iterator> = INDArrayFloatIterator(this) +internal fun INDArray.floatIterator(): Iterator> = Nd4jArrayFloatIterator(this) diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayStructure.kt similarity index 63% rename from kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayStructures.kt rename to kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayStructure.kt index 5d4e1a979..d47a293c3 100644 --- a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayStructure.kt @@ -1,15 +1,15 @@ package kscience.kmath.nd4j -import org.nd4j.linalg.api.ndarray.INDArray import kscience.kmath.structures.MutableNDStructure import kscience.kmath.structures.NDStructure +import org.nd4j.linalg.api.ndarray.INDArray /** * Represents a [NDStructure] wrapping an [INDArray] object. * * @param T the type of items. */ -public sealed class INDArrayStructure : MutableNDStructure { +public sealed class Nd4jArrayStructure : MutableNDStructure { /** * The wrapped [INDArray]. */ @@ -23,46 +23,46 @@ public sealed class INDArrayStructure : MutableNDStructure { public override fun elements(): Sequence> = Sequence(::elementsIterator) } -private data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure() { +private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.intIterator() override fun get(index: IntArray): Int = ndArray.getInt(*index) override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } } /** - * Wraps this [INDArray] to [INDArrayStructure]. + * Wraps this [INDArray] to [Nd4jArrayStructure]. */ -public fun INDArray.asIntStructure(): INDArrayStructure = INDArrayIntStructure(this) +public fun INDArray.asIntStructure(): Nd4jArrayStructure = Nd4jArrayIntStructure(this) -private data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure() { +private data class Nd4jArrayLongStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.longIterator() override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) } } /** - * Wraps this [INDArray] to [INDArrayStructure]. + * Wraps this [INDArray] to [Nd4jArrayStructure]. */ -public fun INDArray.asLongStructure(): INDArrayStructure = INDArrayLongStructure(this) +public fun INDArray.asLongStructure(): Nd4jArrayStructure = Nd4jArrayLongStructure(this) -private data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure() { +private data class Nd4jArrayRealStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.realIterator() override fun get(index: IntArray): Double = ndArray.getDouble(*index) override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } } /** - * Wraps this [INDArray] to [INDArrayStructure]. + * Wraps this [INDArray] to [Nd4jArrayStructure]. */ -public fun INDArray.asRealStructure(): INDArrayStructure = INDArrayRealStructure(this) +public fun INDArray.asRealStructure(): Nd4jArrayStructure = Nd4jArrayRealStructure(this) -private data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure() { +private data class Nd4jArrayFloatStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.floatIterator() override fun get(index: IntArray): Float = ndArray.getFloat(*index) override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } } /** - * Wraps this [INDArray] to [INDArrayStructure]. + * Wraps this [INDArray] to [Nd4jArrayStructure]. */ -public fun INDArray.asFloatStructure(): INDArrayStructure = INDArrayFloatStructure(this) +public fun INDArray.asFloatStructure(): Nd4jArrayStructure = Nd4jArrayFloatStructure(this) diff --git a/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt similarity index 79% rename from kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayAlgebraTest.kt rename to kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt index 1a4f4c9f3..650d5670c 100644 --- a/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayAlgebraTest.kt +++ b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt @@ -6,10 +6,10 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.fail -internal class INDArrayAlgebraTest { +internal class Nd4jArrayAlgebraTest { @Test fun testProduce() { - val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } + val res = (RealNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } val expected = (Nd4j.create(2, 2) ?: fail()).asRealStructure() expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 1)] = 1.0 @@ -20,7 +20,7 @@ internal class INDArrayAlgebraTest { @Test fun testMap() { - val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } } + val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } } val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 3 expected[intArrayOf(0, 1)] = 3 @@ -31,7 +31,7 @@ internal class INDArrayAlgebraTest { @Test fun testAdd() { - val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 } + val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 } val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 26 expected[intArrayOf(0, 1)] = 26 diff --git a/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt similarity index 98% rename from kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayStructureTest.kt rename to kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt index 63426d7f9..7e46211c1 100644 --- a/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt @@ -7,7 +7,7 @@ import kotlin.test.assertEquals import kotlin.test.assertNotEquals import kotlin.test.fail -internal class INDArrayStructureTest { +internal class Nd4jArrayStructureTest { @Test fun testElements() { val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! From 022b8f0fa347aff5bcba3aae6eabae52c60aa390 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 29 Oct 2020 15:44:30 +0700 Subject: [PATCH 60/69] Regenerate readme --- kmath-nd4j/README.md | 40 +++++++++++++++++++------------------ kmath-nd4j/build.gradle.kts | 2 +- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index fac24504a..b071df499 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -2,45 +2,48 @@ This subproject implements the following features: -- NDStructure wrapper for INDArray. -- Optimized NDRing implementations for INDArray storing Ints and Longs. -- Optimized NDField implementations for INDArray storing Floats and Doubles. + - [nd4jarraystrucure](src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray + - [nd4jarrayrings](src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long + - [nd4jarrayfields](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double + > #### Artifact: -> This module is distributed in the artifact `scientifik:kmath-nd4j:0.1.4-dev-8`. -> +> +> This module artifact: `kscience.kmath:kmath-nd4j:0.2.0-dev-3`. +> +> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-nd4j/_latestVersion) +> +> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-nd4j/_latestVersion) +> > **Gradle:** > > ```gradle > repositories { -> mavenCentral() -> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } +> maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } +> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } + > } > > dependencies { -> implementation 'scientifik:kmath-nd4j:0.1.4-dev-8' -> implementation 'org.nd4j:nd4j-native-platform:1.0.0-beta7' +> implementation 'kscience.kmath:kmath-nd4j:0.2.0-dev-3' > } > ``` > **Gradle Kotlin DSL:** > > ```kotlin > repositories { -> mavenCentral() -> maven("https://dl.bintray.com/mipt-npm/scientifik") +> maven("https://dl.bintray.com/kotlin/kotlin-eap") +> maven("https://dl.bintray.com/mipt-npm/kscience") > maven("https://dl.bintray.com/mipt-npm/dev") +> maven("https://dl.bintray.com/hotkeytlt/maven") > } > > dependencies { -> implementation("scientifik:kmath-nd4j:0.1.4-dev-8") -> implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") +> implementation("kscience.kmath:kmath-nd4j:0.2.0-dev-3") > } > ``` -> -> This distribution also needs an implementation of ND4J API. The ND4J Native Platform is usually the fastest one, so -> it is included to the snippet. -> ## Examples @@ -64,7 +67,7 @@ import org.nd4j.linalg.factory.* import scientifik.kmath.nd4j.* import scientifik.kmath.operations.* -val field = RealINDArrayField(intArrayOf(2, 2)) +val field = RealNd4jArrayField(intArrayOf(2, 2)) val array = Nd4j.rand(2, 2).asRealStructure() val res = field { @@ -76,5 +79,4 @@ println(res.ndArray) // [ 269.7913, 202.2077]] ``` - Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts index 953530b01..391727c45 100644 --- a/kmath-nd4j/build.gradle.kts +++ b/kmath-nd4j/build.gradle.kts @@ -18,7 +18,7 @@ readme { propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) feature( - id = "nd4jarraystrucure", + id = "nd4jarraystructure", description = "NDStructure wrapper for INDArray", ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt" ) From fbe1ab94a4127263b09b5182797deaeab5012e2c Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 29 Oct 2020 19:35:08 +0300 Subject: [PATCH 61/69] Change DifferentiableExpression API to use ordered symbol list instead of orders map. --- .../DerivativeStructureExpression.kt | 42 +++++++++++-------- .../DerivativeStructureExpressionTest.kt | 22 ++++++---- .../expressions/DifferentiableExpression.kt | 20 ++++----- .../kscience/kmath/expressions/Expression.kt | 2 +- 4 files changed, 48 insertions(+), 38 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index c593f5103..e4311a56b 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -12,46 +12,51 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure */ public class DerivativeStructureField( public val order: Int, - private val bindings: Map + bindings: Map, ) : ExtendedField, ExpressionAlgebra { - public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) } - public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) } + public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0) } + public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0, 1.0) } /** * A class that implements both [DerivativeStructure] and a [Symbol] */ - public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : - DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { + public inner class DerivativeStructureSymbol( + size: Int, + index: Int, + symbol: Symbol, + value: Double, + ) : DerivativeStructure(size, order, index, value), Symbol { override val identity: String = symbol.identity override fun toString(): String = identity override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity override fun hashCode(): Int = identity.hashCode() } + public val numberOfVariables: Int = bindings.size + /** * Identity-based symbol bindings map */ - private val variables: Map = bindings.entries.associate { (key, value) -> - key.identity to DerivativeStructureSymbol(key, value) - } + private val variables: Map = bindings.entries.mapIndexed { index, (key, value) -> + key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value) + }.toMap() - override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value) + override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, 0, value) public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) - //public fun Number.const(): DerivativeStructure = const(toDouble()) + override fun symbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value)) - public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double { - return derivative(mapOf(parameter to order)) + public fun DerivativeStructure.derivative(symbols: List): Double { + require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" } + val ordersCount = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray()) } - public fun DerivativeStructure.derivative(orders: Map): Double { - return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray()) - } + public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList()) - public fun DerivativeStructure.derivative(vararg orders: Pair): Double = derivative(mapOf(*orders)) public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) { @@ -97,6 +102,7 @@ public class DerivativeStructureField( } } + /** * A constructs that creates a derivative structure with required order on-demand */ @@ -109,7 +115,7 @@ public class DerivativeStructureExpression( /** * Get the derivative expression with given orders */ - public override fun derivativeOrNull(orders: Map): Expression = Expression { arguments -> - with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) } + public override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) } } } diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt index 8886e123f..7511a38ed 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt @@ -5,14 +5,15 @@ import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFails -internal inline fun diff( +internal inline fun diff( order: Int, vararg parameters: Pair, - block: DerivativeStructureField.() -> R, -): R { + block: DerivativeStructureField.() -> Unit, +): Unit { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return DerivativeStructureField(order, mapOf(*parameters)).run(block) + DerivativeStructureField(order, mapOf(*parameters)).run(block) } internal class AutoDiffTest { @@ -21,13 +22,16 @@ internal class AutoDiffTest { @Test fun derivativeStructureFieldTest() { - val res: Double = diff(3, x to 1.0, y to 1.0) { + diff(2, x to 1.0, y to 1.0) { val x = bind(x)//by binding() val y = symbol("y") - val z = x * (-sin(x * y) + y) - z.derivative(x) + val z = x * (-sin(x * y) + y) + 2.0 + println(z.derivative(x)) + println(z.derivative(y,x)) + assertEquals(z.derivative(x, y), z.derivative(y, x)) + //check that improper order cause failure + assertFails { z.derivative(x,x,y) } } - println(res) } @Test @@ -40,5 +44,7 @@ internal class AutoDiffTest { assertEquals(10.0, f(x to 1.0, y to 2.0)) assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0)) + assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0)) + assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0)) } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index 4fe73f283..ac1f4bc20 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -3,20 +3,18 @@ package kscience.kmath.expressions /** * An expression that provides derivatives */ -public interface DifferentiableExpression : Expression{ - public fun derivativeOrNull(orders: Map): Expression? +public interface DifferentiableExpression : Expression { + public fun derivativeOrNull(symbols: List): Expression? } -public fun DifferentiableExpression.derivative(orders: Map): Expression = - derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided") +public fun DifferentiableExpression.derivative(symbols: List): Expression = + derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") -public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = - derivative(mapOf(*orders)) - -public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1) +public fun DifferentiableExpression.derivative(vararg symbols: Symbol): Expression = + derivative(symbols.toList()) public fun DifferentiableExpression.derivative(name: String): Expression = - derivative(StringSymbol(name) to 1) + derivative(StringSymbol(name)) /** * A [DifferentiableExpression] that defines only first derivatives @@ -25,8 +23,8 @@ public abstract class FirstDerivativeExpression : DifferentiableExpression public abstract fun derivativeOrNull(symbol: Symbol): Expression? - public override fun derivativeOrNull(orders: Map): Expression? { - val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null + public override fun derivativeOrNull(symbols: List): Expression? { + val dSymbol = symbols.firstOrNull() ?: return null return derivativeOrNull(dSymbol) } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index ab9ff0e72..8f408e09e 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -87,7 +87,7 @@ public fun ExpressionAlgebra.bind(symbol: Symbol): E = /** * A delegate to create a symbol with a string identity in this scope */ -public val symbol: ReadOnlyProperty = ReadOnlyProperty { thisRef, property -> +public val symbol: ReadOnlyProperty = ReadOnlyProperty { _, property -> StringSymbol(property.name) } From 6f31ddba301b2cfb2a8ebbc900afb566b510f612 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 29 Oct 2020 19:50:45 +0300 Subject: [PATCH 62/69] Fix CM DerivativeStructureField constants --- .../expressions/DerivativeStructureExpression.kt | 10 +++++----- .../kmath/commons/optimization/OptimizeTest.kt | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index e4311a56b..244dc1314 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -14,8 +14,10 @@ public class DerivativeStructureField( public val order: Int, bindings: Map, ) : ExtendedField, ExpressionAlgebra { - public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0) } - public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0, 1.0) } + public val numberOfVariables: Int = bindings.size + + public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } + public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } /** * A class that implements both [DerivativeStructure] and a [Symbol] @@ -32,8 +34,6 @@ public class DerivativeStructureField( override fun hashCode(): Int = identity.hashCode() } - public val numberOfVariables: Int = bindings.size - /** * Identity-based symbol bindings map */ @@ -41,7 +41,7 @@ public class DerivativeStructureField( key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value) }.toMap() - override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, 0, value) + override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index 4384a5124..fa1978f95 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -6,7 +6,6 @@ import kscience.kmath.stat.Distribution import kscience.kmath.stat.Fitting import kscience.kmath.stat.RandomGenerator import kscience.kmath.stat.normal -import kscience.kmath.structures.asBuffer import org.junit.jupiter.api.Test import kotlin.math.pow @@ -53,7 +52,7 @@ internal class OptimizeTest { it.pow(2) + it + 1 + chain.nextDouble() } val yErr = x.map { sigma } - val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x -> + val chi2 = Fitting.chiSquared(x, y, yErr) { x -> val cWithDefault = bindOrNull(c) ?: one bind(a) * x.pow(2) + bind(b) * x + cWithDefault } From 095b165fa47effa6271274b24b953f47761300dc Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 29 Oct 2020 23:59:36 +0700 Subject: [PATCH 63/69] Uncomment expressions benchmark, and add factory methods for Nd4jRing and Nd4jField --- examples/build.gradle.kts | 20 ++- .../ast/ExpressionsInterpretersBenchmark.kt | 138 +++++++++--------- .../kscience/kmath/structures/NDField.kt | 13 ++ .../kscience.kmath.nd4j/Nd4jArrayAlgebra.kt | 61 ++++++++ 4 files changed, 161 insertions(+), 71 deletions(-) diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 9ba1ec5be..f35031140 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -19,7 +19,7 @@ repositories { sourceSets.register("benchmarks") dependencies { -// implementation(project(":kmath-ast")) + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) @@ -27,6 +27,20 @@ dependencies { implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) implementation(project(":kmath-ejml")) + implementation(project(":kmath-nd4j")) + implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7") + implementation("org.nd4j:nd4j-native:1.0.0-beta7") + +// uncomment if your system supports AVX2 +// val os = System.getProperty("os.name") +// +// if (System.getProperty("os.arch") in arrayOf("x86_64", "amd64")) when { +// os.startsWith("Windows") -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:windows-x86_64-avx2") +// os == "Linux" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:linux-x86_64-avx2") +// os == "Mac OS X" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:macosx-x86_64-avx2") +// } else + implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") + implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") implementation("org.slf4j:slf4j-simple:1.7.30") @@ -55,4 +69,6 @@ kotlin.sourceSets.all { } } -tasks.withType { kotlinOptions.jvmTarget = "11" } +tasks.withType { + kotlinOptions.jvmTarget = "11" +} diff --git a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index f0a32e5bd..35875747c 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -1,70 +1,70 @@ package kscience.kmath.ast -// -//import kscience.kmath.asm.compile -//import kscience.kmath.expressions.Expression -//import kscience.kmath.expressions.expressionInField -//import kscience.kmath.expressions.invoke -//import kscience.kmath.operations.Field -//import kscience.kmath.operations.RealField -//import kotlin.random.Random -//import kotlin.system.measureTimeMillis -// -//class ExpressionsInterpretersBenchmark { -// private val algebra: Field = RealField -// fun functionalExpression() { -// val expr = algebra.expressionInField { -// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) -// } -// -// invokeAndSum(expr) -// } -// -// fun mstExpression() { -// val expr = algebra.mstInField { -// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) -// } -// -// invokeAndSum(expr) -// } -// -// fun asmExpression() { -// val expr = algebra.mstInField { -// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) -// }.compile() -// -// invokeAndSum(expr) -// } -// -// private fun invokeAndSum(expr: Expression) { -// val random = Random(0) -// var sum = 0.0 -// -// repeat(1000000) { -// sum += expr("x" to random.nextDouble()) -// } -// -// println(sum) -// } -//} -// -//fun main() { -// val benchmark = ExpressionsInterpretersBenchmark() -// -// val fe = measureTimeMillis { -// benchmark.functionalExpression() -// } -// -// println("fe=$fe") -// -// val mst = measureTimeMillis { -// benchmark.mstExpression() -// } -// -// println("mst=$mst") -// -// val asm = measureTimeMillis { -// benchmark.asmExpression() -// } -// -// println("asm=$asm") -//} + +import kscience.kmath.asm.compile +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.expressionInField +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.Field +import kscience.kmath.operations.RealField +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + fun functionalExpression() { + val expr = algebra.expressionInField { + symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0) + } + + invokeAndSum(expr) + } + + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + } + + invokeAndSum(expr) + } + + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + }.compile() + + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + val random = Random(0) + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} + +fun main() { + val benchmark = ExpressionsInterpretersBenchmark() + + val fe = measureTimeMillis { + benchmark.functionalExpression() + } + + println("fe=$fe") + + val mst = measureTimeMillis { + benchmark.mstExpression() + } + + println("mst=$mst") + + val asm = measureTimeMillis { + benchmark.asmExpression() + } + + println("asm=$asm") +} diff --git a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt index 28bfab779..e53af0dee 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt @@ -1,8 +1,10 @@ package kscience.kmath.structures import kotlinx.coroutines.GlobalScope +import kscience.kmath.nd4j.Nd4jArrayField import kscience.kmath.operations.RealField import kscience.kmath.operations.invoke +import org.nd4j.linalg.factory.Nd4j import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.system.measureTimeMillis @@ -14,6 +16,8 @@ internal inline fun measureAndPrint(title: String, block: () -> Unit) { } fun main() { + // initializing Nd4j + Nd4j.zeros(0) val dim = 1000 val n = 1000 @@ -23,6 +27,8 @@ fun main() { val specializedField = NDField.real(dim, dim) //A generic boxing field. It should be used for objects, not primitives. val genericField = NDField.boxing(RealField, dim, dim) + // Nd4j specialized field. + val nd4jField = Nd4jArrayField.real(dim, dim) measureAndPrint("Automatic field addition") { autoField { @@ -43,6 +49,13 @@ fun main() { } } + measureAndPrint("Nd4j specialized addition") { + nd4jField { + var res = one + repeat(n) { res += 1.0 as Number } + } + } + measureAndPrint("Lazy addition") { val res = specializedField.one.mapAsync(GlobalScope) { var c = 0.0 diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt index 2093a3cb3..a8c874fc3 100644 --- a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt @@ -126,6 +126,36 @@ public interface Nd4jArrayRing : NDRing>, Nd4j check(b) return b.ndArray.rsub(this).wrap() } + + public companion object { + private val intNd4jArrayRingCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + private val longNd4jArrayRingCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + /** + * Creates an [NDRing] for [Int] values or pull it from cache if it was created previously. + */ + public fun int(vararg shape: Int): Nd4jArrayRing = + intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) } + + /** + * Creates an [NDRing] for [Long] values or pull it from cache if it was created previously. + */ + public fun long(vararg shape: Int): Nd4jArrayRing = + longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) } + + /** + * Creates a most suitable implementation of [NDRing] using reified class. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(vararg shape: Int): Nd4jArrayRing> = when { + T::class == Int::class -> int(*shape) as Nd4jArrayRing> + T::class == Long::class -> long(*shape) as Nd4jArrayRing> + else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.") + } + } } /** @@ -145,6 +175,37 @@ public interface Nd4jArrayField : NDField>, Nd check(b) return b.ndArray.rdiv(this).wrap() } + + + public companion object { + private val floatNd4jArrayFieldCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + private val realNd4jArrayFieldCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + /** + * Creates an [NDField] for [Float] values or pull it from cache if it was created previously. + */ + public fun float(vararg shape: Int): Nd4jArrayRing = + floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) } + + /** + * Creates an [NDField] for [Double] values or pull it from cache if it was created previously. + */ + public fun real(vararg shape: Int): Nd4jArrayRing = + realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) } + + /** + * Creates a most suitable implementation of [NDRing] using reified class. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(vararg shape: Int): Nd4jArrayField> = when { + T::class == Float::class -> float(*shape) as Nd4jArrayField> + T::class == Double::class -> real(*shape) as Nd4jArrayField> + else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.") + } + } } /** From 29a670483ba315a3adde3f6fabc2f69c3d1c09e8 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Fri, 30 Oct 2020 01:09:11 +0700 Subject: [PATCH 64/69] Update KG and Maven repos, delete symbol delegate provider, implement working differentiable mst expression based on SFun shape to MST conversion --- build.gradle.kts | 7 ++- examples/build.gradle.kts | 8 --- .../ast/ExpressionsInterpretersBenchmark.kt | 2 +- .../kscience/kmath/ast/KotlingradSupport.kt | 16 +++--- .../kscience/kmath/ast/MstExpression.kt | 25 +++++---- .../kotlin/kscience/kmath/ast/extensions.kt | 12 ---- .../jvmMain/kotlin/kscience/kmath/asm/asm.kt | 3 +- kmath-kotlingrad/build.gradle.kts | 3 +- .../kotlingrad/DifferentiableMstExpression.kt | 53 ++++++++++++++++++ .../kscience/kmath/kotlingrad/KMathNumber.kt | 18 ++++++ .../kmath/kotlingrad/ScalarsAdapters.kt | 56 ++++++++----------- 11 files changed, 125 insertions(+), 78 deletions(-) delete mode 100644 kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt create mode 100644 kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt create mode 100644 kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/KMathNumber.kt diff --git a/build.gradle.kts b/build.gradle.kts index 51ce48d84..095697bc4 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -9,10 +9,15 @@ internal val githubProject: String by extra("kmath") allprojects { repositories { jcenter() + maven("https://clojars.org/repo") + maven("https://dl.bintray.com/egor-bogomolov/astminer/") + maven("https://dl.bintray.com/hotkeytlt/maven") maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlinx") - maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://dl.bintray.com/mipt-npm/dev") + maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://jitpack.io") + mavenCentral() } group = "kscience.kmath" diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 3ca9bbb47..33018976d 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -8,14 +8,6 @@ plugins { } allOpen.annotation("org.openjdk.jmh.annotations.State") - -repositories { - maven("https://dl.bintray.com/mipt-npm/kscience") - maven("https://dl.bintray.com/mipt-npm/dev") - maven("https://dl.bintray.com/kotlin/kotlin-dev/") - mavenCentral() -} - sourceSets.register("benchmarks") dependencies { diff --git a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index b25a61e96..a4806ed68 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -13,7 +13,7 @@ internal class ExpressionsInterpretersBenchmark { private val algebra: Field = RealField fun functionalExpression() { val expr = algebra.expressionInField { - variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) + symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0) } invokeAndSum(expr) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index c8478a631..7b5e1565d 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -1,11 +1,9 @@ package kscience.kmath.ast -import edu.umontreal.kotlingrad.experimental.DoublePrecision import kscience.kmath.asm.compile import kscience.kmath.expressions.invoke -import kscience.kmath.kotlingrad.toMst -import kscience.kmath.kotlingrad.toSFun -import kscience.kmath.kotlingrad.toSVar +import kscience.kmath.expressions.symbol +import kscience.kmath.kotlingrad.DifferentiableMstExpression import kscience.kmath.operations.RealField /** @@ -13,10 +11,12 @@ import kscience.kmath.operations.RealField * valid derivative. */ fun main() { - val proto = DoublePrecision.prototype - val x by MstAlgebra.symbol("x").toSVar(proto) - val quadratic = "x^2-4*x-44".parseMath().toSFun(proto) - val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() + val x by symbol + + val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath()) + .derivativeOrNull(listOf(x)) + .compile() + val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt index 5ca75e993..f68e3f5f8 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -13,7 +13,7 @@ import kotlin.contracts.contract * @property mst the [MST] node. * @author Alexander Nozik */ -public class MstExpression(public val algebra: Algebra, public val mst: MST) : Expression { +public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) @@ -21,8 +21,9 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS override fun binaryOperation(operation: String, left: T, right: T): T = algebra.binaryOperation(operation, left, right) - override fun number(value: Number): T = if (algebra is NumericAlgebra) - algebra.number(value) + @Suppress("UNCHECKED_CAST") + override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) + (algebra as NumericAlgebra).number(value) else error("Numeric nodes are not supported by $this") } @@ -38,14 +39,14 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS public inline fun , E : Algebra> A.mst( mstAlgebra: E, block: E.() -> MST, -): MstExpression = MstExpression(this, mstAlgebra.block()) +): MstExpression = MstExpression(this, mstAlgebra.block()) /** * Builds [MstExpression] over [Space]. * * @author Alexander Nozik */ -public inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression { +public inline fun > A.mstInSpace(block: MstSpace.() -> MST): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return MstExpression(this, MstSpace.block()) } @@ -55,7 +56,7 @@ public inline fun Space.mstInSpace(block: MstSpace.() -> MS * * @author Alexander Nozik */ -public inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression { +public inline fun > A.mstInRing(block: MstRing.() -> MST): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return MstExpression(this, MstRing.block()) } @@ -65,7 +66,7 @@ public inline fun Ring.mstInRing(block: MstRing.() -> MST): * * @author Alexander Nozik */ -public inline fun Field.mstInField(block: MstField.() -> MST): MstExpression { +public inline fun > A.mstInField(block: MstField.() -> MST): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return MstExpression(this, MstField.block()) } @@ -75,7 +76,7 @@ public inline fun Field.mstInField(block: MstField.() -> MS * * @author Iaroslav Postovalov */ -public inline fun Field.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { +public inline fun > A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return MstExpression(this, MstExtendedField.block()) } @@ -85,7 +86,7 @@ public inline fun Field.mstInExtendedField(block: MstExtend * * @author Alexander Nozik */ -public inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression { +public inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return algebra.mstInSpace(block) } @@ -95,7 +96,7 @@ public inline fun > FunctionalExpressionSpace> FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression { +public inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return algebra.mstInRing(block) } @@ -105,7 +106,7 @@ public inline fun > FunctionalExpressionRing. * * @author Alexander Nozik */ -public inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression { +public inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return algebra.mstInField(block) } @@ -117,7 +118,7 @@ public inline fun > FunctionalExpressionField> FunctionalExpressionExtendedField.mstInExtendedField( block: MstExtendedField.() -> MST, -): MstExpression { +): MstExpression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return algebra.mstInExtendedField(block) } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt deleted file mode 100644 index b790a3a88..000000000 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt +++ /dev/null @@ -1,12 +0,0 @@ -package kscience.kmath.ast - -import kscience.kmath.operations.Algebra -import kotlin.properties.PropertyDelegateProvider -import kotlin.properties.ReadOnlyProperty - -/** - * Returns [PropertyDelegateProvider] providing [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name - * of the property. - */ -public val Algebra.symbol: PropertyDelegateProvider, ReadOnlyProperty, MST.Symbolic>> - get() = PropertyDelegateProvider { _, _ -> ReadOnlyProperty { _, p -> MST.Symbolic(p.name) } } diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt index 2b6fa6247..9ccfa464c 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt @@ -69,4 +69,5 @@ public inline fun Algebra.expression(mst: MST): Expression< * * @author Alexander Nozik. */ -public inline fun MstExpression.compile(): Expression = mst.compileWith(T::class.java, algebra) +public inline fun MstExpression>.compile(): Expression = + mst.compileWith(T::class.java, algebra) diff --git a/kmath-kotlingrad/build.gradle.kts b/kmath-kotlingrad/build.gradle.kts index f2245c3d5..027a03bc9 100644 --- a/kmath-kotlingrad/build.gradle.kts +++ b/kmath-kotlingrad/build.gradle.kts @@ -3,6 +3,7 @@ plugins { } dependencies { - api("com.github.breandan:kotlingrad:0.3.7") + implementation("com.github.breandan:kaliningraph:0.1.2") + implementation("com.github.breandan:kotlingrad:0.3.7") api(project(":kmath-ast")) } diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt new file mode 100644 index 000000000..88cc20639 --- /dev/null +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt @@ -0,0 +1,53 @@ +package kscience.kmath.kotlingrad + +import edu.umontreal.kotlingrad.experimental.SFun +import kscience.kmath.ast.MST +import kscience.kmath.ast.MstAlgebra +import kscience.kmath.ast.MstExpression +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Symbol +import kscience.kmath.operations.NumericAlgebra + +/** + * Represents wrapper of [MstExpression] implementing [DifferentiableExpression]. + * + * The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting + * [SFun] back to [MST]. + * + * @param T the type of number. + * @param A the [NumericAlgebra] of [T]. + * @property expr the underlying [MstExpression]. + */ +public inline class DifferentiableMstExpression(public val expr: MstExpression) : + DifferentiableExpression where A : NumericAlgebra, T : Number { + public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst)) + + /** + * The [MstExpression.algebra] of [expr]. + */ + public val algebra: A + get() = expr.algebra + + /** + * The [MstExpression.mst] of [expr]. + */ + public val mst: MST + get() = expr.mst + + public override fun invoke(arguments: Map): T = expr(arguments) + + public override fun derivativeOrNull(symbols: List): MstExpression = MstExpression( + algebra, + symbols.map(Symbol::identity) + .map(MstAlgebra::symbol) + .map { it.toSVar>() } + .fold(mst.toSFun(), SFun>::d) + .toMst(), + ) +} + +/** + * Wraps this [MstExpression] into [DifferentiableMstExpression]. + */ +public fun > MstExpression.differentiable(): DifferentiableMstExpression = + DifferentiableMstExpression(this) diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/KMathNumber.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/KMathNumber.kt new file mode 100644 index 000000000..ce5658137 --- /dev/null +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/KMathNumber.kt @@ -0,0 +1,18 @@ +package kscience.kmath.kotlingrad + +import edu.umontreal.kotlingrad.experimental.RealNumber +import edu.umontreal.kotlingrad.experimental.SConst +import kscience.kmath.operations.NumericAlgebra + +/** + * Implements [RealNumber] by delegating its functionality to [NumericAlgebra]. + * + * @param T the type of number. + * @param A the [NumericAlgebra] of [T]. + * @property algebra the algebra. + * @param value the value of this number. + */ +public class KMathNumber(public val algebra: A, value: T) : + RealNumber, T>(value) where T : Number, A : NumericAlgebra { + public override fun wrap(number: Number): SConst> = SConst(algebra.number(number)) +} diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt index 24e8377bd..b6effab4b 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -3,14 +3,26 @@ package kscience.kmath.kotlingrad import edu.umontreal.kotlingrad.experimental.* import kscience.kmath.ast.MST import kscience.kmath.ast.MstAlgebra -import kscience.kmath.ast.MstExpression import kscience.kmath.ast.MstExtendedField import kscience.kmath.ast.MstExtendedField.unaryMinus -import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.Expression -import kscience.kmath.expressions.Symbol import kscience.kmath.operations.* +/** + * Maps [SVar] to [MST.Symbolic] directly. + * + * @receiver the variable. + * @return a node. + */ +public fun > SVar.toMst(): MST.Symbolic = MstAlgebra.symbol(name) + +/** + * Maps [SVar] to [MST.Numeric] directly. + * + * @receiver the constant. + * @return a node. + */ +public fun > SConst.toMst(): MST.Numeric = MstAlgebra.number(doubleValue) + /** * Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then. * [Power] operation is limited to constant right-hand side arguments. @@ -37,8 +49,8 @@ import kscience.kmath.operations.* */ public fun > SFun.toMst(): MST = MstExtendedField { when (this@toMst) { - is SVar -> symbol(name) - is SConst -> number(doubleValue) + is SVar -> toMst() + is SConst -> toMst() is Sum -> left.toMst() + right.toMst() is Prod -> left.toMst() * right.toMst() is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue) @@ -69,7 +81,7 @@ public fun > MST.Numeric.toSConst(): SConst = SConst(value) * @param proto the prototype instance. * @return a new variable. */ -public fun > MST.Symbolic.toSVar(): SVar = SVar(value) +internal fun > MST.Symbolic.toSVar(): SVar = SVar(value) /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. @@ -90,12 +102,12 @@ public fun > MST.toSFun(): SFun = when (this) { is MST.Symbolic -> toSVar() is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.toSFun() - SpaceOperations.MINUS_OPERATION -> (-value).toSFun() + SpaceOperations.PLUS_OPERATION -> +value.toSFun() + SpaceOperations.MINUS_OPERATION -> -value.toSFun() TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) - PowerOperations.SQRT_OPERATION -> value.toSFun().sqrt() + PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun()) ExponentialOperations.EXP_OPERATION -> exp(value.toSFun()) ExponentialOperations.LN_OPERATION -> value.toSFun().ln() else -> error("Unary operation $operation not defined in $this") @@ -110,27 +122,3 @@ public fun > MST.toSFun(): SFun = when (this) { else -> error("Binary operation $operation not defined in $this") } } - -public class KMathNumber(public val algebra: A, value: T) : - RealNumber, T>(value) where T : Number, A : NumericAlgebra { - public override fun wrap(number: Number): SConst> = SConst(algebra.number(number)) -} - -public class DifferentiableMstExpression(public val algebra: A, public val mst: MST) : - DifferentiableExpression where A : NumericAlgebra, T : Number { - public val expr by lazy { MstExpression(algebra, mst) } - - public override fun invoke(arguments: Map): T = expr(arguments) - - public override fun derivativeOrNull(orders: Map): Expression { - TODO() - } - - public fun derivativeOrNull(orders: List): Expression { - orders.map { MstAlgebra.symbol(it.identity).toSVar>() } - .fold>, SFun>>(mst.toSFun()) { result, sVar -> result.d(sVar) } - .toMst() - - TODO() - } -} From bc4eb95ae7c05ff957092586d8823d6064dfc21f Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Fri, 30 Oct 2020 16:40:43 +0700 Subject: [PATCH 65/69] Add extension functions for DifferentiableMstExpression --- .../kotlin/kscience/kmath/ast/KotlingradSupport.kt | 3 ++- .../kmath/kotlingrad/DifferentiableMstExpression.kt | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index 7b5e1565d..9b34426f7 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -4,6 +4,7 @@ import kscience.kmath.asm.compile import kscience.kmath.expressions.invoke import kscience.kmath.expressions.symbol import kscience.kmath.kotlingrad.DifferentiableMstExpression +import kscience.kmath.kotlingrad.derivative import kscience.kmath.operations.RealField /** @@ -14,7 +15,7 @@ fun main() { val x by symbol val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath()) - .derivativeOrNull(listOf(x)) + .derivative(x) .compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt index 88cc20639..ddfc7cccb 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt @@ -5,6 +5,7 @@ import kscience.kmath.ast.MST import kscience.kmath.ast.MstAlgebra import kscience.kmath.ast.MstExpression import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.StringSymbol import kscience.kmath.expressions.Symbol import kscience.kmath.operations.NumericAlgebra @@ -46,6 +47,15 @@ public inline class DifferentiableMstExpression(public val expr: MstExpres ) } +public fun > DifferentiableMstExpression.derivative(symbols: List): MstExpression = + derivativeOrNull(symbols) + +public fun > DifferentiableMstExpression.derivative(vararg symbols: Symbol): MstExpression = + derivative(symbols.toList()) + +public fun > DifferentiableMstExpression.derivative(name: String): MstExpression = + derivative(StringSymbol(name)) + /** * Wraps this [MstExpression] into [DifferentiableMstExpression]. */ From ef7066b8c94d0c9450d2b1a0f1eeb79ffb27fdc9 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Fri, 30 Oct 2020 16:40:58 +0700 Subject: [PATCH 66/69] Update example --- .../src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index 9b34426f7..5fbb5b86a 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -5,6 +5,7 @@ import kscience.kmath.expressions.invoke import kscience.kmath.expressions.symbol import kscience.kmath.kotlingrad.DifferentiableMstExpression import kscience.kmath.kotlingrad.derivative +import kscience.kmath.kotlingrad.differentiable import kscience.kmath.operations.RealField /** @@ -14,7 +15,8 @@ import kscience.kmath.operations.RealField fun main() { val x by symbol - val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath()) + val actualDerivative = MstExpression(RealField, "x^2-4*x-44".parseMath()) + .differentiable() .derivative(x) .compile() From d14e4376595670f7ff9148f01d34fccd277f9e2b Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Fri, 30 Oct 2020 16:57:19 +0700 Subject: [PATCH 67/69] Update DifferentiableExpression by providing second type argument representing the result of differentiation --- .../kscience/kmath/ast/KotlingradSupport.kt | 1 - .../DerivativeStructureExpression.kt | 10 ++-- .../expressions/DifferentiableExpression.kt | 34 ++++++----- .../kscience/kmath/expressions/Expression.kt | 4 +- .../kmath/expressions/SimpleAutoDiff.kt | 56 +++++++++---------- .../kotlingrad/DifferentiableMstExpression.kt | 12 +--- 6 files changed, 57 insertions(+), 60 deletions(-) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index 5fbb5b86a..5acd97e3d 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -3,7 +3,6 @@ package kscience.kmath.ast import kscience.kmath.asm.compile import kscience.kmath.expressions.invoke import kscience.kmath.expressions.symbol -import kscience.kmath.kotlingrad.DifferentiableMstExpression import kscience.kmath.kotlingrad.derivative import kscience.kmath.kotlingrad.differentiable import kscience.kmath.operations.RealField diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 244dc1314..345babe8b 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -95,10 +95,10 @@ public class DerivativeStructureField( public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this - public companion object : AutoDiffProcessor { - override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression { - return DerivativeStructureExpression(function) - } + public companion object : + AutoDiffProcessor> { + public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression> = + DerivativeStructureExpression(function) } } @@ -108,7 +108,7 @@ public class DerivativeStructureField( */ public class DerivativeStructureExpression( public val function: DerivativeStructureField.() -> DerivativeStructure, -) : DifferentiableExpression { +) : DifferentiableExpression> { public override operator fun invoke(arguments: Map): Double = DerivativeStructureField(0, arguments).function().value diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index ac1f4bc20..890ad5f71 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -1,29 +1,37 @@ package kscience.kmath.expressions /** - * An expression that provides derivatives + * Represents expression which structure can be differentiated. + * + * @param T the type this expression takes as argument and returns. + * @param R the type of expression this expression can be differentiated to. */ -public interface DifferentiableExpression : Expression { - public fun derivativeOrNull(symbols: List): Expression? +public interface DifferentiableExpression> : Expression { + /** + * Differentiates this expression by ordered collection of [symbols]. + */ + public fun derivativeOrNull(symbols: List): R? } -public fun DifferentiableExpression.derivative(symbols: List): Expression = +public fun > DifferentiableExpression.derivative(symbols: List): R = derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") -public fun DifferentiableExpression.derivative(vararg symbols: Symbol): Expression = +public fun > DifferentiableExpression.derivative(vararg symbols: Symbol): R = derivative(symbols.toList()) -public fun DifferentiableExpression.derivative(name: String): Expression = +public fun > DifferentiableExpression.derivative(name: String): R = derivative(StringSymbol(name)) /** * A [DifferentiableExpression] that defines only first derivatives */ -public abstract class FirstDerivativeExpression : DifferentiableExpression { +public abstract class FirstDerivativeExpression> : DifferentiableExpression { + /** + * Returns first derivative of this expression by given [symbol]. + */ + public abstract fun derivativeOrNull(symbol: Symbol): R? - public abstract fun derivativeOrNull(symbol: Symbol): Expression? - - public override fun derivativeOrNull(symbols: List): Expression? { + public final override fun derivativeOrNull(symbols: List): R? { val dSymbol = symbols.firstOrNull() ?: return null return derivativeOrNull(dSymbol) } @@ -32,6 +40,6 @@ public abstract class FirstDerivativeExpression : DifferentiableExpression /** * A factory that converts an expression in autodiff variables to a [DifferentiableExpression] */ -public interface AutoDiffProcessor> { - public fun process(function: A.() -> I): DifferentiableExpression -} \ No newline at end of file +public fun interface AutoDiffProcessor, R : Expression> { + public fun process(function: A.() -> I): DifferentiableExpression +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index 568de255e..98940e767 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -22,7 +22,9 @@ public inline class StringSymbol(override val identity: String) : Symbol { } /** - * An elementary function that could be invoked on a map of arguments + * An elementary function that could be invoked on a map of arguments. + * + * @param T the type this expression takes as argument and returns. */ public fun interface Expression { /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index 5a9642690..e8a894d23 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -68,7 +68,7 @@ public fun > F.simpleAutoDiff( ): DerivationResult { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } - return SimpleAutoDiffField(this, bindings).derivate(body) + return SimpleAutoDiffField(this, bindings).differentiate(body) } public fun > F.simpleAutoDiff( @@ -83,12 +83,21 @@ public open class SimpleAutoDiffField>( public val context: F, bindings: Map, ) : Field>, ExpressionAlgebra> { + public override val zero: AutoDiffValue + get() = const(context.zero) + + public override val one: AutoDiffValue + get() = const(context.one) // this stack contains pairs of blocks and values to apply them to private var stack: Array = arrayOfNulls(8) private var sp: Int = 0 private val derivatives: MutableMap, T> = hashMapOf() + private val bindings: Map> = bindings.entries.associate { + it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero) + } + /** * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result * with respect to this variable. @@ -106,11 +115,7 @@ public open class SimpleAutoDiffField>( override fun hashCode(): Int = identity.hashCode() } - private val bindings: Map> = bindings.entries.associate { - it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero) - } - - override fun bindOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity] + public override fun bindOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity] private fun getDerivative(variable: AutoDiffValue): T = (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero @@ -119,7 +124,6 @@ public open class SimpleAutoDiffField>( if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value } - @Suppress("UNCHECKED_CAST") private fun runBackwardPass() { while (sp > 0) { @@ -129,9 +133,6 @@ public open class SimpleAutoDiffField>( } } - override val zero: AutoDiffValue get() = const(context.zero) - override val one: AutoDiffValue get() = const(context.one) - override fun const(value: T): AutoDiffValue = AutoDiffValue(value) /** @@ -165,7 +166,7 @@ public open class SimpleAutoDiffField>( } - internal fun derivate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult { + internal fun differentiate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult { val result = function() result.d = context.one // computing derivative w.r.t result runBackwardPass() @@ -174,41 +175,41 @@ public open class SimpleAutoDiffField>( // Overloads for Double constants - override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue = + public override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue = derive(const { this@plus.toDouble() * one + b.value }) { z -> b.d += z.d } - override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this) + public override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this) - override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue = + public override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue = derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } - override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = + public override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } // Basic math (+, -, *, /) - override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = + public override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value + b.value }) { z -> a.d += z.d b.d += z.d } - override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = + public override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value * b.value }) { z -> a.d += z.d * b.value b.d += z.d * a.value } - override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = + public override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value / b.value }) { z -> a.d += z.d / b.value b.d -= z.d * a.value / (b.value * b.value) } - override fun multiply(a: AutoDiffValue, k: Number): AutoDiffValue = + public override fun multiply(a: AutoDiffValue, k: Number): AutoDiffValue = derive(const { k.toDouble() * a.value }) { z -> a.d += z.d * k.toDouble() } @@ -220,15 +221,15 @@ public open class SimpleAutoDiffField>( public class SimpleAutoDiffExpression>( public val field: F, public val function: SimpleAutoDiffField.() -> AutoDiffValue, -) : FirstDerivativeExpression() { +) : FirstDerivativeExpression>() { public override operator fun invoke(arguments: Map): T { //val bindings = arguments.entries.map { it.key.bind(it.value) } return SimpleAutoDiffField(field, arguments).function().value } - override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> + public override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> //val bindings = arguments.entries.map { it.key.bind(it.value) } - val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function) + val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function) derivationResult.derivative(symbol) } } @@ -236,13 +237,10 @@ public class SimpleAutoDiffExpression>( /** * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] */ -public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField> { - return object : AutoDiffProcessor, SimpleAutoDiffField> { - override fun process(function: SimpleAutoDiffField.() -> AutoDiffValue): DifferentiableExpression { - return SimpleAutoDiffExpression(field, function) - } +public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField, Expression> = + AutoDiffProcessor { function -> + SimpleAutoDiffExpression(field, function) } -} // Extensions for differentiation of various basic mathematical functions @@ -392,4 +390,4 @@ public class SimpleAutoDiffExtendedField>( public override fun atanh(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).atanh(arg) -} \ No newline at end of file +} diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt index ddfc7cccb..dd5e46f90 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt @@ -5,7 +5,6 @@ import kscience.kmath.ast.MST import kscience.kmath.ast.MstAlgebra import kscience.kmath.ast.MstExpression import kscience.kmath.expressions.DifferentiableExpression -import kscience.kmath.expressions.StringSymbol import kscience.kmath.expressions.Symbol import kscience.kmath.operations.NumericAlgebra @@ -20,7 +19,7 @@ import kscience.kmath.operations.NumericAlgebra * @property expr the underlying [MstExpression]. */ public inline class DifferentiableMstExpression(public val expr: MstExpression) : - DifferentiableExpression where A : NumericAlgebra, T : Number { + DifferentiableExpression> where A : NumericAlgebra, T : Number { public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst)) /** @@ -47,15 +46,6 @@ public inline class DifferentiableMstExpression(public val expr: MstExpres ) } -public fun > DifferentiableMstExpression.derivative(symbols: List): MstExpression = - derivativeOrNull(symbols) - -public fun > DifferentiableMstExpression.derivative(vararg symbols: Symbol): MstExpression = - derivative(symbols.toList()) - -public fun > DifferentiableMstExpression.derivative(name: String): MstExpression = - derivative(StringSymbol(name)) - /** * Wraps this [MstExpression] into [DifferentiableMstExpression]. */ From 658a1703ed77f82ce19d2f7e7df89782e5307890 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sat, 31 Oct 2020 21:44:52 +0700 Subject: [PATCH 68/69] Add KDoc comment --- .../kscience/kmath/expressions/DifferentiableExpression.kt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index 890ad5f71..a15df1ac8 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -9,6 +9,9 @@ package kscience.kmath.expressions public interface DifferentiableExpression> : Expression { /** * Differentiates this expression by ordered collection of [symbols]. + * + * @param symbols the symbols. + * @return the derivative or `null`. */ public fun derivativeOrNull(symbols: List): R? } From 33d23c8d289784a00707c553f9c33fd7c64186b3 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 2 Nov 2020 01:08:55 +0700 Subject: [PATCH 69/69] Duplicate repositories declared in main build script, fix errors --- README.md | 16 +++++++++++++++ build.gradle.kts | 5 ++++- examples/build.gradle.kts | 14 +++++++++++++ .../kscience/kmath/ast/KotlingradSupport.kt | 2 +- .../optimization/CMOptimizationProblem.kt | 7 +++---- .../kmath/commons/optimization/cmFit.kt | 11 ++++------ .../commons/optimization/OptimizeTest.kt | 13 +++++++----- .../expressions/DifferentiableExpression.kt | 4 ++-- .../kmath/kotlingrad/AdaptingTests.kt | 20 +++++++++---------- .../kotlin/kscience/kmath/stat/Fitting.kt | 10 +++++++--- .../kmath/stat/OptimizationProblem.kt | 13 +++++------- settings.gradle.kts | 3 +-- 12 files changed, 74 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 2df9d3246..c4e3e5374 100644 --- a/README.md +++ b/README.md @@ -211,7 +211,15 @@ Release artifacts are accessible from bintray with following configuration (see ```kotlin repositories { + jcenter() + maven("https://clojars.org/repo") + maven("https://dl.bintray.com/egor-bogomolov/astminer/") + maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://dl.bintray.com/kotlin/kotlin-eap") + maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/mipt-npm/kscience") + maven("https://jitpack.io") + mavenCentral() } dependencies { @@ -228,7 +236,15 @@ Development builds are uploaded to the separate repository: ```kotlin repositories { + jcenter() + maven("https://clojars.org/repo") + maven("https://dl.bintray.com/egor-bogomolov/astminer/") + maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://dl.bintray.com/kotlin/kotlin-eap") + maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/mipt-npm/dev") + maven("https://jitpack.io") + mavenCentral() } ``` diff --git a/build.gradle.kts b/build.gradle.kts index 095697bc4..3514c91e6 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,3 +1,5 @@ +import ru.mipt.npm.gradle.KSciencePublishPlugin + plugins { id("ru.mipt.npm.project") } @@ -17,6 +19,7 @@ allprojects { maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://jitpack.io") + maven("http://logicrunch.research.it.uu.se/maven/") mavenCentral() } @@ -25,7 +28,7 @@ allprojects { } subprojects { - if (name.startsWith("kmath")) apply() + if (name.startsWith("kmath")) apply() } readme { diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 99828c621..d42627ff0 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -10,6 +10,20 @@ plugins { allOpen.annotation("org.openjdk.jmh.annotations.State") sourceSets.register("benchmarks") +repositories { + jcenter() + maven("https://clojars.org/repo") + maven("https://dl.bintray.com/egor-bogomolov/astminer/") + maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://dl.bintray.com/kotlin/kotlin-eap") + maven("https://dl.bintray.com/kotlin/kotlinx") + maven("https://dl.bintray.com/mipt-npm/dev") + maven("https://dl.bintray.com/mipt-npm/kscience") + maven("https://jitpack.io") + maven("http://logicrunch.research.it.uu.se/maven/") + mavenCentral() +} + dependencies { implementation(project(":kmath-ast")) implementation(project(":kmath-kotlingrad")) diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index 5acd97e3d..b3c827503 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -1,9 +1,9 @@ package kscience.kmath.ast import kscience.kmath.asm.compile +import kscience.kmath.expressions.derivative import kscience.kmath.expressions.invoke import kscience.kmath.expressions.symbol -import kscience.kmath.kotlingrad.derivative import kscience.kmath.kotlingrad.differentiable import kscience.kmath.operations.RealField diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt index 13f9af7bb..d6f79529a 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -19,9 +19,8 @@ import kotlin.reflect.KClass public operator fun PointValuePair.component1(): DoubleArray = point public operator fun PointValuePair.component2(): Double = value -public class CMOptimizationProblem( - override val symbols: List, -) : OptimizationProblem, SymbolIndexer, OptimizationFeature { +public class CMOptimizationProblem(override val symbols: List, ) : + OptimizationProblem, SymbolIndexer, OptimizationFeature { private val optimizationData: HashMap, OptimizationData> = HashMap() private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null public var convergenceChecker: ConvergenceChecker = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE, @@ -49,7 +48,7 @@ public class CMOptimizationProblem( addOptimizationData(objectiveFunction) } - public override fun diffExpression(expression: DifferentiableExpression): Unit { + public override fun diffExpression(expression: DifferentiableExpression>) { expression(expression) val gradientFunction = ObjectiveFunctionGradient { val args = it.toMap() diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt index 42475db6c..b8e8bfd4b 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt @@ -12,7 +12,6 @@ import kscience.kmath.structures.asBuffer import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import org.apache.commons.math3.optim.nonlinear.scalar.GoalType - /** * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation */ @@ -21,7 +20,7 @@ public fun Fitting.chiSquared( y: Buffer, yErr: Buffer, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, -): DifferentiableExpression = chiSquared(DerivativeStructureField, x, y, yErr, model) +): DifferentiableExpression> = chiSquared(DerivativeStructureField, x, y, yErr, model) /** * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation @@ -31,7 +30,7 @@ public fun Fitting.chiSquared( y: Iterable, yErr: Iterable, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, -): DifferentiableExpression = chiSquared( +): DifferentiableExpression> = chiSquared( DerivativeStructureField, x.toList().asBuffer(), y.toList().asBuffer(), @@ -39,7 +38,6 @@ public fun Fitting.chiSquared( model ) - /** * Optimize expression without derivatives */ @@ -48,16 +46,15 @@ public fun Expression.optimize( configuration: CMOptimizationProblem.() -> Unit, ): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) - /** * Optimize differentiable expression */ -public fun DifferentiableExpression.optimize( +public fun DifferentiableExpression>.optimize( vararg symbols: Symbol, configuration: CMOptimizationProblem.() -> Unit, ): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) -public fun DifferentiableExpression.minimize( +public fun DifferentiableExpression>.minimize( vararg startPoint: Pair, configuration: CMOptimizationProblem.() -> Unit = {}, ): OptimizationResult { diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index fa1978f95..3290c8f32 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -47,14 +47,17 @@ internal class OptimizeTest { val sigma = 1.0 val generator = Distribution.normal(0.0, sigma) val chain = generator.sample(RandomGenerator.default(112667)) - val x = (1..100).map { it.toDouble() } - val y = x.map { it -> + val x = (1..100).map(Int::toDouble) + + val y = x.map { it.pow(2) + it + 1 + chain.nextDouble() } - val yErr = x.map { sigma } - val chi2 = Fitting.chiSquared(x, y, yErr) { x -> + + val yErr = List(x.size) { sigma } + + val chi2 = Fitting.chiSquared(x, y, yErr) { x1 -> val cWithDefault = bindOrNull(c) ?: one - bind(a) * x.pow(2) + bind(b) * x + cWithDefault + bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault } val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt index a15df1ac8..abce9c4ec 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -6,7 +6,7 @@ package kscience.kmath.expressions * @param T the type this expression takes as argument and returns. * @param R the type of expression this expression can be differentiated to. */ -public interface DifferentiableExpression> : Expression { +public interface DifferentiableExpression> : Expression { /** * Differentiates this expression by ordered collection of [symbols]. * @@ -43,6 +43,6 @@ public abstract class FirstDerivativeExpression> : Differen /** * A factory that converts an expression in autodiff variables to a [DifferentiableExpression] */ -public fun interface AutoDiffProcessor, R : Expression> { +public fun interface AutoDiffProcessor, out R : Expression> { public fun process(function: A.() -> I): DifferentiableExpression } diff --git a/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt index 682b0cf2e..77902211b 100644 --- a/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt +++ b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt @@ -13,13 +13,11 @@ import kotlin.test.assertTrue import kotlin.test.fail internal class AdaptingTests { - private val proto: DReal = DoublePrecision.prototype - @Test fun symbol() { val c1 = MstAlgebra.symbol("x") - assertTrue(c1.toSVar(proto).name == "x") - val c2 = "kitten".parseMath().toSFun(proto) + assertTrue(c1.toSVar>().name == "x") + val c2 = "kitten".parseMath().toSFun>() if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() } @@ -27,15 +25,15 @@ internal class AdaptingTests { fun number() { val c1 = MstAlgebra.number(12354324) assertTrue(c1.toSConst().doubleValue == 12354324.0) - val c2 = "0.234".parseMath().toSFun(proto) + val c2 = "0.234".parseMath().toSFun>() if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() - val c3 = "1e-3".parseMath().toSFun(proto) + val c3 = "1e-3".parseMath().toSFun>() if (c3 is SConst) assertEquals(0.001, c3.value) else fail() } @Test fun simpleFunctionShape() { - val linear = "2*x+16".parseMath().toSFun(proto) + val linear = "2*x+16".parseMath().toSFun>() if (linear !is Sum) fail() if (linear.left !is Prod) fail() if (linear.right !is SConst) fail() @@ -43,8 +41,8 @@ internal class AdaptingTests { @Test fun simpleFunctionDerivative() { - val x = MstAlgebra.symbol("x").toSVar(proto) - val quadratic = "x^2-4*x-44".parseMath().toSFun(proto) + val x = MstAlgebra.symbol("x").toSVar>() + val quadratic = "x^2-4*x-44".parseMath().toSFun>() val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) @@ -52,8 +50,8 @@ internal class AdaptingTests { @Test fun moreComplexDerivative() { - val x = MstAlgebra.symbol("x").toSVar(proto) - val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun(proto) + val x = MstAlgebra.symbol("x").toSVar>() + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun>() val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile() val expectedDerivative = MstExpression( diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt index 01fdf4c5e..9d4655df2 100644 --- a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt @@ -12,16 +12,18 @@ public object Fitting { * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation */ public fun chiSquared( - autoDiff: AutoDiffProcessor, + autoDiff: AutoDiffProcessor>, x: Buffer, y: Buffer, yErr: Buffer, model: A.(I) -> I, - ): DifferentiableExpression where A : ExtendedField, A : ExpressionAlgebra { + ): DifferentiableExpression> where A : ExtendedField, A : ExpressionAlgebra { require(x.size == y.size) { "X and y buffers should be of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + return autoDiff.process { var sum = zero + x.indices.forEach { val xValue = const(x[it]) val yValue = const(y[it]) @@ -29,6 +31,7 @@ public object Fitting { val modelValue = model(xValue) sum += ((yValue - modelValue) / yErrValue).pow(2) } + sum } } @@ -45,6 +48,7 @@ public object Fitting { ): Expression { require(x.size == y.size) { "X and y buffers should be of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + return Expression { arguments -> x.indices.sumByDouble { val xValue = x[it] @@ -56,4 +60,4 @@ public object Fitting { } } } -} \ No newline at end of file +} diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt index ea522bff9..0f3cd9dd9 100644 --- a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt @@ -27,17 +27,17 @@ public interface OptimizationProblem { /** * Define the initial guess for the optimization problem */ - public fun initialGuess(map: Map): Unit + public fun initialGuess(map: Map) /** * Set an objective function expression */ - public fun expression(expression: Expression): Unit + public fun expression(expression: Expression) /** * Set a differentiable expression as objective function as function and gradient provider */ - public fun diffExpression(expression: DifferentiableExpression): Unit + public fun diffExpression(expression: DifferentiableExpression>) /** * Update the problem from previous optimization run @@ -50,9 +50,8 @@ public interface OptimizationProblem { public fun optimize(): OptimizationResult } -public interface OptimizationProblemFactory> { +public fun interface OptimizationProblemFactory> { public fun build(symbols: List): P - } public operator fun > OptimizationProblemFactory.invoke( @@ -60,7 +59,6 @@ public operator fun > OptimizationProblemFac block: P.() -> Unit, ): P = build(symbols).apply(block) - /** * Optimize expression without derivatives using specific [OptimizationProblemFactory] */ @@ -78,7 +76,7 @@ public fun > Expression.optimizeWith( /** * Optimize differentiable expression using specific [OptimizationProblemFactory] */ -public fun > DifferentiableExpression.optimizeWith( +public fun > DifferentiableExpression>.optimizeWith( factory: OptimizationProblemFactory, vararg symbols: Symbol, configuration: F.() -> Unit, @@ -88,4 +86,3 @@ public fun > DifferentiableExpression.op problem.diffExpression(this) return problem.optimize() } - diff --git a/settings.gradle.kts b/settings.gradle.kts index e825ddbdf..97dfe1b96 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,8 +1,7 @@ pluginManagement { repositories { - mavenLocal() - jcenter() gradlePluginPortal() + jcenter() maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/dev")