diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..2c99c7bc1 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,32 @@ +# KMath + +## [Unreleased] + +### Added +- Functional Expressions API +- Mathematical Syntax Tree, its interpreter and API +- String to MST parser (https://github.com/mipt-npm/kmath/pull/120) +- MST to JVM bytecode translator (https://github.com/mipt-npm/kmath/pull/94) +- FloatBuffer (specialized MutableBuffer over FloatArray) +- FlaggedBuffer to associate primitive numbers buffer with flags (to mark values infinite or missing, etc.) +- Specialized builder functions for all primitive buffers like `IntBuffer(25) { it + 1 }` (https://github.com/mipt-npm/kmath/pull/125) +- Interface `NumericAlgebra` where `number` operation is available to convert numbers to algebraic elements +- Inverse trigonometric functions support in ExtendedField (`asin`, `acos`, `atan`) (https://github.com/mipt-npm/kmath/pull/114) +- New space extensions: `average` and `averageWith` +- Local coding conventions +- Geometric Domains API in `kmath-core` +- Blocking chains in `kmath-coroutines` + +### Changed +- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations +- `power(T, Int)` extension function has preconditions and supports `Field` +- Memory objects have more preconditions (overflow checking) +- `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114) +- Gradle version: 6.3 -> 6.5.1 +- Moved probability distributions to commons-rng and to `kmath-prob`. + +### Fixed +- Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106) +- D3.dim value in `kmath-dimensions` +- Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101) +- Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93) diff --git a/build.gradle.kts b/build.gradle.kts index 6ab33d31c..8a2ba3617 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,8 +1,8 @@ plugins { - id("scientifik.publish") version "0.4.2" apply false + id("scientifik.publish") apply false } -val kmathVersion by extra("0.1.4-dev-4") +val kmathVersion by extra("0.1.4-dev-8") val bintrayRepo by extra("scientifik") val githubProject by extra("kmath") @@ -11,6 +11,7 @@ allprojects { repositories { jcenter() maven("https://dl.bintray.com/kotlin/kotlinx") + maven("https://dl.bintray.com/hotkeytlt/maven") } group = "scientifik" diff --git a/doc/algebra.md b/doc/algebra.md index 015f4fc82..b1b77a31f 100644 --- a/doc/algebra.md +++ b/doc/algebra.md @@ -1,110 +1,124 @@ -# Algebra and algebra elements +# Algebraic Structures and Algebraic Elements -The mathematical operations in `kmath` are generally separated from mathematical objects. -This means that in order to perform an operation, say `+`, one needs two objects of a type `T` and -and algebra context which defines appropriate operation, say `Space`. Next one needs to run actual operation -in the context: +The mathematical operations in KMath are generally separated from mathematical objects. This means that to perform an +operation, say `+`, one needs two objects of a type `T` and an algebra context, which draws appropriate operation up, +say `Space`. Next one needs to run the actual operation in the context: ```kotlin -val a: T -val b: T -val space: Space +import scientifik.kmath.operations.* -val c = space.run{a + b} +val a: T = ... +val b: T = ... +val space: Space = ... + +val c = space { a + b } ``` -From the first glance, this distinction seems to be a needless complication, but in fact one needs -to remember that in mathematics, one could define different operations on the same objects. For example, -one could use different types of geometry for vectors. +At first glance, this distinction seems to be a needless complication, but in fact one needs to remember that in +mathematics, one could draw up different operations on same objects. For example, one could use different types of +geometry for vectors. -## Algebra hierarchy +## Algebraic Structures Mathematical contexts have the following hierarchy: -**Space** <- **Ring** <- **Field** +**Algebra** ← **Space** ← **Ring** ← **Field** -All classes follow abstract mathematical constructs. -[Space](http://mathworld.wolfram.com/Space.html) defines `zero` element, addition operation and multiplication by constant, -[Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and unit `one` element, -[Field](http://mathworld.wolfram.com/Field.html) adds division operation. +These interfaces follow real algebraic structures: -Typical case of `Field` is the `RealField` which works on doubles. And typical case of `Space` is a `VectorSpace`. +- [Space](https://mathworld.wolfram.com/VectorSpace.html) defines addition, its neutral element (i.e. 0) and scalar +multiplication; +- [Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and its neutral element (i.e. 1); +- [Field](http://mathworld.wolfram.com/Field.html) adds division operation. -In some cases algebra context could hold additional operation like `exp` or `sin`, in this case it inherits appropriate -interface. Also a context could have an operation which produces an element outside of its context. For example -`Matrix` `dot` operation produces a matrix with new dimensions which can be incompatible with initial matrix in -terms of linear operations. +A typical implementation of `Field` is the `RealField` which works on doubles, and `VectorSpace` for `Space`. -## Algebra element +In some cases algebra context can hold additional operations like `exp` or `sin`, and then it inherits appropriate +interface. Also, contexts may have operations, which produce elements outside of the context. For example, `Matrix.dot` +operation produces a matrix with new dimensions, which can be incompatible with initial matrix in terms of linear +operations. -In order to achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving contexts -`kmath` introduces special type objects called `MathElement`. A `MathElement` is basically some object coupled to +## Algebraic Element + +To achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving +contexts KMath submits special type objects called `MathElement`. A `MathElement` is basically some object coupled to a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts, -but it also holds reference to the `ComplexField` singleton which allows to perform direct operations on `Complex` +but it also holds reference to the `ComplexField` singleton, which allows performing direct operations on `Complex` numbers without explicit involving the context like: ```kotlin - val c1 = Complex(1.0, 1.0) - val c2 = Complex(1.0, -1.0) - val c3 = c1 + c2 + 3.0.toComplex() - //or with field notation: - val c4 = ComplexField.run{c1 + i - 2.0} +import scientifik.kmath.operations.* + +// Using elements +val c1 = Complex(1.0, 1.0) +val c2 = Complex(1.0, -1.0) +val c3 = c1 + c2 + 3.0.toComplex() + +// Using context +val c4 = ComplexField { c1 + i - 2.0 } ``` Both notations have their pros and cons. -The hierarchy for algebra elements follows the hierarchy for the corresponding algebra. +The hierarchy for algebraic elements follows the hierarchy for the corresponding algebraic structures. -**MathElement** <- **SpaceElement** <- **RingElement** <- **FieldElement** +**MathElement** ← **SpaceElement** ← **RingElement** ← **FieldElement** -**MathElement** is the generic common ancestor of the class with context. +`MathElement` is the generic common ancestor of the class with context. -One important distinction between algebra elements and algebra contexts is that algebra element has three type parameters: +One major distinction between algebraic elements and algebraic contexts is that elements have three type +parameters: -1. The type of elements, field operates on. -2. The self-type of the element returned from operation (must be algebra element). +1. The type of elements, the field operates on. +2. The self-type of the element returned from operation (which has to be an algebraic element). 3. The type of the algebra over first type-parameter. -The middle type is needed in case algebra members do not store context. For example, it is not possible to add -a context to regular `Double`. The element performs automatic conversions from context types and back. -One should used context operations in all important places. The performance of element operations is not guaranteed. +The middle type is needed for of algebra members do not store context. For example, it is impossible to add a context +to regular `Double`. The element performs automatic conversions from context types and back. One should use context +operations in all performance-critical places. The performance of element operations is not guaranteed. -## Spaces and fields +## Spaces and Fields -An obvious first choice of mathematical objects to implement in a context-oriented style are algebraic elements like spaces, -rings and fields. Those are located in the `scientifik.kmath.operations.Algebra.kt` file. Alongside common contexts, the file includes definitions for algebra elements like `FieldElement`. A `FieldElement` object -stores a reference to the `Field` which contains additive and multiplicative operations, meaning -it has one fixed context attached and does not require explicit external context. So those `MathElements` can be operated without context: +KMath submits both contexts and elements for builtin algebraic structures: ```kotlin +import scientifik.kmath.operations.* + val c1 = Complex(1.0, 2.0) val c2 = ComplexField.i + val c3 = c1 + c2 +// or +val c3 = ComplexField { c1 + c2 } ``` -`ComplexField` also features special operations to mix complex and real numbers, for example: +Also, `ComplexField` features special operations to mix complex and real numbers, for example: ```kotlin +import scientifik.kmath.operations.* + val c1 = Complex(1.0, 2.0) -val c2 = ComplexField.run{ c1 - 1.0} // Returns: [re:0.0, im: 2.0] -val c3 = ComplexField.run{ c1 - i*2.0} +val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0) +val c3 = ComplexField { c1 - i * 2.0 } ``` -**Note**: In theory it is possible to add behaviors directly to the context, but currently kotlin syntax does not support -that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and [KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates. +**Note**: In theory it is possible to add behaviors directly to the context, but as for now Kotlin does not support +that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and +[KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates. ## Nested fields -Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex elements like so: +Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex +elements like so: ```kotlin -val element = NDElement.complex(shape = intArrayOf(2,2)){ index: IntArray -> +val element = NDElement.complex(shape = intArrayOf(2, 2)) { index: IntArray -> Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) } ``` -The `element` in this example is a member of the `Field` of 2-d structures, each element of which is a member of its own -`ComplexField`. The important thing is one does not need to create a special n-d class to hold complex +The `element` in this example is a member of the `Field` of 2D structures, each element of which is a member of its own +`ComplexField`. It is important one does not need to create a special n-d class to hold complex numbers and implement operations on it, one just needs to provide a field for its elements. **Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts like diff --git a/doc/buffers.md b/doc/buffers.md index b0b7489b3..679bd4e78 100644 --- a/doc/buffers.md +++ b/doc/buffers.md @@ -1,8 +1,9 @@ # Buffers + Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). There are different types of buffers: -* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. +* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays. * Boxing `ListBuffer` wrapping a list * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * `MemoryBuffer` allows direct allocation of objects in continuous memory block. @@ -12,4 +13,5 @@ Some kmath features require a `BufferFactory` class to operate properly. A gener buffer for given reified type (for types with custom memory buffer it still better to use their own `MemoryBuffer.create()` factory). ## Buffer performance -One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead \ No newline at end of file + +One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead diff --git a/doc/codestyle.md b/doc/codestyle.md new file mode 100644 index 000000000..541dc4973 --- /dev/null +++ b/doc/codestyle.md @@ -0,0 +1,34 @@ +# Coding Conventions + +KMath code follows general [Kotlin conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but +with a number of small changes and clarifications. + +## Utility Class Naming + +Filename should coincide with a name of one of the classes contained in the file or start with small letter and +describe its contents. + +The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that +file names should start with a capital letter even if file does not contain classes. Yet starting utility classes and +aggregators with a small letter seems to be a good way to visually separate those files. + +This convention could be changed in future in a non-breaking way. + +## Private Variable Naming + +Private variables' names may start with underscore `_` for of the private mutable variable is shadowed by the public +read-only value with the same meaning. + +This rule does not permit underscores in names, but it is sometimes useful to "underscore" the fact that public and +private versions draw up the same entity. It is allowed only for private variables. + +This convention could be changed in future in a non-breaking way. + +## Functions and Properties One-liners + +Use one-liners when they occupy single code window line both for functions and properties with getters like +`val b: String get() = "fff"`. The same should be performed with multiline expressions when they could be +cleanly separated. + +There is no universal consensus whenever use `fun a() = ...` or `fun a() { return ... }`. Yet from reader outlook +one-lines seem to better show that the property or function is easily calculated. diff --git a/doc/linear.md b/doc/linear.md index bbcc435ba..883df275e 100644 --- a/doc/linear.md +++ b/doc/linear.md @@ -1,6 +1,6 @@ ## Basic linear algebra layout -Kmath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared +KMath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple back-ends. The new operations added as extensions to contexts instead of being member functions of data structures. diff --git a/doc/nd-structure.md b/doc/nd-structure.md index cf13c6a29..835304b9f 100644 --- a/doc/nd-structure.md +++ b/doc/nd-structure.md @@ -1,4 +1,4 @@ -# Nd-structure generation and operations +# ND-structure generation and operations **TODO** diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 8853b78a5..73def3572 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.71" - id("kotlinx.benchmark") version "0.2.0-dev-7" + kotlin("plugin.allopen") version "1.3.72" + id("kotlinx.benchmark") version "0.2.0-dev-8" } configure { @@ -24,16 +24,18 @@ sourceSets { } dependencies { + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) + implementation(project(":kmath-prob")) implementation(project(":kmath-koma")) implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) implementation("com.kyonifer:koma-core-ejml:0.12") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") - implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7") - "benchmarksCompile"(sourceSets.main.get().compileClasspath) + implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8") + "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath } // Configure benchmark @@ -57,6 +59,6 @@ benchmark { tasks.withType { kotlinOptions { - jvmTarget = Scientifik.JVM_VERSION + jvmTarget = Scientifik.JVM_TARGET.toString() } } \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt index 9676b5e4a..e40b0c4b7 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt @@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex class BufferBenchmark { @Benchmark - fun genericDoubleBufferReadWrite() { - val buffer = DoubleBuffer(size){it.toDouble()} + fun genericRealBufferReadWrite() { + val buffer = RealBuffer(size){it.toDouble()} (0 until size).forEach { buffer[it] diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index be4115d81..f7b9661ef 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -20,48 +20,39 @@ class ViktorBenchmark { final val viktorField = ViktorNDField(intArrayOf(dim, dim)) @Benchmark - fun `Automatic field addition`() { + fun automaticFieldAddition() { autoField.run { var res = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += one } } } @Benchmark - fun `Viktor field addition`() { + fun viktorFieldAddition() { viktorField.run { var res = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } @Benchmark - fun `Raw Viktor`() { + fun rawViktor() { val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) var res = one - repeat(n) { - res = res + one - } + repeat(n) { res = res + one } } @Benchmark - fun `Real field log`() { + fun realdFieldLog() { realField.run { val fortyTwo = produce { 42.0 } var res = one - - repeat(n) { - res = ln(fortyTwo) - } + repeat(n) { res = ln(fortyTwo) } } } @Benchmark - fun `Raw Viktor log`() { + fun rawViktorLog() { val fortyTwo = F64Array.full(dim, dim, init = 42.0) var res: F64Array repeat(n) { diff --git a/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt new file mode 100644 index 000000000..17a70a4aa --- /dev/null +++ b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -0,0 +1,70 @@ +package scientifik.kmath.ast + +import scientifik.kmath.asm.compile +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.expressionInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.RealField +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + fun functionalExpression() { + val expr = algebra.expressionInField { + variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) + } + + invokeAndSum(expr) + } + + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + } + + invokeAndSum(expr) + } + + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + }.compile() + + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + val random = Random(0) + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} + +fun main() { + val benchmark = ExpressionsInterpretersBenchmark() + + val fe = measureTimeMillis { + benchmark.functionalExpression() + } + + println("fe=$fe") + + val mst = measureTimeMillis { + benchmark.mstExpression() + } + + println("mst=$mst") + + val asm = measureTimeMillis { + benchmark.asmExpression() + } + + println("asm=$asm") +} diff --git a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt new file mode 100644 index 000000000..b060cddb6 --- /dev/null +++ b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt @@ -0,0 +1,71 @@ +package scientifik.kmath.commons.prob + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.runBlocking +import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler +import org.apache.commons.rng.simple.RandomSource +import scientifik.kmath.chains.BlockingRealChain +import scientifik.kmath.prob.* +import java.time.Duration +import java.time.Instant + + +private suspend fun runChain(): Duration { + val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) + + val normal = Distribution.normal(NormalSamplerMethod.Ziggurat) + val chain = normal.sample(generator) as BlockingRealChain + + val startTime = Instant.now() + var sum = 0.0 + repeat(10000001) { counter -> + + sum += chain.nextDouble() + + if (counter % 100000 == 0) { + val duration = Duration.between(startTime, Instant.now()) + val meanValue = sum / counter + println("Chain sampler completed $counter elements in $duration: $meanValue") + } + } + return Duration.between(startTime, Instant.now()) +} + +private fun runDirect(): Duration { + val provider = RandomSource.create(RandomSource.MT, 123L) + val sampler = ZigguratNormalizedGaussianSampler(provider) + val startTime = Instant.now() + + var sum = 0.0 + repeat(10000001) { counter -> + + sum += sampler.sample() + + if (counter % 100000 == 0) { + val duration = Duration.between(startTime, Instant.now()) + val meanValue = sum / counter + println("Direct sampler completed $counter elements in $duration: $meanValue") + } + } + return Duration.between(startTime, Instant.now()) +} + +/** + * Comparing chain sampling performance with direct sampling performance + */ +fun main() { + runBlocking(Dispatchers.Default) { + val chainJob = async { + runChain() + } + + val directJob = async { + runDirect() + } + + println("Chain: ${chainJob.await()}") + println("Direct: ${directJob.await()}") + } + +} \ No newline at end of file diff --git a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt index 3c5f53e13..e059415dc 100644 --- a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt +++ b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt @@ -5,10 +5,11 @@ import scientifik.kmath.chains.Chain import scientifik.kmath.chains.collectWithState import scientifik.kmath.prob.Distribution import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.normal data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) -fun Chain.mean(): Chain = collectWithState(AveragingChainState(),{it.copy()}){ chain-> +fun Chain.mean(): Chain = collectWithState(AveragingChainState(), { it.copy() }) { chain -> val next = chain.next() num++ value += next diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt index cc8b68d85..991cd34a1 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt @@ -27,7 +27,7 @@ fun main() { val complexTime = measureTimeMillis { complexField.run { - var res = one + var res: NDBuffer = one repeat(n) { res += 1.0 } diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt index cfd1206ff..2aafb504d 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt @@ -23,14 +23,14 @@ fun main() { measureAndPrint("Automatic field addition") { autoField.run { - var res = one + var res: NDBuffer = one repeat(n) { - res += 1.0 + res += number(1.0) } } } - measureAndPrint("Element addition"){ + measureAndPrint("Element addition") { var res = genericField.one repeat(n) { res += 1.0 @@ -63,7 +63,7 @@ fun main() { genericField.run { var res: NDBuffer = one repeat(n) { - res += 1.0 + res += one // con't avoid using `one` due to resolution ambiguity } } } diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt index ecfb4ab20..a33fdb2c4 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt @@ -6,7 +6,7 @@ fun main(args: Array) { val n = 6000 val array = DoubleArray(n * n) { 1.0 } - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) val strides = DefaultStrides(intArrayOf(n, n)) val structure = BufferNDStructure(strides, buffer) diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt index 2d16cc8f4..0241f12ad 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt @@ -26,10 +26,10 @@ fun main(args: Array) { } println("Array mapping finished in $time2 millis") - val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 }) + val buffer = RealBuffer(DoubleArray(n * n) { 1.0 }) val time3 = measureTimeMillis { - val target = DoubleBuffer(DoubleArray(n * n)) + val target = RealBuffer(DoubleArray(n * n)) val res = array.forEachIndexed { index, value -> target[index] = value + 1 } diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 490fda857..62d4c0535 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index a4b442974..bb8b2fc26 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.3-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 2fe81a7d9..fbd7c5158 100755 --- a/gradlew +++ b/gradlew @@ -82,6 +82,7 @@ esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then @@ -129,6 +130,7 @@ fi if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` # We build the pattern for arguments to be converted via cygpath diff --git a/gradlew.bat b/gradlew.bat index 62bd9b9cc..5093609d5 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -84,6 +84,7 @@ set CMD_LINE_ARGS=%* set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + @rem Execute Gradle "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% diff --git a/kmath-ast/README.md b/kmath-ast/README.md new file mode 100644 index 000000000..2339d0426 --- /dev/null +++ b/kmath-ast/README.md @@ -0,0 +1,91 @@ +# Abstract Syntax Tree Expression Representation and Operations (`kmath-ast`) + +This subproject implements the following features: + +- Expression Language and its parser. +- MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation. +- Type-safe builder for MST. +- Evaluating expressions by traversing MST. + +> #### Artifact: +> This module is distributed in the artifact `scientifik:kmath-ast:0.1.4-dev-8`. +> +> **Gradle:** +> +> ```gradle +> repositories { +> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } +> maven { url 'https://dl.bintray.com/mipt-npm/dev' } +> maven { url https://dl.bintray.com/hotkeytlt/maven' } +> } +> +> dependencies { +> implementation 'scientifik:kmath-ast:0.1.4-dev-8' +> } +> ``` +> **Gradle Kotlin DSL:** +> +> ```kotlin +> repositories { +> maven("https://dl.bintray.com/mipt-npm/scientifik") +> maven("https://dl.bintray.com/mipt-npm/dev") +> maven("https://dl.bintray.com/hotkeytlt/maven") +> } +> +> dependencies { +> implementation("scientifik:kmath-ast:0.1.4-dev-8") +> } +> ``` +> + +## Dynamic Expression Code Generation with ObjectWeb ASM + +`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds +a special implementation of `Expression` with implemented `invoke` function. + +For example, the following builder: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +``` + +… leads to generation of bytecode, which can be decompiled to the following Java class: + +```java +package scientifik.kmath.asm.generated; + +import java.util.Map; +import scientifik.kmath.asm.internal.MapIntrinsics; +import scientifik.kmath.expressions.Expression; +import scientifik.kmath.operations.RealField; + +public final class AsmCompiledExpression_1073786867_0 implements Expression { + private final RealField algebra; + + public final Double invoke(Map arguments) { + return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D); + } + + public AsmCompiledExpression_1073786867_0(RealField algebra) { + this.algebra = algebra; + } +} + +``` + +### Example Usage + +This API extends MST and MstExpression, so you may optimize as both of them: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +RealField.expression("x+2".parseMath()) +``` + +### Known issues + +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid +class loading overhead. +- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts new file mode 100644 index 000000000..d13a7712d --- /dev/null +++ b/kmath-ast/build.gradle.kts @@ -0,0 +1,23 @@ +plugins { id("scientifik.mpp") } + +kotlin.sourceSets { +// all { +// languageSettings.apply{ +// enableLanguageFeature("NewInference") +// } +// } + commonMain { + dependencies { + api(project(":kmath-core")) + implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0") + } + } + + jvmMain { + dependencies { + implementation("org.ow2.asm:asm:8.0.1") + implementation("org.ow2.asm:asm-commons:8.0.1") + implementation(kotlin("reflect")) + } + } +} \ No newline at end of file diff --git a/kmath-ast/reference/ArithmeticsEvaluator.g4 b/kmath-ast/reference/ArithmeticsEvaluator.g4 new file mode 100644 index 000000000..dc47b23fb --- /dev/null +++ b/kmath-ast/reference/ArithmeticsEvaluator.g4 @@ -0,0 +1,59 @@ +grammar ArithmeticsEvaluator; + +fragment DIGIT: '0'..'9'; +fragment LETTER: 'a'..'z'; +fragment CAPITAL_LETTER: 'A'..'Z'; +fragment UNDERSCORE: '_'; + +ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*; +NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?; +MUL: '*'; +DIV: '/'; +PLUS: '+'; +MINUS: '-'; +POW: '^'; +COMMA: ','; +LPAR: '('; +RPAR: ')'; +WS: [ \n\t\r]+ -> skip; + +num + : NUM + ; + +singular + : ID + ; + +unaryFunction + : ID LPAR subSumChain RPAR + ; + +binaryFunction + : ID LPAR subSumChain COMMA subSumChain RPAR + ; + +term + : num + | singular + | unaryFunction + | binaryFunction + | MINUS term + | LPAR subSumChain RPAR + ; + +powChain + : term (POW term)* + ; + +divMulChain + : powChain ((DIV | MUL) powChain)* + ; + +subSumChain + : divMulChain ((PLUS | MINUS) divMulChain)* + ; + +rootParser + : subSumChain EOF + ; diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt new file mode 100644 index 000000000..0e8151c04 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -0,0 +1,87 @@ +package scientifik.kmath.ast + +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.operations.RealField + +/** + * A Mathematical Syntax Tree node for mathematical expressions. + */ +sealed class MST { + /** + * A node containing raw string. + * + * @property value the value of this node. + */ + data class Symbolic(val value: String) : MST() + + /** + * A node containing a numeric value or scalar. + * + * @property value the value of this number. + */ + data class Numeric(val value: Number) : MST() + + /** + * A node containing an unary operation. + * + * @property operation the identifier of operation. + * @property value the argument of this operation. + */ + data class Unary(val operation: String, val value: MST) : MST() { + companion object + } + + /** + * A node containing binary operation. + * + * @property operation the identifier operation. + * @property left the left operand. + * @property right the right operand. + */ + data class Binary(val operation: String, val left: MST, val right: MST) : MST() { + companion object + } +} + +// TODO add a function with named arguments + +/** + * Interprets the [MST] node with this [Algebra]. + * + * @receiver the algebra that provides operations. + * @param node the node to evaluate. + * @return the value of expression. + */ +fun Algebra.evaluate(node: MST): T = when (node) { + is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) + ?: error("Numeric nodes are not supported by $this") + is MST.Symbolic -> symbol(node.value) + is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) + is MST.Binary -> when { + this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + + node.left is MST.Numeric && node.right is MST.Numeric -> { + val number = RealField.binaryOperation( + node.operation, + node.left.value.toDouble(), + node.right.value.toDouble() + ) + + number(number) + } + + node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) + node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) + else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + } +} + +/** + * Interprets the [MST] node with this [Algebra]. + * + * @receiver the node to evaluate. + * @param algebra the algebra that provides operations. + * @return the value of expression. + */ +fun MST.interpret(algebra: Algebra): T = algebra.evaluate(this) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt new file mode 100644 index 000000000..b47c7cae8 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt @@ -0,0 +1,102 @@ +package scientifik.kmath.ast + +import scientifik.kmath.operations.* + +/** + * [Algebra] over [MST] nodes. + */ +object MstAlgebra : NumericAlgebra { + override fun number(value: Number): MST = MST.Numeric(value) + + override fun symbol(value: String): MST = MST.Symbolic(value) + + override fun unaryOperation(operation: String, arg: MST): MST = + MST.Unary(operation, arg) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MST.Binary(operation, left, right) +} + +/** + * [Space] over [MST] nodes. + */ +object MstSpace : Space, NumericAlgebra { + override val zero: MST = number(0.0) + + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} + +/** + * [Ring] over [MST] nodes. + */ +object MstRing : Ring, NumericAlgebra { + override val zero: MST = number(0.0) + override val one: MST = number(1.0) + + override fun number(value: Number): MST = MstSpace.number(value) + override fun symbol(value: String): MST = MstSpace.symbol(value) + override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) + + override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) + + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstSpace.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} + +/** + * [Field] over [MST] nodes. + */ +object MstField : Field { + override val zero: MST = number(0.0) + override val one: MST = number(1.0) + + override fun symbol(value: String): MST = MstRing.symbol(value) + override fun number(value: Number): MST = MstRing.number(value) + override fun add(a: MST, b: MST): MST = MstRing.add(a, b) + override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) + override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) + override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstRing.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) +} + +/** + * [ExtendedField] over [MST] nodes. + */ +object MstExtendedField : ExtendedField { + override val zero: MST = number(0.0) + override val one: MST = number(1.0) + + override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + override fun asin(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) + override fun acos(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg) + override fun atan(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg) + override fun add(a: MST, b: MST): MST = MstField.add(a, b) + override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) + override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) + override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) + override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstField.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg) +} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt new file mode 100644 index 000000000..59f3f15d8 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -0,0 +1,88 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.* +import scientifik.kmath.operations.* + +/** + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than + * ASM-generated expressions. + * + * @property algebra the algebra that provides operations. + * @property mst the [MST] node. + */ +class MstExpression(val algebra: Algebra, val mst: MST) : Expression { + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T = + algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if (algebra is NumericAlgebra) + algebra.number(value) + else + error("Numeric nodes are not supported by $this") + } + + override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} + +/** + * Builds [MstExpression] over [Algebra]. + */ +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MstExpression = MstExpression(this, mstAlgebra.block()) + +/** + * Builds [MstExpression] over [Space]. + */ +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = + MstExpression(this, MstSpace.block()) + +/** + * Builds [MstExpression] over [Ring]. + */ +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = + MstExpression(this, MstRing.block()) + +/** + * Builds [MstExpression] over [Field]. + */ +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = + MstExpression(this, MstField.block()) + +/** + * Builds [MstExpression] over [ExtendedField]. + */ +inline fun Field.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression = + MstExpression(this, MstExtendedField.block()) + +/** + * Builds [MstExpression] over [FunctionalExpressionSpace]. + */ +inline fun > FunctionalExpressionSpace.mstInSpace( + block: MstSpace.() -> MST +): MstExpression = algebra.mstInSpace(block) + +/** + * Builds [MstExpression] over [FunctionalExpressionRing]. + */ +inline fun > FunctionalExpressionRing.mstInRing( + block: MstRing.() -> MST +): MstExpression = algebra.mstInRing(block) + +/** + * Builds [MstExpression] over [FunctionalExpressionField]. + */ +inline fun > FunctionalExpressionField.mstInField( + block: MstField.() -> MST +): MstExpression = algebra.mstInField(block) + +/** + * Builds [MstExpression] over [FunctionalExpressionExtendedField]. + */ +inline fun > FunctionalExpressionExtendedField.mstInExtendedField( + block: MstExtendedField.() -> MST +): MstExpression = algebra.mstInExtendedField(block) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt new file mode 100644 index 000000000..cba335a8d --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt @@ -0,0 +1,97 @@ +package scientifik.kmath.ast + +import com.github.h0tk3y.betterParse.combinators.* +import com.github.h0tk3y.betterParse.grammar.Grammar +import com.github.h0tk3y.betterParse.grammar.parseToEnd +import com.github.h0tk3y.betterParse.grammar.parser +import com.github.h0tk3y.betterParse.grammar.tryParseToEnd +import com.github.h0tk3y.betterParse.lexer.Token +import com.github.h0tk3y.betterParse.lexer.TokenMatch +import com.github.h0tk3y.betterParse.lexer.regexToken +import com.github.h0tk3y.betterParse.parser.ParseResult +import com.github.h0tk3y.betterParse.parser.Parser +import scientifik.kmath.operations.FieldOperations +import scientifik.kmath.operations.PowerOperations +import scientifik.kmath.operations.RingOperations +import scientifik.kmath.operations.SpaceOperations + +/** + * TODO move to core + */ +object ArithmeticsEvaluator : Grammar() { + // TODO replace with "...".toRegex() when better-parse 0.4.1 is released + private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?") + private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*") + private val lpar: Token by regexToken("\\(") + private val rpar: Token by regexToken("\\)") + private val comma: Token by regexToken(",") + private val mul: Token by regexToken("\\*") + private val pow: Token by regexToken("\\^") + private val div: Token by regexToken("/") + private val minus: Token by regexToken("-") + private val plus: Token by regexToken("\\+") + private val ws: Token by regexToken("\\s+", ignore = true) + + private val number: Parser by num use { MST.Numeric(text.toDouble()) } + private val singular: Parser by id use { MST.Symbolic(text) } + + private val unaryFunction: Parser by (id and skip(lpar) and parser(::subSumChain) and skip(rpar)) + .map { (id, term) -> MST.Unary(id.text, term) } + + private val binaryFunction: Parser by id + .and(skip(lpar)) + .and(parser(::subSumChain)) + .and(skip(comma)) + .and(parser(::subSumChain)) + .and(skip(rpar)) + .map { (id, left, right) -> MST.Binary(id.text, left, right) } + + private val term: Parser by number + .or(binaryFunction) + .or(unaryFunction) + .or(singular) + .or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) + .or(skip(lpar) and parser(::subSumChain) and skip(rpar)) + + private val powChain: Parser by leftAssociative(term = term, operator = pow) { a, _, b -> + MST.Binary(PowerOperations.POW_OPERATION, a, b) + } + + private val divMulChain: Parser by leftAssociative( + term = powChain, + operator = div or mul use TokenMatch::type + ) { a, op, b -> + if (op == div) + MST.Binary(FieldOperations.DIV_OPERATION, a, b) + else + MST.Binary(RingOperations.TIMES_OPERATION, a, b) + } + + private val subSumChain: Parser by leftAssociative( + term = divMulChain, + operator = plus or minus use TokenMatch::type + ) { a, op, b -> + if (op == plus) + MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) + else + MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) + } + + override val rootParser: Parser by subSumChain +} + +/** + * Tries to parse the string into [MST]. + * + * @receiver the string to parse. + * @return the [MST] node. + */ +fun String.tryParseMath(): ParseResult = ArithmeticsEvaluator.tryParseToEnd(this) + +/** + * Parses the string into [MST]. + * + * @receiver the string to parse. + * @return the [MST] node. + */ +fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt new file mode 100644 index 000000000..ee0ea15ff --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -0,0 +1,64 @@ +package scientifik.kmath.asm + +import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.MstType +import scientifik.kmath.asm.internal.buildAlgebraOperationCall +import scientifik.kmath.asm.internal.buildName +import scientifik.kmath.ast.MST +import scientifik.kmath.ast.MstExpression +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import kotlin.reflect.KClass + +/** + * Compile given MST to an Expression using AST compiler + */ +fun MST.compileWith(type: KClass, algebra: Algebra): Expression { + fun AsmBuilder.visit(node: MST) { + when (node) { + is MST.Symbolic -> { + val symbol = try { + algebra.symbol(node.value) + } catch (ignored: Throwable) { + null + } + + if (symbol != null) + loadTConstant(symbol) + else + loadVariable(node.value) + } + + is MST.Numeric -> loadNumeric(node.value) + + is MST.Unary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "unaryOperation", + parameterTypes = arrayOf(MstType.fromMst(node.value)) + ) { visit(node.value) } + + is MST.Binary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "binaryOperation", + parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right)) + ) { + visit(node.left) + visit(node.right) + } + } + } + + return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() +} + +/** + * Compile an [MST] to ASM using given algebra + */ +inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) + +/** + * Optimize performance of an [MstExpression] using ASM codegen + */ +inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt new file mode 100644 index 000000000..f8c159baf --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -0,0 +1,568 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.commons.InstructionAdapter +import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader +import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra +import java.util.* +import java.util.stream.Collectors +import kotlin.reflect.KClass + +/** + * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. + * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. + * + * @property T the type of AsmExpression to unwrap. + * @property algebra the algebra the applied AsmExpressions use. + * @property className the unique class name of new loaded class. + * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. + */ +internal class AsmBuilder internal constructor( + private val classOfT: KClass<*>, + private val algebra: Algebra, + private val className: String, + private val invokeLabel0Visitor: AsmBuilder.() -> Unit +) { + /** + * Internal classloader of [AsmBuilder] with alias to define class from byte array. + */ + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + /** + * The instance of [ClassLoader] used by this builder. + */ + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + + /** + * ASM Type for [algebra]. + */ + private val tAlgebraType: Type = algebra::class.asm + + /** + * ASM type for [T]. + */ + internal val tType: Type = classOfT.asm + + /** + * ASM type for new class. + */ + private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + + /** + * Index of `this` variable in invoke method of the built subclass. + */ + private val invokeThisVar: Int = 0 + + /** + * Index of `arguments` variable in invoke method of the built subclass. + */ + private val invokeArgumentsVar: Int = 1 + + /** + * List of constants to provide to the subclass. + */ + private val constants: MutableList = mutableListOf() + + /** + * Method visitor of `invoke` method of the subclass. + */ + private lateinit var invokeMethodVisitor: InstructionAdapter + + /** + * State if this [AsmBuilder] needs to generate constants field. + */ + private var hasConstants: Boolean = true + + /** + * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. + */ + internal var primitiveMode: Boolean = false + + /** + * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMask: Type = OBJECT_TYPE + + /** + * Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMaskBoxed: Type = OBJECT_TYPE + + /** + * Stack of useful objects types on stack to verify types. + */ + private val typeStack: ArrayDeque = ArrayDeque() + + /** + * Stack of useful objects types on stack expected by algebra calls. + */ + internal val expectationStack: ArrayDeque = ArrayDeque(listOf(tType)) + + /** + * The cache for instance built by this builder. + */ + private var generatedInstance: Expression? = null + + /** + * Subclasses, loads and instantiates [Expression] for given parameters. + * + * The built instance is cached. + */ + @Suppress("UNCHECKED_CAST") + internal fun getInstance(): Expression { + generatedInstance?.let { return it } + + if (SIGNATURE_LETTERS.containsKey(classOfT)) { + primitiveMode = true + primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) + primitiveMaskBoxed = tType + } + + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { + visit( + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName) + ) + + visitMethod( + ACC_PUBLIC or ACC_FINAL, + "invoke", + Type.getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", + null + ).instructionAdapter { + invokeMethodVisitor = this + visitCode() + val l0 = label() + invokeLabel0Visitor() + areturn(tType) + val l1 = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + l0, + l1, + invokeThisVar + ) + + visitLocalVariable( + "arguments", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", + l0, + l1, + invokeArgumentsVar + ) + + visitMaxs(0, 2) + visitEnd() + } + + visitMethod( + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, + "invoke", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null + ).instructionAdapter { + val thisVar = 0 + val argumentsVar = 1 + visitCode() + val l0 = label() + load(thisVar, OBJECT_TYPE) + load(argumentsVar, MAP_TYPE) + invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val l1 = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + l0, + l1, + thisVar + ) + + visitMaxs(0, 2) + visitEnd() + } + + hasConstants = constants.isNotEmpty() + + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "algebra", + descriptor = tAlgebraType.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + if (hasConstants) + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitMethod( + ACC_PUBLIC, + "", + + Type.getMethodDescriptor( + Type.VOID_TYPE, + tAlgebraType, + *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), + + null, + null + ).instructionAdapter { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = label() + load(thisVar, classType) + invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + label() + load(thisVar, classType) + load(algebraVar, tAlgebraType) + putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + + if (hasConstants) { + label() + load(thisVar, classType) + load(constantsVar, OBJECT_ARRAY_TYPE) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + } + + label() + visitInsn(RETURN) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) + + visitLocalVariable( + "algebra", + tAlgebraType.descriptor, + null, + l0, + l4, + algebraVar + ) + + if (hasConstants) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) + + visitMaxs(0, 3) + visitEnd() + } + + visitEnd() + } + + val new = classLoader + .defineClass(className, classWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression + + generatedInstance = new + return new + } + + /** + * Loads a [T] constant from [constants]. + */ + internal fun loadTConstant(value: T) { + if (classOfT in INLINABLE_NUMBERS) { + val expectedType = expectationStack.pop() + val mustBeBoxed = expectedType.sort == Type.OBJECT + loadNumberConstant(value as Number, mustBeBoxed) + + if (mustBeBoxed) + invokeMethodVisitor.checkcast(tType) + + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) + return + } + + loadObjectConstant(value as Any, tType) + } + + /** + * Boxes the current value and pushes it. + */ + private fun box(primitive: Type) { + val r = PRIMITIVES_TO_BOXED.getValue(primitive) + + invokeMethodVisitor.invokestatic( + r.internalName, + "valueOf", + Type.getMethodDescriptor(r, primitive), + false + ) + } + + /** + * Unboxes the current boxed value and pushes it. + */ + private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual( + NUMBER_TYPE.internalName, + NUMBER_CONVERTER_METHODS.getValue(primitive), + Type.getMethodDescriptor(primitive), + false + ) + + /** + * Loads [java.lang.Object] constant from constants. + */ + private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + loadThis() + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + checkcast(type) + } + + internal fun loadNumeric(value: Number) { + if (expectationStack.peek() == NUMBER_TYPE) { + loadNumberConstant(value, true) + expectationStack.pop() + typeStack.push(NUMBER_TYPE) + } else (algebra as? NumericAlgebra)?.number(value)?.let { loadTConstant(it) } + ?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.") + } + + /** + * Loads this variable. + */ + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) + + /** + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive + * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded + * from it). + */ + private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { + val boxed = value::class.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] + + if (primitive != null) { + when (primitive) { + Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + + if (mustBeBoxed) + box(primitive) + + return + } + + loadObjectConstant(value, boxed) + + if (!mustBeBoxed) + unboxTo(primitiveMask) + } + + /** + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be + * provided. + */ + internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + load(invokeArgumentsVar, MAP_TYPE) + aconst(name) + + if (defaultValue != null) + loadTConstant(defaultValue) + + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + + Type.getMethodDescriptor( + OBJECT_TYPE, + MAP_TYPE, + OBJECT_TYPE, + *OBJECT_TYPE.wrapToArrayIf { defaultValue != null }), + false + ) + + checkcast(tType) + val expectedType = expectationStack.pop() + + if (expectedType.sort == Type.OBJECT) + typeStack.push(tType) + else { + unboxTo(primitiveMask) + typeStack.push(primitiveMask) + } + } + + /** + * Loads algebra from according field of the class and casts it to class of [algebra] provided. + */ + internal fun loadAlgebra() { + loadThis() + invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) + } + + /** + * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is + * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be + * called before the arguments and this operation. + * + * The result is casted to [T] automatically. + */ + internal fun invokeAlgebraOperation( + owner: String, + method: String, + descriptor: String, + expectedArity: Int, + opcode: Int = INVOKEINTERFACE + ) { + run loop@{ + repeat(expectedArity) { + if (typeStack.isEmpty()) return@loop + typeStack.pop() + } + } + + invokeMethodVisitor.visitMethodInsn( + opcode, + owner, + method, + descriptor, + opcode == INVOKEINTERFACE + ) + + invokeMethodVisitor.checkcast(tType) + val isLastExpr = expectationStack.size == 1 + val expectedType = expectationStack.pop() + + if (expectedType.sort == Type.OBJECT || isLastExpr) + typeStack.push(tType) + else { + unboxTo(primitiveMask) + typeStack.push(primitiveMask) + } + } + + /** + * Writes a LDC Instruction with string constant provided. + */ + internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) + + internal companion object { + /** + * Maps JVM primitive numbers boxed types to their primitive ASM types. + */ + private val SIGNATURE_LETTERS: Map, Type> by lazy { + hashMapOf( + java.lang.Byte::class to Type.BYTE_TYPE, + java.lang.Short::class to Type.SHORT_TYPE, + java.lang.Integer::class to Type.INT_TYPE, + java.lang.Long::class to Type.LONG_TYPE, + java.lang.Float::class to Type.FLOAT_TYPE, + java.lang.Double::class to Type.DOUBLE_TYPE + ) + } + + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ + private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ + private val PRIMITIVES_TO_BOXED: Map by lazy { + BOXED_TO_PRIMITIVES.entries.stream().collect( + Collectors.toMap( + Map.Entry::value, + Map.Entry::key + ) + ) + } + + /** + * Maps primitive ASM types to [Number] functions unboxing them. + */ + private val NUMBER_CONVERTER_METHODS: Map by lazy { + hashMapOf( + Type.BYTE_TYPE to "byteValue", + Type.SHORT_TYPE to "shortValue", + Type.INT_TYPE to "intValue", + Type.LONG_TYPE to "longValue", + Type.FLOAT_TYPE to "floatValue", + Type.DOUBLE_TYPE to "doubleValue" + ) + } + + /** + * Provides boxed number types values of which can be stored in JVM bytecode constant pool. + */ + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + + /** + * ASM type for [Expression]. + */ + internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + + /** + * ASM type for [java.lang.Number]. + */ + internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + + /** + * ASM type for [java.util.Map]. + */ + internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + + /** + * ASM type for [java.lang.Object]. + */ + internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } + + /** + * ASM type for array of [java.lang.Object]. + */ + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") + internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + + /** + * ASM type for [Algebra]. + */ + internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + + /** + * ASM type for [java.lang.String]. + */ + internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } + + /** + * ASM type for MapIntrinsics. + */ + internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt new file mode 100644 index 000000000..bf73d304b --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt @@ -0,0 +1,17 @@ +package scientifik.kmath.asm.internal + +import scientifik.kmath.ast.MST + +internal enum class MstType { + GENERAL, + NUMBER; + + companion object { + fun fromMst(mst: MST): MstType { + if (mst is MST.Numeric) + return NUMBER + + return GENERAL + } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt new file mode 100644 index 000000000..a637289b8 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -0,0 +1,178 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.INVOKEVIRTUAL +import org.objectweb.asm.commons.InstructionAdapter +import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import java.lang.reflect.Method +import kotlin.reflect.KClass + +private val methodNameAdapters: Map, String> by lazy { + hashMapOf( + "+" to 2 to "add", + "*" to 2 to "multiply", + "/" to 2 to "divide", + "+" to 1 to "unaryPlus", + "-" to 1 to "unaryMinus", + "-" to 2 to "minus" + ) +} + +internal val KClass<*>.asm: Type + get() = Type.getType(java) + +/** + * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. + */ +internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array = + if (predicate(this)) arrayOf(this) else emptyArray() + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor]. + */ +private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. + */ +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) + +/** + * Constructs a [Label], then applies it to this visitor. + */ +internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) } + +/** + * Creates a class name for [Expression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} + +@Suppress("FunctionName") +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) + +internal inline fun ClassWriter.visitField( + access: Int, + name: String, + descriptor: String, + signature: String?, + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) + +private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = + context.javaClass.methods.find { method -> + val nameValid = method.name == name + val arityValid = method.parameters.size == parameterTypes.size + val notBridgeInPrimitive = !(primitiveMode && method.isBridge) + + val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) -> + !(mstType != MstType.NUMBER && type == java.lang.Number::class.java) + } + + nameValid && arityValid && notBridgeInPrimitive && paramsValid + } + +/** + * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds + * type expectation stack for needed arity. + * + * @return `true` if contains, else `false`. + */ +private fun AsmBuilder.buildExpectationStack( + context: Algebra, + name: String, + parameterTypes: Array +): Boolean { + val arity = parameterTypes.size + val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes) + + if (specific != null) + mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) } + else + repeat(arity) { expectationStack.push(tType) } + + return specific != null +} + +private fun AsmBuilder.mapTypes(method: Method, parameterTypes: Array): List = method + .parameterTypes + .zip(parameterTypes) + .map { (type, mstType) -> + when { + type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE + else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed + } + } + +/** + * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts + * [AsmBuilder.invokeAlgebraOperation] of this method. + * + * @return `true` if contains, else `false`. + */ +private fun AsmBuilder.tryInvokeSpecific( + context: Algebra, + name: String, + parameterTypes: Array +): Boolean { + val arity = parameterTypes.size + val theName = methodNameAdapters[name to arity] ?: name + val spec = findSpecific(context, theName, parameterTypes) ?: return false + val owner = context::class.asm + + invokeAlgebraOperation( + owner = owner.internalName, + method = theName, + descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()), + expectedArity = arity, + opcode = INVOKEVIRTUAL + ) + + return true +} + +/** + * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. + */ +internal inline fun AsmBuilder.buildAlgebraOperationCall( + context: Algebra, + name: String, + fallbackMethodName: String, + parameterTypes: Array, + parameters: AsmBuilder.() -> Unit +) { + val arity = parameterTypes.size + loadAlgebra() + if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) + parameters() + + if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = fallbackMethodName, + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + *Array(arity) { AsmBuilder.OBJECT_TYPE } + ), + + expectedArity = arity + ) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt new file mode 100644 index 000000000..80e83c1bf --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt @@ -0,0 +1,7 @@ +@file:JvmName("MapIntrinsics") + +package scientifik.kmath.asm.internal + +@JvmOverloads +internal fun Map.getOrFail(key: K, default: V? = null): V = + this[key] ?: default ?: error("Parameter not found: $key") diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt new file mode 100644 index 000000000..3acc6eb28 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -0,0 +1,110 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmAlgebras { + @Test + fun space() { + val res1 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }("x" to 2.toByte()) + + val res2 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to 2.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun ring() { + val res1 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }("x" to 3.toByte()) + + val res2 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }.compile()("x" to 3.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun field() { + val res1 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0) + + val res2 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0) + + assertEquals(res1, res2) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt new file mode 100644 index 000000000..36c254c38 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -0,0 +1,31 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmExpressions { + @Test + fun testUnaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") }.compile() + val res = expression("x" to 2.0) + assertEquals(-2.0, res) + } + + @Test + fun testBinaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile() + val res = expression("x" to 2.0) + assertEquals(-1.0, res) + } + + @Test + fun testConstProductInvocation() { + val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) + assertEquals(4.0, res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt new file mode 100644 index 000000000..a88431e9d --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt @@ -0,0 +1,55 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + assertEquals(2.0, expr("x" to 2.0)) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + assertEquals(-2.0, expr("x" to 2.0)) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 2.0)) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + assertEquals(1.0, expr("x" to 2.0)) + } + + @Test + fun testPower() { + val expr = RealField + .mstInField { binaryOperation("power", symbol("x"), number(2)) } + .compile() + + assertEquals(4.0, expr("x" to 2.0)) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt new file mode 100644 index 000000000..aafc75448 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt @@ -0,0 +1,22 @@ +package scietifik.kmath.asm + +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestAsmVariables { + @Test + fun testVariableWithoutDefault() { + val expr = ByteRing.mstInRing { symbol("x") } + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testVariableWithoutDefaultFails() { + val expr = ByteRing.mstInRing { symbol("x") } + assertFailsWith { expr() } + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt new file mode 100644 index 000000000..75659cc35 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -0,0 +1,25 @@ +package scietifik.kmath.ast + +import scientifik.kmath.asm.compile +import scientifik.kmath.asm.expression +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Complex +import scientifik.kmath.operations.ComplexField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class AsmTest { + @Test + fun `compile MST`() { + val res = ComplexField.expression("2+2*(2+2)".parseMath())() + assertEquals(Complex(10.0, 0.0), res) + } + + @Test + fun `compile MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()() + assertEquals(Complex(10.0, 0.0), res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserPrecedenceTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserPrecedenceTest.kt new file mode 100644 index 000000000..9bdbb12c9 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserPrecedenceTest.kt @@ -0,0 +1,36 @@ +package scietifik.kmath.ast + +import scientifik.kmath.ast.evaluate +import scientifik.kmath.ast.parseMath +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class ParserPrecedenceTest { + private val f: Field = RealField + + @Test + fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath())) + + @Test + fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath())) + + @Test + fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath())) + + @Test + fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath())) + + @Test + fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath())) + + @Test + fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath())) + + @Test + fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath())) + + @Test + fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath())) +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt new file mode 100644 index 000000000..9179c3428 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -0,0 +1,60 @@ +package scietifik.kmath.ast + +import scientifik.kmath.ast.evaluate +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.Complex +import scientifik.kmath.operations.ComplexField +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class ParserTest { + @Test + fun `evaluate MST`() { + val mst = "2+2*(2+2)".parseMath() + val res = ComplexField.evaluate(mst) + assertEquals(Complex(10.0, 0.0), res) + } + + @Test + fun `evaluate MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() + assertEquals(Complex(10.0, 0.0), res) + } + + @Test + fun `evaluate MST with singular`() { + val mst = "i".parseMath() + val res = ComplexField.evaluate(mst) + assertEquals(ComplexField.i, res) + } + + + @Test + fun `evaluate MST with unary function`() { + val mst = "sin(0)".parseMath() + val res = RealField.evaluate(mst) + assertEquals(0.0, res) + } + + @Test + fun `evaluate MST with binary function`() { + val magicalAlgebra = object : Algebra { + override fun symbol(value: String): String = value + + override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() + + override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) { + "magic" -> "$left ★ $right" + else -> throw NotImplementedError() + } + } + + val mst = "magic(a, b)".parseMath() + val res = magicalAlgebra.evaluate(mst) + assertEquals("a ★ b", res) + } +} diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index d5c038dc4..54c404f57 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,7 +2,7 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty @@ -59,8 +59,10 @@ class DerivativeStructureField( override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() - override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() + override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() + override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { is Double -> arg.pow(pow) @@ -74,10 +76,10 @@ class DerivativeStructureField( override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() - operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble()) - operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble()) - operator fun Number.plus(s: DerivativeStructure) = s + this - operator fun Number.minus(s: DerivativeStructure) = s - this + override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) + override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) + override operator fun Number.plus(b: DerivativeStructure) = b + this + override operator fun Number.minus(b: DerivativeStructure) = b - this } /** @@ -113,7 +115,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionContext, Field { +object DiffExpressionAlgebra : ExpressionAlgebra, Field { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } @@ -136,6 +138,3 @@ object DiffExpressionContext : ExpressionContext, Field override fun divide(a: DiffExpression, b: DiffExpression) = DiffExpression { a.function(this) / b.function(this) } } - - - diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt index 72e5fb95a..a17effccc 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt @@ -5,6 +5,7 @@ import org.apache.commons.math3.linear.RealMatrix import org.apache.commons.math3.linear.RealVector import scientifik.kmath.linear.* import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.NDStructure class CMMatrix(val origin: RealMatrix, features: Set? = null) : FeaturedMatrix { @@ -19,6 +20,16 @@ class CMMatrix(val origin: RealMatrix, features: Set? = null) : CMMatrix(origin, this.features + features) override fun get(i: Int, j: Int): Double = origin.getEntry(i, j) + + override fun equals(other: Any?): Boolean { + return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + } + + override fun hashCode(): Int { + var result = origin.hashCode() + result = 31 * result + features.hashCode() + return result + } } fun Matrix.toCM(): CMMatrix = if (this is CMMatrix) { diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CMRandomGeneratorWrapper.kt deleted file mode 100644 index 74e035ecb..000000000 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CMRandomGeneratorWrapper.kt +++ /dev/null @@ -1,32 +0,0 @@ -package scientifik.kmath.commons.prob - -import org.apache.commons.math3.random.JDKRandomGenerator -import scientifik.kmath.prob.RandomGenerator -import org.apache.commons.math3.random.RandomGenerator as CMRandom - -inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator { - override fun nextDouble(): Double = generator.nextDouble() - - override fun nextInt(): Int = generator.nextInt() - - override fun nextLong(): Long = generator.nextLong() - - override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) } - - override fun fork(): RandomGenerator { - TODO("not implemented") //To change body of created functions use File | Settings | File Templates. - } -} - -fun CMRandom.asKmathGenerator(): RandomGenerator = CMRandomGeneratorWrapper(this) - -fun RandomGenerator.asCMGenerator(): CMRandom = - (this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper") - -val RandomGenerator.Companion.default: RandomGenerator by lazy { JDKRandomGenerator().asKmathGenerator() } - -fun RandomGenerator.Companion.jdk(seed: Int? = null): RandomGenerator = if (seed == null) { - JDKRandomGenerator() -} else { - JDKRandomGenerator(seed) -}.asKmathGenerator() \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CommonsDistribution.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CommonsDistribution.kt deleted file mode 100644 index 94f8560a4..000000000 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CommonsDistribution.kt +++ /dev/null @@ -1,82 +0,0 @@ -package scientifik.kmath.commons.prob - -import org.apache.commons.math3.distribution.* -import scientifik.kmath.prob.Distribution -import scientifik.kmath.prob.RandomChain -import scientifik.kmath.prob.RandomGenerator -import scientifik.kmath.prob.UnivariateDistribution -import org.apache.commons.math3.random.RandomGenerator as CMRandom - -class CMRealDistributionWrapper(val builder: (CMRandom?) -> RealDistribution) : UnivariateDistribution { - - private val defaultDistribution by lazy { builder(null) } - - override fun probability(arg: Double): Double = defaultDistribution.probability(arg) - - override fun cumulative(arg: Double): Double = defaultDistribution.cumulativeProbability(arg) - - override fun sample(generator: RandomGenerator): RandomChain { - val distribution = builder(generator.asCMGenerator()) - return RandomChain(generator) { distribution.sample() } - } -} - -class CMIntDistributionWrapper(val builder: (CMRandom?) -> IntegerDistribution) : UnivariateDistribution { - - private val defaultDistribution by lazy { builder(null) } - - override fun probability(arg: Int): Double = defaultDistribution.probability(arg) - - override fun cumulative(arg: Int): Double = defaultDistribution.cumulativeProbability(arg) - - override fun sample(generator: RandomGenerator): RandomChain { - val distribution = builder(generator.asCMGenerator()) - return RandomChain(generator) { distribution.sample() } - } -} - - -fun Distribution.Companion.normal(mean: Double = 0.0, sigma: Double = 1.0): UnivariateDistribution = - CMRealDistributionWrapper { generator -> NormalDistribution(generator, mean, sigma) } - -fun Distribution.Companion.poisson(mean: Double): UnivariateDistribution = CMIntDistributionWrapper { generator -> - PoissonDistribution( - generator, - mean, - PoissonDistribution.DEFAULT_EPSILON, - PoissonDistribution.DEFAULT_MAX_ITERATIONS - ) -} - -fun Distribution.Companion.binomial(trials: Int, p: Double): UnivariateDistribution = - CMIntDistributionWrapper { generator -> - BinomialDistribution(generator, trials, p) - } - -fun Distribution.Companion.student(degreesOfFreedom: Double): UnivariateDistribution = - CMRealDistributionWrapper { generator -> - TDistribution(generator, degreesOfFreedom, TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY) - } - -fun Distribution.Companion.chi2(degreesOfFreedom: Double): UnivariateDistribution = - CMRealDistributionWrapper { generator -> - ChiSquaredDistribution(generator, degreesOfFreedom) - } - -fun Distribution.Companion.fisher( - numeratorDegreesOfFreedom: Double, - denominatorDegreesOfFreedom: Double -): UnivariateDistribution = - CMRealDistributionWrapper { generator -> - FDistribution(generator, numeratorDegreesOfFreedom, denominatorDegreesOfFreedom) - } - -fun Distribution.Companion.exponential(mean: Double): UnivariateDistribution = - CMRealDistributionWrapper { generator -> - ExponentialDistribution(generator, mean) - } - -fun Distribution.Companion.uniform(a: Double, b: Double): UnivariateDistribution = - CMRealDistributionWrapper { generator -> - UniformRealDistribution(generator, a, b) - } \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt new file mode 100644 index 000000000..13e79d60e --- /dev/null +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.commons.random + +import scientifik.kmath.prob.RandomGenerator + +class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : + org.apache.commons.math3.random.RandomGenerator { + private var generator = factory(intArrayOf()) + + override fun nextBoolean(): Boolean = generator.nextBoolean() + + override fun nextFloat(): Float = generator.nextDouble().toFloat() + + override fun setSeed(seed: Int) { + generator = factory(intArrayOf(seed)) + } + + override fun setSeed(seed: IntArray) { + generator = factory(seed) + } + + override fun setSeed(seed: Long) { + setSeed(seed.toInt()) + } + + override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + override fun nextInt(): Int = generator.nextInt() + + override fun nextInt(n: Int): Int = generator.nextInt(n) + + override fun nextGaussian(): Double = TODO() + + override fun nextDouble(): Double = generator.nextDouble() + + override fun nextLong(): Long = generator.nextLong() +} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt index bcb3ea87b..eb1b5b69a 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt @@ -18,7 +18,7 @@ object Transformations { private fun Buffer.toArray(): Array = Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } - private fun Buffer.asArray() = if (this is DoubleBuffer) { + private fun Buffer.asArray() = if (this is RealBuffer) { array } else { DoubleArray(size) { i -> get(i) } diff --git a/kmath-core/README.md b/kmath-core/README.md new file mode 100644 index 000000000..24e2c57d3 --- /dev/null +++ b/kmath-core/README.md @@ -0,0 +1,40 @@ +# The Core Module (`kmath-ast`) + +The core features of KMath: + +- Algebraic structures: contexts and elements. +- ND structures. +- Buffers. +- Functional Expressions. +- Domains. +- Automatic differentiation. + +> #### Artifact: +> This module is distributed in the artifact `scientifik:kmath-core:0.1.4-dev-8`. +> +> **Gradle:** +> +> ```gradle +> repositories { +> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } +> maven { url 'https://dl.bintray.com/mipt-npm/dev' } +> maven { url https://dl.bintray.com/hotkeytlt/maven' } +> } +> +> dependencies { +> implementation 'scientifik:kmath-core:0.1.4-dev-8' +> } +> ``` +> **Gradle Kotlin DSL:** +> +> ```kotlin +> repositories { +> maven("https://dl.bintray.com/mipt-npm/scientifik") +> maven("https://dl.bintray.com/mipt-npm/dev") +> maven("https://dl.bintray.com/hotkeytlt/maven") +> } +> +> dependencies {`` +> implementation("scientifik:kmath-core:0.1.4-dev-8") +> } +> ``` diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 092f3deb7..bea0fbf42 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -1,11 +1,7 @@ -plugins { - id("scientifik.mpp") -} +plugins { id("scientifik.mpp") } kotlin.sourceSets { commonMain { - dependencies { - api(project(":kmath-memory")) - } + dependencies { api(project(":kmath-memory")) } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt new file mode 100644 index 000000000..341383bfb --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * A simple geometric domain. + * + * @param T the type of element of this domain. + */ +interface Domain { + /** + * Checks if the specified point is contained in this domain. + */ + operator fun contains(point: Point): Boolean + + /** + * Number of hyperspace dimensions. + */ + val dimension: Int +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt new file mode 100644 index 000000000..66798c42f --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt @@ -0,0 +1,68 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.RealBuffer +import scientifik.kmath.structures.indices + +/** + * + * HyperSquareDomain class. + * + * @author Alexander Nozik + */ +class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { + + override operator fun contains(point: Point): Boolean = point.indices.all { i -> + point[i] in lower[i]..upper[i] + } + + override val dimension: Int get() = lower.size + + override fun getLowerBound(num: Int, point: Point): Double? = lower[num] + + override fun getLowerBound(num: Int): Double? = lower[num] + + override fun getUpperBound(num: Int, point: Point): Double? = upper[num] + + override fun getUpperBound(num: Int): Double? = upper[num] + + override fun nearestInDomain(point: Point): Point { + val res = DoubleArray(point.size) { i -> + when { + point[i] < lower[i] -> lower[i] + point[i] > upper[i] -> upper[i] + else -> point[i] + } + } + + return RealBuffer(*res) + } + + override fun volume(): Double { + var res = 1.0 + for (i in 0 until dimension) { + if (lower[i].isInfinite() || upper[i].isInfinite()) { + return Double.POSITIVE_INFINITY + } + if (upper[i] > lower[i]) { + res *= upper[i] - lower[i] + } + } + return res + } +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt new file mode 100644 index 000000000..7507ccd59 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * n-dimensional volume + * + * @author Alexander Nozik + */ +interface RealDomain : Domain { + fun nearestInDomain(point: Point): Point + + /** + * The lower edge for the domain going down from point + * @param num + * @param point + * @return + */ + fun getLowerBound(num: Int, point: Point): Double? + + /** + * The upper edge of the domain going up from point + * @param num + * @param point + * @return + */ + fun getUpperBound(num: Int, point: Point): Double? + + /** + * Global lower edge + * @param num + * @return + */ + fun getLowerBound(num: Int): Double? + + /** + * Global upper edge + * @param num + * @return + */ + fun getUpperBound(num: Int): Double? + + /** + * Hyper volume + * @return + */ + fun volume(): Double +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt new file mode 100644 index 000000000..595a3dbe7 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +class UnconstrainedDomain(override val dimension: Int) : RealDomain { + override operator fun contains(point: Point): Boolean = true + + override fun getLowerBound(num: Int, point: Point): Double? = Double.NEGATIVE_INFINITY + + override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY + + override fun getUpperBound(num: Int, point: Point): Double? = Double.POSITIVE_INFINITY + + override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY + + override fun nearestInDomain(point: Point): Point = point + + override fun volume(): Double = Double.POSITIVE_INFINITY +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt new file mode 100644 index 000000000..280dc7d66 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt @@ -0,0 +1,47 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.asBuffer + +inline class UnivariateDomain(val range: ClosedFloatingPointRange) : RealDomain { + operator fun contains(d: Double): Boolean = range.contains(d) + + override operator fun contains(point: Point): Boolean { + require(point.size == 0) + return contains(point[0]) + } + + override fun nearestInDomain(point: Point): Point { + require(point.size == 1) + val value = point[0] + return when { + value in range -> point + value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() + else -> doubleArrayOf(range.start).asBuffer() + } + } + + override fun getLowerBound(num: Int, point: Point): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int, point: Point): Double? { + require(num == 0) + return range.endInclusive + } + + override fun getLowerBound(num: Int): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int): Double? { + require(num == 0) + return range.endInclusive + } + + override fun volume(): Double = range.endInclusive - range.start + + override val dimension: Int get() = 1 +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt new file mode 100644 index 000000000..8cd6e28f8 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -0,0 +1,31 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.ExtendedField +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.Space + +/** + * Creates a functional expression with this [Space]. + */ +fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = + FunctionalExpressionSpace(this).run(block) + +/** + * Creates a functional expression with this [Ring]. + */ +fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = + FunctionalExpressionRing(this).run(block) + +/** + * Creates a functional expression with this [Field]. + */ +fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression = + FunctionalExpressionField(this).run(block) + +/** + * Creates a functional expression with this [ExtendedField]. + */ +fun ExtendedField.fieldExpression( + block: FunctionalExpressionExtendedField>.() -> Expression +): Expression = FunctionalExpressionExtendedField(this).run(block) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index aa7407c0a..380822f78 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,92 +1,49 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.Algebra /** * An elementary function that could be invoked on a map of arguments */ interface Expression { + /** + * Calls this expression from arguments. + * + * @param arguments the map of arguments. + * @return the value. + */ operator fun invoke(arguments: Map): T + + companion object } +/** + * Create simple lazily evaluated expression inside given algebra + */ +fun Algebra.expression(block: Algebra.(arguments: Map) -> T): Expression = + object : Expression { + override fun invoke(arguments: Map): T = block(arguments) + } + +/** + * Calls this expression from arguments. + * + * @param pairs the pair of arguments' names to values. + * @return the value. + */ operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) /** * A context for expression construction */ -interface ExpressionContext { +interface ExpressionAlgebra : Algebra { /** * Introduce a variable into expression context */ - fun variable(name: String, default: T? = null): Expression + fun variable(name: String, default: T? = null): E /** * A constant expression which does not depend on arguments */ - fun const(value: T): Expression + fun const(value: T): E } - -internal class VariableExpression(val name: String, val default: T? = null) : Expression { - override fun invoke(arguments: Map): T = - arguments[name] ?: default ?: error("Parameter not found: $name") -} - -internal class ConstantExpression(val value: T) : Expression { - override fun invoke(arguments: Map): T = value -} - -internal class SumExpression(val context: Space, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = - context.multiply(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : - Expression { - override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) -} - -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) -} - -open class ExpressionSpace(val space: Space) : Space>, ExpressionContext { - override val zero: Expression = ConstantExpression(space.zero) - - override fun const(value: T): Expression = ConstantExpression(value) - - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) - - - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this -} - - -class ExpressionField(val field: Field) : Field>, ExpressionSpace(field) { - override val one: Expression = ConstantExpression(field.one) - override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) - - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) - - operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - - operator fun T.times(arg: Expression) = arg * this - operator fun T.div(arg: Expression) = arg / this -} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt new file mode 100644 index 000000000..dd5fb572a --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -0,0 +1,175 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.* + +internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : + Expression { + override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) +} + +internal class FunctionalBinaryOperation( + val context: Algebra, + val name: String, + val first: Expression, + val second: Expression +) : Expression { + override fun invoke(arguments: Map): T = + context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) +} + +internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { + override fun invoke(arguments: Map): T = + arguments[name] ?: default ?: error("Parameter not found: $name") +} + +internal class FunctionalConstantExpression(val value: T) : Expression { + override fun invoke(arguments: Map): T = value +} + +internal class FunctionalConstProductExpression( + val context: Space, + private val expr: Expression, + val const: Number +) : Expression { + override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) +} + +/** + * A context class for [Expression] construction. + * + * @param algebra The algebra to provide for Expressions built. + */ +abstract class FunctionalExpressionAlgebra>(val algebra: A) : ExpressionAlgebra> { + /** + * Builds an Expression of constant expression which does not depend on arguments. + */ + override fun const(value: T): Expression = FunctionalConstantExpression(value) + + /** + * Builds an Expression to access a variable. + */ + override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) + + /** + * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. + */ + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) + + /** + * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. + */ + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) +} + +/** + * A context class for [Expression] construction for [Space] algebras. + */ +open class FunctionalExpressionSpace>(algebra: A) : + FunctionalExpressionAlgebra(algebra), Space> { + override val zero: Expression get() = const(algebra.zero) + + /** + * Builds an Expression of addition of two another expressions. + */ + override fun add(a: Expression, b: Expression): Expression = + binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + /** + * Builds an Expression of multiplication of expression by number. + */ + override fun multiply(a: Expression, k: Number): Expression = + FunctionalConstProductExpression(algebra, a, k) + + operator fun Expression.plus(arg: T): Expression = this + const(arg) + operator fun Expression.minus(arg: T): Expression = this - const(arg) + operator fun T.plus(arg: Expression): Expression = arg + this + operator fun T.minus(arg: Expression): Expression = arg - this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), + Ring> where A : Ring, A : NumericAlgebra { + override val one: Expression + get() = const(algebra.one) + + /** + * Builds an Expression of multiplication of two expressions. + */ + override fun multiply(a: Expression, b: Expression): Expression = + binaryOperation(RingOperations.TIMES_OPERATION, a, b) + + operator fun Expression.times(arg: T): Expression = this * const(arg) + operator fun T.times(arg: Expression): Expression = arg * this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +open class FunctionalExpressionField(algebra: A) : + FunctionalExpressionRing(algebra), + Field> where A : Field, A : NumericAlgebra { + /** + * Builds an Expression of division an expression by another one. + */ + override fun divide(a: Expression, b: Expression): Expression = + binaryOperation(FieldOperations.DIV_OPERATION, a, b) + + operator fun Expression.div(arg: T): Expression = this / const(arg) + operator fun T.div(arg: Expression): Expression = arg / this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +open class FunctionalExpressionExtendedField(algebra: A) : + FunctionalExpressionField(algebra), + ExtendedField> where A : ExtendedField, A : NumericAlgebra { + override fun sin(arg: Expression): Expression = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + override fun cos(arg: Expression): Expression = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + + override fun asin(arg: Expression): Expression = + unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) + + override fun acos(arg: Expression): Expression = + unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg) + + override fun atan(arg: Expression): Expression = + unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg) + + override fun power(arg: Expression, pow: Number): Expression = + binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + + override fun exp(arg: Expression): Expression = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = + FunctionalExpressionSpace(this).block() + +inline fun > A.expressionInRing(block: FunctionalExpressionRing.() -> Expression): Expression = + FunctionalExpressionRing(this).block() + +inline fun > A.expressionInField(block: FunctionalExpressionField.() -> Expression): Expression = + FunctionalExpressionField(this).block() + +inline fun > A.expressionInExtendedField(block: FunctionalExpressionExtendedField.() -> Expression): Expression = + FunctionalExpressionExtendedField(this).block() diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 624cd44d4..2e1f32501 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -19,22 +19,20 @@ class BufferMatrixContext>( override fun point(size: Int, initializer: (Int) -> T): Point = bufferFactory(size, initializer) - companion object { - - } + companion object } @Suppress("OVERRIDE_BY_INLINE") object RealMatrixContext : GenericMatrixContext { - override val elementContext = RealField + override val elementContext: RealField get() = RealField override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix { - val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } return BufferMatrix(rows, columns, buffer) } - override inline fun point(size: Int, initializer: (Int) -> Double): Point = DoubleBuffer(size,initializer) + override inline fun point(size: Int, initializer: (Int) -> Double): Point = RealBuffer(size, initializer) } class BufferMatrix( @@ -52,7 +50,7 @@ class BufferMatrix( override val shape: IntArray get() = intArrayOf(rowNum, colNum) - override fun suggestFeature(vararg features: MatrixFeature) = + override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix = BufferMatrix(rowNum, colNum, buffer, this.features + features) override fun get(index: IntArray): T = get(index[0], index[1]) @@ -84,8 +82,8 @@ class BufferMatrix( override fun toString(): String { return if (rowNum <= 5 && colNum <= 5) { "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + - rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { - it.asSequence().joinToString(separator = "\t") { it.toString() } + rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer -> + buffer.asSequence().joinToString(separator = "\t") { it.toString() } } } else { "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" @@ -101,8 +99,15 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix.unsafeArray(): DoubleArray = if (this is RealBuffer) { + array + } else { + DoubleArray(size) { get(it) } + } + + val a = this.buffer.unsafeArray() + val b = other.buffer.unsafeArray() for (i in (0 until rowNum)) { for (j in (0 until other.colNum)) { @@ -112,6 +117,6 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix : Matrix { */ fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix - companion object { - - } + companion object } -fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = +fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix = MatrixContext.real.produce(rows, columns, initializer) /** @@ -41,7 +39,7 @@ fun Structure2D.Companion.square(vararg elements: T): FeaturedMatrix.features get() = (this as? FeaturedMatrix)?.features?: emptySet() +val Matrix<*>.features: Set get() = (this as? FeaturedMatrix)?.features ?: emptySet() /** * Check if matrix has the given feature class @@ -68,7 +66,7 @@ fun > GenericMatrixContext.one(rows: Int, columns: In * A virtual matrix of zeroes */ fun > GenericMatrixContext.zero(rows: Int, columns: Int): FeaturedMatrix = - VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } class TransposedFeature(val original: Matrix) : MatrixFeature @@ -83,4 +81,4 @@ fun Matrix.transpose(): Matrix { ) { i, j -> get(j, i) } } -infix fun Matrix.dot(other: Matrix): Matrix = with(MatrixContext.real) { dot(other) } \ No newline at end of file +infix fun Matrix.dot(other: Matrix): Matrix = with(MatrixContext.real) { dot(other) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index 87e0ef027..d04a99fbb 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -18,7 +18,7 @@ class LUPDecomposition( private val even: Boolean ) : LUPDecompositionFeature, DeterminantFeature { - val elementContext get() = context.elementContext + val elementContext: Field get() = context.elementContext /** * Returns the matrix L of the decomposition. @@ -67,7 +67,7 @@ class LUPDecomposition( } -fun , F : Field> GenericMatrixContext.abs(value: T) = +fun , F : Field> GenericMatrixContext.abs(value: T): T = if (value > elementContext.zero) value else with(elementContext) { -value } @@ -128,14 +128,14 @@ fun , F : Field> GenericMatrixContext.lup( luRow[col] = sum // maintain best permutation choice - if (abs(sum) > largest) { - largest = abs(sum) + if (this@lup.abs(sum) > largest) { + largest = this@lup.abs(sum) max = row } } // Singularity check - if (checkSingular(abs(lu[max, col]))) { + if (checkSingular(this@lup.abs(lu[max, col]))) { error("The matrix is singular") } @@ -169,9 +169,10 @@ fun , F : Field> GenericMatrixContext.lup( inline fun , F : Field> GenericMatrixContext.lup( matrix: Matrix, noinline checkSingular: (T) -> Boolean -) = lup(T::class, matrix, checkSingular) +): LUPDecomposition = lup(T::class, matrix, checkSingular) -fun GenericMatrixContext.lup(matrix: Matrix) = lup(Double::class, matrix) { it < 1e-11 } +fun GenericMatrixContext.lup(matrix: Matrix): LUPDecomposition = + lup(Double::class, matrix) { it < 1e-11 } fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Matrix { @@ -185,7 +186,7 @@ fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Mat // Apply permutations to b val bp = create { _, _ -> zero } - for (row in 0 until pivot.size) { + for (row in pivot.indices) { val bpRow = bp.row(row) val pRow = pivot[row] for (col in 0 until matrix.colNum) { @@ -194,7 +195,7 @@ fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Mat } // Solve LY = b - for (col in 0 until pivot.size) { + for (col in pivot.indices) { val bpCol = bp.row(col) for (i in col + 1 until pivot.size) { val bpI = bp.row(i) @@ -225,7 +226,7 @@ fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Mat } } -inline fun LUPDecomposition.solve(matrix: Matrix) = solve(T::class, matrix) +inline fun LUPDecomposition.solve(matrix: Matrix): Matrix = solve(T::class, matrix) /** * Solve a linear equation **a*x = b** @@ -240,13 +241,12 @@ inline fun , F : Field> GenericMatrixContext. return decomposition.solve(T::class, b) } -fun RealMatrixContext.solve(a: Matrix, b: Matrix) = - solve(a, b) { it < 1e-11 } +fun RealMatrixContext.solve(a: Matrix, b: Matrix): Matrix = solve(a, b) { it < 1e-11 } inline fun , F : Field> GenericMatrixContext.inverse( matrix: Matrix, noinline checkSingular: (T) -> Boolean -) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular) +): Matrix = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular) -fun RealMatrixContext.inverse(matrix: Matrix) = - solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 } \ No newline at end of file +fun RealMatrixContext.inverse(matrix: Matrix): Matrix = + solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgebra.kt index 0456ffebb..fb49d18ed 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgebra.kt @@ -1,12 +1,8 @@ package scientifik.kmath.linear -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Norm -import scientifik.kmath.operations.RealField import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.VirtualBuffer -import scientifik.kmath.structures.asSequence typealias Point = Buffer @@ -19,8 +15,6 @@ interface LinearSolver { fun inverse(a: Matrix): Matrix } -typealias RealMatrix = Matrix - /** * Convert matrix to vector if it is possible */ @@ -31,4 +25,4 @@ fun Matrix.asPoint(): Point = error("Can't convert matrix with more than one column to vector") } -fun Point.asMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) } \ No newline at end of file +fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(size, 1) { i, _ -> get(i) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt index 466dbea6e..516f65bb8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt @@ -1,14 +1,46 @@ package scientifik.kmath.linear +import scientifik.kmath.structures.Buffer +import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.asBuffer -class MatrixBuilder(val rows: Int, val columns: Int) { - operator fun invoke(vararg elements: T): FeaturedMatrix { +class MatrixBuilder(val rows: Int, val columns: Int) { + operator fun invoke(vararg elements: T): FeaturedMatrix { if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns") val buffer = elements.asBuffer() return BufferMatrix(rows, columns, buffer) } + + //TODO add specific matrix builder functions like diagonal, etc } -fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) \ No newline at end of file +fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) + +fun Structure2D.Companion.row(vararg values: T): FeaturedMatrix { + val buffer = values.asBuffer() + return BufferMatrix(1, values.size, buffer) +} + +inline fun Structure2D.Companion.row( + size: Int, + factory: BufferFactory = Buffer.Companion::auto, + noinline builder: (Int) -> T +): FeaturedMatrix { + val buffer = factory(size, builder) + return BufferMatrix(1, size, buffer) +} + +fun Structure2D.Companion.column(vararg values: T): FeaturedMatrix { + val buffer = values.asBuffer() + return BufferMatrix(values.size, 1, buffer) +} + +inline fun Structure2D.Companion.column( + size: Int, + factory: BufferFactory = Buffer.Companion::auto, + noinline builder: (Int) -> T +): FeaturedMatrix { + val buffer = factory(size, builder) + return BufferMatrix(size, 1, buffer) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt index 7797fdadf..5dc86a7dd 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt @@ -29,7 +29,7 @@ interface MatrixContext : SpaceOperations> { /** * Non-boxing double matrix */ - val real = RealMatrixContext + val real: RealMatrixContext = RealMatrixContext /** * A structured matrix with custom buffer @@ -82,12 +82,12 @@ interface GenericMatrixContext> : MatrixContext { } } - override operator fun Matrix.unaryMinus() = + override operator fun Matrix.unaryMinus(): Matrix = produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } override fun add(a: Matrix, b: Matrix): Matrix { if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]") - return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) + b[i, j] } } + return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } } } override operator fun Matrix.minus(b: Matrix): Matrix { @@ -96,7 +96,7 @@ interface GenericMatrixContext> : MatrixContext { } override fun multiply(a: Matrix, k: Number): Matrix = - produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) * k } } + produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } } operator fun Number.times(matrix: FeaturedMatrix): Matrix = matrix * this diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt index de315071f..87cfe21b0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt @@ -1,7 +1,7 @@ package scientifik.kmath.linear /** - * A marker interface representing some matrix feature like diagonal, sparce, zero, etc. Features used to optimize matrix + * A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix * operations performance in some cases. */ interface MatrixFeature @@ -36,19 +36,19 @@ interface DeterminantFeature : MatrixFeature { } @Suppress("FunctionName") -fun DeterminantFeature(determinant: T) = object: DeterminantFeature{ +fun DeterminantFeature(determinant: T): DeterminantFeature = object : DeterminantFeature { override val determinant: T = determinant } /** * Lower triangular matrix */ -object LFeature: MatrixFeature +object LFeature : MatrixFeature /** * Upper triangular feature */ -object UFeature: MatrixFeature +object UFeature : MatrixFeature /** * TODO add documentation @@ -59,4 +59,4 @@ interface LUPDecompositionFeature : MatrixFeature { val p: FeaturedMatrix } -//TODO add sparse matrix feature \ No newline at end of file +//TODO add sparse matrix feature diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt index 8e14e2882..691b464fc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt @@ -54,7 +54,7 @@ interface VectorSpace> : Space> { size: Int, space: S, bufferFactory: BufferFactory = Buffer.Companion::boxing - ) = BufferVectorSpace(size, space, bufferFactory) + ): BufferVectorSpace = BufferVectorSpace(size, space, bufferFactory) /** * Automatic buffered vector, unboxed if it is possible @@ -70,6 +70,6 @@ class BufferVectorSpace>( override val space: S, val bufferFactory: BufferFactory ) : VectorSpace { - override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer) + override fun produce(initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) //override fun produceElement(initializer: (Int) -> T): Vector = BufferVector(this, produce(initializer)) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt index 0806cabea..207151d57 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -20,7 +20,7 @@ class VirtualMatrix( override fun get(i: Int, j: Int): T = generator(i, j) - override fun suggestFeature(vararg features: MatrixFeature) = + override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix = VirtualMatrix(rowNum, colNum, this.features + features, generator) override fun equals(other: Any?): Boolean { @@ -56,4 +56,4 @@ class VirtualMatrix( } } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt index ed77054cf..db8863ae8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt @@ -22,12 +22,12 @@ class DerivationResult( val deriv: Map, T>, val context: Field ) : Variable(value) { - fun deriv(variable: Variable) = deriv[variable] ?: context.zero + fun deriv(variable: Variable): T = deriv[variable] ?: context.zero /** * compute divergence */ - fun div() = context.run { sum(deriv.values) } + fun div(): T = context.run { sum(deriv.values) } /** * Compute a gradient for variables in given order @@ -53,7 +53,7 @@ class DerivationResult( * ``` */ fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult = - AutoDiffContext(this).run { + AutoDiffContext(this).run { val result = body() result.d = context.one// computing derivative w.r.t result runBackwardPass() @@ -86,24 +86,24 @@ abstract class AutoDiffField> : Field> { abstract fun variable(value: T): Variable - inline fun variable(block: F.() -> T) = variable(context.block()) + inline fun variable(block: F.() -> T): Variable = variable(context.block()) // Overloads for Double constants - operator fun Number.plus(that: Variable): Variable = - derive(variable { this@plus.toDouble() * one + that.value }) { z -> - that.d += z.d + override operator fun Number.plus(b: Variable): Variable = + derive(variable { this@plus.toDouble() * one + b.value }) { z -> + b.d += z.d } - operator fun Variable.plus(b: Number): Variable = b.plus(this) + override operator fun Variable.plus(b: Number): Variable = b.plus(this) - operator fun Number.minus(that: Variable): Variable = - derive(variable { this@minus.toDouble() * one - that.value }) { z -> - that.d -= z.d + override operator fun Number.minus(b: Variable): Variable = + derive(variable { this@minus.toDouble() * one - b.value }) { z -> + b.d -= z.d } - operator fun Variable.minus(that: Number): Variable = - derive(variable { this@minus.value - one * that.toDouble() }) { z -> + override operator fun Variable.minus(b: Number): Variable = + derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } } @@ -236,4 +236,4 @@ fun > AutoDiffField.sin(x: Variable): Var fun > AutoDiffField.cos(x: Variable): Variable = derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) - } \ No newline at end of file + } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt index 90ce5da68..d3bf0891f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt @@ -1,5 +1,7 @@ package scientifik.kmath.misc +import kotlin.math.abs + /** * Convert double range to sequence. * @@ -8,29 +10,37 @@ package scientifik.kmath.misc * * If step is negative, the same goes from upper boundary downwards */ -fun ClosedFloatingPointRange.toSequence(step: Double): Sequence = - when { - step == 0.0 -> error("Zero step in double progression") - step > 0 -> sequence { - var current = start - while (current <= endInclusive) { - yield(current) - current += step - } - } - else -> sequence { - var current = endInclusive - while (current >= start) { - yield(current) - current += step - } - } +fun ClosedFloatingPointRange.toSequenceWithStep(step: Double): Sequence = when { + step == 0.0 -> error("Zero step in double progression") + step > 0 -> sequence { + var current = start + while (current <= endInclusive) { + yield(current) + current += step } + } + else -> sequence { + var current = endInclusive + while (current >= start) { + yield(current) + current += step + } + } +} + +/** + * Convert double range to sequence with the fixed number of points + */ +fun ClosedFloatingPointRange.toSequenceWithPoints(numPoints: Int): Sequence { + require(numPoints > 1) { "The number of points should be more than 2" } + return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1)) +} /** * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] */ +@Deprecated("Replace by 'toSequenceWithPoints'") fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { if (numPoints < 2) error("Can't create generic grid with less than two points") return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt index c3cfc448a..a0f4525cc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt @@ -1,14 +1,15 @@ package scientifik.kmath.misc import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke import kotlin.jvm.JvmName - /** - * Generic cumulative operation on iterator - * @param T type of initial iterable - * @param R type of resulting iterable - * @param initial lazy evaluated + * Generic cumulative operation on iterator. + * + * @param T the type of initial iterable. + * @param R the type of resulting iterable. + * @param initial lazy evaluated. */ fun Iterator.cumulative(initial: R, operation: (R, T) -> R): Iterator = object : Iterator { var state: R = initial @@ -36,41 +37,41 @@ fun List.cumulative(initial: R, operation: (R, T) -> R): List = /** * Cumulative sum with custom space */ -fun Iterable.cumulativeSum(space: Space) = with(space) { +fun Iterable.cumulativeSum(space: Space): Iterable = space { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun Iterable.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = this.cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Iterable.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = this.cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Iterable.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = this.cumulative(0L) { element, sum -> sum + element } -fun Sequence.cumulativeSum(space: Space) = with(space) { +fun Sequence.cumulativeSum(space: Space): Sequence = with(space) { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun Sequence.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = this.cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Sequence.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = this.cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Sequence.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = this.cumulative(0L) { element, sum -> sum + element } -fun List.cumulativeSum(space: Space) = with(space) { +fun List.cumulativeSum(space: Space): List = with(space) { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun List.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } +fun List.cumulativeSum(): List = this.cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun List.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } +fun List.cumulativeSum(): List = this.cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun List.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } \ No newline at end of file +fun List.cumulativeSum(): List = this.cumulative(0L) { element, sum -> sum + element } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 485185526..f18bde597 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -1,95 +1,340 @@ package scientifik.kmath.operations +/** + * Stub for DSL the [Algebra] is. + */ @DslMarker annotation class KMathContext /** - * Marker interface for any algebra + * Represents an algebraic structure. + * + * @param T the type of element of this structure. */ -interface Algebra +interface Algebra { + /** + * Wrap raw string or variable + */ + fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") -inline operator fun , R> T.invoke(block: T.() -> R): R = run(block) + /** + * Dynamic call of unary operation with name [operation] on [arg] + */ + fun unaryOperation(operation: String, arg: T): T + + /** + * Dynamic call of binary operation [operation] on [left] and [right] + */ + fun binaryOperation(operation: String, left: T, right: T): T +} /** - * Space-like operations without neutral element + * An algebraic structure where elements can have numeric representation. + * + * @param T the type of element of this structure. + */ +interface NumericAlgebra : Algebra { + /** + * Wraps a number. + */ + fun number(value: Number): T + + /** + * Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. + */ + fun leftSideNumberOperation(operation: String, left: Number, right: T): T = + binaryOperation(operation, number(left), right) + + /** + * Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. + */ + fun rightSideNumberOperation(operation: String, left: T, right: Number): T = + leftSideNumberOperation(operation, right, left) +} + +/** + * Call a block with an [Algebra] as receiver. + */ +inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) + +/** + * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as + * multiplication by scalars. + * + * @param T the type of element of this semispace. */ interface SpaceOperations : Algebra { /** - * Addition operation for two context elements + * Addition of two elements. + * + * @param a the addend. + * @param b the augend. + * @return the sum. */ fun add(a: T, b: T): T /** - * Multiplication operation for context element and real number + * Multiplication of element by scalar. + * + * @param a the multiplier. + * @param k the multiplicand. + * @return the produce. */ fun multiply(a: T, k: Number): T - //Operation to be performed in this context + // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176 + + /** + * The negation of this element. + * + * @receiver this value. + * @return the additive inverse of this value. + */ operator fun T.unaryMinus(): T = multiply(this, -1.0) + /** + * Returns this value. + * + * @receiver this value. + * @return this value. + */ + operator fun T.unaryPlus(): T = this + + /** + * Addition of two elements. + * + * @receiver the addend. + * @param b the augend. + * @return the sum. + */ operator fun T.plus(b: T): T = add(this, b) + + /** + * Subtraction of two elements. + * + * @receiver the minuend. + * @param b the subtrahend. + * @return the difference. + */ operator fun T.minus(b: T): T = add(this, -b) - operator fun T.times(k: Number) = multiply(this, k.toDouble()) - operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) - operator fun Number.times(b: T) = b * this + + /** + * Multiplication of this element by a scalar. + * + * @receiver the multiplier. + * @param k the multiplicand. + * @return the product. + */ + operator fun T.times(k: Number): T = multiply(this, k.toDouble()) + + /** + * Division of this element by scalar. + * + * @receiver the dividend. + * @param k the divisor. + * @return the quotient. + */ + operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble()) + + /** + * Multiplication of this number by element. + * + * @receiver the multiplier. + * @param b the multiplicand. + * @return the product. + */ + operator fun Number.times(b: T): T = b * this + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + PLUS_OPERATION -> arg + MINUS_OPERATION -> -arg + else -> error("Unary operation $operation not defined in $this") + } + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + PLUS_OPERATION -> add(left, right) + MINUS_OPERATION -> left - right + else -> error("Binary operation $operation not defined in $this") + } + + companion object { + /** + * The identifier of addition. + */ + const val PLUS_OPERATION: String = "+" + + /** + * The identifier of subtraction (and negation). + */ + const val MINUS_OPERATION: String = "-" + + const val NOT_OPERATION: String = "!" + } } - /** - * A general interface representing linear context of some kind. - * The context defines sum operation for its elements and multiplication by real value. - * One must note that in some cases context is a singleton class, but in some cases it - * works as a context for operations inside it. + * Represents linear space, i.e. algebraic structure with associative binary operation called "addition" and its neutral + * element as well as multiplication by scalars. * - * TODO do we need non-commutative context? + * @param T the type of element of this group. */ interface Space : SpaceOperations { /** - * Neutral element for sum operation + * The neutral element of addition. */ val zero: T } /** - * Operations on ring without multiplication neutral element + * Represents semiring, i.e. algebraic structure with two associative binary operations called "addition" and + * "multiplication". + * + * @param T the type of element of this semiring. */ interface RingOperations : SpaceOperations { /** - * Multiplication for two field elements + * Multiplies two elements. + * + * @param a the multiplier. + * @param b the multiplicand. */ fun multiply(a: T, b: T): T + /** + * Multiplies this element by scalar. + * + * @receiver the multiplier. + * @param b the multiplicand. + */ operator fun T.times(b: T): T = multiply(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + TIMES_OPERATION -> multiply(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object { + /** + * The identifier of multiplication. + */ + const val TIMES_OPERATION: String = "*" + } } /** - * The same as {@link Space} but with additional multiplication operation + * Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and + * "multiplication" and their neutral elements. + * + * @param T the type of element of this ring. */ -interface Ring : Space, RingOperations { +interface Ring : Space, RingOperations, NumericAlgebra { /** * neutral operation for multiplication */ val one: T -// operator fun T.plus(b: Number) = this.plus(b * one) -// operator fun Number.plus(b: T) = b + this -// -// operator fun T.minus(b: Number) = this.minus(b * one) -// operator fun Number.minus(b: T) = -b + this + override fun number(value: Number): T = one * value.toDouble() + + override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { + SpaceOperations.PLUS_OPERATION -> left + right + SpaceOperations.MINUS_OPERATION -> left - right + RingOperations.TIMES_OPERATION -> left * right + else -> super.leftSideNumberOperation(operation, left, right) + } + + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { + SpaceOperations.PLUS_OPERATION -> left + right + SpaceOperations.MINUS_OPERATION -> left - right + RingOperations.TIMES_OPERATION -> left * right + else -> super.rightSideNumberOperation(operation, left, right) + } + + /** + * Addition of element and scalar. + * + * @receiver the addend. + * @param b the augend. + */ + operator fun T.plus(b: Number): T = this + number(b) + + /** + * Addition of scalar and element. + * + * @receiver the addend. + * @param b the augend. + */ + operator fun Number.plus(b: T): T = b + this + + /** + * Subtraction of element from number. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + operator fun T.minus(b: Number): T = this - number(b) + + /** + * Subtraction of number from element. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + operator fun Number.minus(b: T): T = -b + this } /** - * All ring operations but without neutral elements + * Represents semifield, i.e. algebraic structure with three operations: associative "addition" and "multiplication", + * and "division". + * + * @param T the type of element of this semifield. */ interface FieldOperations : RingOperations { + /** + * Division of two elements. + * + * @param a the dividend. + * @param b the divisor. + * @return the quotient. + */ fun divide(a: T, b: T): T + /** + * Division of two elements. + * + * @receiver the dividend. + * @param b the divisor. + * @return the quotient. + */ operator fun T.div(b: T): T = divide(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + DIV_OPERATION -> divide(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object { + /** + * The identifier of division. + */ + const val DIV_OPERATION: String = "/" + } } /** - * Four operations algebra + * Represents field, i.e. algebraic structure with three operations: associative "addition" and "multiplication", + * and "division" and their neutral elements. + * + * @param T the type of element of this semifield. */ interface Field : Ring, FieldOperations { - operator fun Number.div(b: T) = this * divide(one, b) + /** + * Division of element by scalar. + * + * @receiver the dividend. + * @param b the divisor. + * @return the quotient. + */ + operator fun Number.div(b: T): T = this * divide(one, b) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt index 093021ae3..197897c14 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt @@ -2,47 +2,107 @@ package scientifik.kmath.operations /** * The generic mathematics elements which is able to store its context - * @param T the type of space operation results - * @param I self type of the element. Needed for static type checking - * @param C the type of mathematical context for this element + * + * @param C the type of mathematical context for this element. */ interface MathElement { /** - * The context this element belongs to + * The context this element belongs to. */ val context: C } +/** + * Represents element that can be wrapped to its "primitive" value. + * + * @param T the type wrapped by this wrapper. + * @param I the type of this wrapper. + */ interface MathWrapper { + /** + * Unwraps [I] to [T]. + */ fun unwrap(): T + + /** + * Wraps [T] to [I]. + */ fun T.wrap(): I } /** - * The element of linear context - * @param T the type of space operation results - * @param I self type of the element. Needed for static type checking - * @param S the type of space + * The element of [Space]. + * + * @param T the type of space operation results. + * @param I self type of the element. Needed for static type checking. + * @param S the type of space. */ interface SpaceElement, S : Space> : MathElement, MathWrapper { + /** + * Adds element to this one. + * + * @param b the augend. + * @return the sum. + */ + operator fun plus(b: T): I = context.add(unwrap(), b).wrap() - operator fun plus(b: T) = context.add(unwrap(), b).wrap() - operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap() - operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap() - operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap() + /** + * Subtracts element from this one. + * + * @param b the subtrahend. + * @return the difference. + */ + operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap() + + /** + * Multiplies this element by number. + * + * @param k the multiplicand. + * @return the product. + */ + operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap() + + /** + * Divides this element by number. + * + * @param k the divisor. + * @return the quotient. + */ + operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap() } /** - * Ring element + * The element of [Ring]. + * + * @param T the type of space operation results. + * @param I self type of the element. Needed for static type checking. + * @param R the type of space. */ interface RingElement, R : Ring> : SpaceElement { - operator fun times(b: T) = context.multiply(unwrap(), b).wrap() + /** + * Multiplies this element by another one. + * + * @param b the multiplicand. + * @return the product. + */ + operator fun times(b: T): I = context.multiply(unwrap(), b).wrap() } /** - * Field element + * The element of [Field]. + * + * @param T the type of space operation results. + * @param I self type of the element. Needed for static type checking. + * @param F the type of field. */ interface FieldElement, F : Field> : RingElement { override val context: F - operator fun div(b: T) = context.divide(unwrap(), b).wrap() -} \ No newline at end of file + + /** + * Divides this element by another one. + * + * @param b the divisor. + * @return the quotient. + */ + operator fun div(b: T): I = context.divide(unwrap(), b).wrap() +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraExtensions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraExtensions.kt index bfb4199a3..00b16dc98 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraExtensions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraExtensions.kt @@ -1,15 +1,107 @@ package scientifik.kmath.operations +/** + * Returns the sum of all elements in the iterable in this [Space]. + * + * @receiver the algebra that provides addition. + * @param data the iterable to sum up. + * @return the sum. + */ fun Space.sum(data: Iterable): T = data.fold(zero) { left, right -> add(left, right) } + +/** + * Returns the sum of all elements in the sequence in this [Space]. + * + * @receiver the algebra that provides addition. + * @param data the sequence to sum up. + * @return the sum. + */ fun Space.sum(data: Sequence): T = data.fold(zero) { left, right -> add(left, right) } -fun > Iterable.sumWith(space: S): T = space.sum(this) +/** + * Returns an average value of elements in the iterable in this [Space]. + * + * @receiver the algebra that provides addition and division. + * @param data the iterable to find average. + * @return the average value. + */ +fun Space.average(data: Iterable): T = sum(data) / data.count() + +/** + * Returns an average value of elements in the sequence in this [Space]. + * + * @receiver the algebra that provides addition and division. + * @param data the sequence to find average. + * @return the average value. + */ +fun Space.average(data: Sequence): T = sum(data) / data.count() + +/** + * Returns the sum of all elements in the iterable in provided space. + * + * @receiver the collection to sum up. + * @param space the algebra that provides addition. + * @return the sum. + */ +fun Iterable.sumWith(space: Space): T = space.sum(this) + +/** + * Returns the sum of all elements in the sequence in provided space. + * + * @receiver the collection to sum up. + * @param space the algebra that provides addition. + * @return the sum. + */ +fun Sequence.sumWith(space: Space): T = space.sum(this) + +/** + * Returns an average value of elements in the iterable in this [Space]. + * + * @receiver the iterable to find average. + * @param space the algebra that provides addition and division. + * @return the average value. + */ +fun Iterable.averageWith(space: Space): T = space.average(this) + +/** + * Returns an average value of elements in the sequence in this [Space]. + * + * @receiver the sequence to find average. + * @param space the algebra that provides addition and division. + * @return the average value. + */ +fun Sequence.averageWith(space: Space): T = space.average(this) //TODO optimized power operation -fun RingOperations.power(arg: T, power: Int): T { + +/** + * Raises [arg] to the natural power [power]. + * + * @receiver the algebra to provide multiplication. + * @param arg the base. + * @param power the exponent. + * @return the base raised to the power. + */ +fun Ring.power(arg: T, power: Int): T { + require(power >= 0) { "The power can't be negative." } + require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } + if (power == 0) return one var res = arg - repeat(power - 1) { - res *= arg - } + repeat(power - 1) { res *= arg } return res -} \ No newline at end of file +} + +/** + * Raises [arg] to the integer power [power]. + * + * @receiver the algebra to provide multiplication and division. + * @param arg the base. + * @param power the exponent. + * @return the base raised to the power. + */ +fun Field.power(arg: T, power: Int): T { + require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } + if (power == 0) return one + if (power < 0) return one / (this as Ring).power(arg, -power) + return (this as Ring).power(arg, power) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt index 1661170d3..fd7719157 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt @@ -2,6 +2,7 @@ package scientifik.kmath.operations import scientifik.kmath.operations.BigInt.Companion.BASE import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE +import scientifik.kmath.structures.* import kotlin.math.log2 import kotlin.math.max import kotlin.math.min @@ -193,8 +194,8 @@ class BigInt internal constructor( } infix fun or(other: BigInt): BigInt { - if (this == ZERO) return other; - if (other == ZERO) return this; + if (this == ZERO) return other + if (other == ZERO) return this val resSize = max(this.magnitude.size, other.magnitude.size) val newMagnitude: Magnitude = Magnitude(resSize) for (i in 0 until resSize) { @@ -209,7 +210,7 @@ class BigInt internal constructor( } infix fun and(other: BigInt): BigInt { - if ((this == ZERO) or (other == ZERO)) return ZERO; + if ((this == ZERO) or (other == ZERO)) return ZERO val resSize = min(this.magnitude.size, other.magnitude.size) val newMagnitude: Magnitude = Magnitude(resSize) for (i in 0 until resSize) { @@ -259,7 +260,7 @@ class BigInt internal constructor( } companion object { - const val BASE = 0xffffffffUL + const val BASE: ULong = 0xffffffffUL const val BASE_SIZE: Int = 32 val ZERO: BigInt = BigInt(0, uintArrayOf()) val ONE: BigInt = BigInt(1, uintArrayOf(1u)) @@ -393,12 +394,12 @@ fun abs(x: BigInt): BigInt = x.abs() /** * Convert this [Int] to [BigInt] */ -fun Int.toBigInt() = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt())) +fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt())) /** * Convert this [Long] to [BigInt] */ -fun Long.toBigInt() = BigInt( +fun Long.toBigInt(): BigInt = BigInt( sign.toByte(), stripLeadingZeros( uintArrayOf( (kotlin.math.abs(this).toULong() and BASE).toUInt(), @@ -410,17 +411,17 @@ fun Long.toBigInt() = BigInt( /** * Convert UInt to [BigInt] */ -fun UInt.toBigInt() = BigInt(1, uintArrayOf(this)) +fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this)) /** * Convert ULong to [BigInt] */ -fun ULong.toBigInt() = BigInt( +fun ULong.toBigInt(): BigInt = BigInt( 1, stripLeadingZeros( uintArrayOf( - (this and BigInt.BASE).toUInt(), - ((this shr BigInt.BASE_SIZE) and BigInt.BASE).toUInt() + (this and BASE).toUInt(), + ((this shr BASE_SIZE) and BASE).toUInt() ) ) ) @@ -433,7 +434,7 @@ fun UIntArray.toBigInt(sign: Byte): BigInt { return BigInt(sign, this.copyOf()) } -val hexChToInt = hashMapOf( +val hexChToInt: MutableMap = hashMapOf( '0' to 0, '1' to 1, '2' to 2, '3' to 3, '4' to 4, '5' to 5, '6' to 6, '7' to 7, '8' to 8, '9' to 9, 'A' to 10, 'B' to 11, @@ -482,3 +483,18 @@ fun String.parseBigInteger(): BigInt? { } return res * sign } + +inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer = + boxing(size, initializer) + +inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer = + boxing(size, initializer) + +fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing = + BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) + +fun NDElement.Companion.bigInt( + vararg shape: Int, + initializer: BigIntField.(IntArray) -> BigInt +): BufferedNDRingElement = + NDAlgebra.bigInt(*shape).produce(initializer) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 6c529f55e..0ce144a33 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -8,15 +8,20 @@ import scientifik.memory.MemorySpec import scientifik.memory.MemoryWriter import kotlin.math.* +private val PI_DIV_2 = Complex(PI / 2, 0) + /** - * A field for complex numbers + * A field of [Complex]. */ -object ComplexField : ExtendedFieldOperations, Field { +object ComplexField : ExtendedField { override val zero: Complex = Complex(0.0, 0.0) override val one: Complex = Complex(1.0, 0.0) - val i = Complex(0.0, 1.0) + /** + * The imaginary unit. + */ + val i: Complex = Complex(0.0, 1.0) override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) @@ -30,9 +35,11 @@ object ComplexField : ExtendedFieldOperations, Field { return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) } - override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg)) - + override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 + override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg) + override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg) + override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2 override fun power(arg: Complex, pow: Number): Complex = arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) @@ -41,19 +48,59 @@ object ComplexField : ExtendedFieldOperations, Field { override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) - operator fun Double.plus(c: Complex) = add(this.toComplex(), c) + /** + * Adds complex number to real one. + * + * @receiver the addend. + * @param c the augend. + * @return the sum. + */ + operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c) - operator fun Double.minus(c: Complex) = add(this.toComplex(), -c) + /** + * Subtracts complex number from real one. + * + * @receiver the minuend. + * @param c the subtrahend. + * @return the difference. + */ + operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c) - operator fun Complex.plus(d: Double) = d + this + /** + * Adds real number to complex one. + * + * @receiver the addend. + * @param d the augend. + * @return the sum. + */ + operator fun Complex.plus(d: Double): Complex = d + this - operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) + /** + * Subtracts real number from complex one. + * + * @receiver the minuend. + * @param d the subtrahend. + * @return the difference. + */ + operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex()) - operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) + /** + * Multiplies real number by complex one. + * + * @receiver the multiplier. + * @param c the multiplicand. + * @receiver the product. + */ + operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) + + override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) } /** - * Complex number class + * Represents complex number. + * + * @property re The real part. + * @property im The imaginary part. */ data class Complex(val re: Double, val im: Double) : FieldElement, Comparable { constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) @@ -94,7 +141,13 @@ val Complex.r: Double get() = sqrt(re * re + im * im) */ val Complex.theta: Double get() = atan(im / re) -fun Double.toComplex() = Complex(this, 0.0) +/** + * Creates a complex number with real part equal to this real. + * + * @receiver the real part. + * @return the new complex number. + */ +fun Double.toComplex(): Complex = Complex(this, 0.0) inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer { return MemoryBuffer.create(Complex, size, init) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 9639e4c28..b113e07a1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -4,19 +4,45 @@ import kotlin.math.abs import kotlin.math.pow as kpow /** - * Advanced Number-like field that implements basic operations + * Advanced Number-like semifield that implements basic operations. */ interface ExtendedFieldOperations : - FieldOperations, - TrigonometricOperations, + InverseTrigonometricOperations, PowerOperations, - ExponentialOperations + ExponentialOperations { -interface ExtendedField : ExtendedFieldOperations, Field + override fun tan(arg: T): T = sin(arg) / cos(arg) + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + TrigonometricOperations.COS_OPERATION -> cos(arg) + TrigonometricOperations.SIN_OPERATION -> sin(arg) + TrigonometricOperations.TAN_OPERATION -> tan(arg) + InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) + InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) + InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) + PowerOperations.SQRT_OPERATION -> sqrt(arg) + ExponentialOperations.EXP_OPERATION -> exp(arg) + ExponentialOperations.LN_OPERATION -> ln(arg) + else -> super.unaryOperation(operation, arg) + } +} + + +/** + * Advanced Number-like field that implements basic operations. + */ +interface ExtendedField : ExtendedFieldOperations, Field { + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { + PowerOperations.POW_OPERATION -> power(left, right) + else -> super.rightSideNumberOperation(operation, left, right) + } +} /** * Real field element wrapping double. * + * @property value the [Double] value wrapped by this [Real]. + * * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 */ inline class Real(val value: Double) : FieldElement { @@ -24,74 +50,90 @@ inline class Real(val value: Double) : FieldElement { override fun Double.wrap(): Real = Real(value) - override val context get() = RealField + override val context: RealField get() = RealField companion object } /** - * A field for double without boxing. Does not produce appropriate field element + * A field for [Double] without boxing. Does not produce appropriate field element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object RealField : ExtendedField, Norm { override val zero: Double = 0.0 - override inline fun add(a: Double, b: Double) = a + b - override inline fun multiply(a: Double, b: Double) = a * b - override inline fun multiply(a: Double, k: Number) = a * k.toDouble() + override inline fun add(a: Double, b: Double): Double = a + b + override inline fun multiply(a: Double, b: Double): Double = a * b + override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() override val one: Double = 1.0 - override inline fun divide(a: Double, b: Double) = a / b + override inline fun divide(a: Double, b: Double): Double = a / b - override inline fun sin(arg: Double) = kotlin.math.sin(arg) - override inline fun cos(arg: Double) = kotlin.math.cos(arg) + override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) + override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) + override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) + override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) + override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) + override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) - override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) + override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) - override inline fun exp(arg: Double) = kotlin.math.exp(arg) - override inline fun ln(arg: Double) = kotlin.math.ln(arg) + override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) + override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) - override inline fun norm(arg: Double) = abs(arg) + override inline fun norm(arg: Double): Double = abs(arg) - override inline fun Double.unaryMinus() = -this + override inline fun Double.unaryMinus(): Double = -this - override inline fun Double.plus(b: Double) = this + b + override inline fun Double.plus(b: Double): Double = this + b - override inline fun Double.minus(b: Double) = this - b + override inline fun Double.minus(b: Double): Double = this - b - override inline fun Double.times(b: Double) = this * b + override inline fun Double.times(b: Double): Double = this * b - override inline fun Double.div(b: Double) = this / b + override inline fun Double.div(b: Double): Double = this / b + + override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { + PowerOperations.POW_OPERATION -> left pow right + else -> super.binaryOperation(operation, left, right) + } } +/** + * A field for [Float] without boxing. Does not produce appropriate field element. + */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object FloatField : ExtendedField, Norm { override val zero: Float = 0f - override inline fun add(a: Float, b: Float) = a + b - override inline fun multiply(a: Float, b: Float) = a * b - override inline fun multiply(a: Float, k: Number) = a * k.toFloat() + override inline fun add(a: Float, b: Float): Float = a + b + override inline fun multiply(a: Float, b: Float): Float = a * b + override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() override val one: Float = 1f - override inline fun divide(a: Float, b: Float) = a / b + override inline fun divide(a: Float, b: Float): Float = a / b - override inline fun sin(arg: Float) = kotlin.math.sin(arg) - override inline fun cos(arg: Float) = kotlin.math.cos(arg) + override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) + override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) + override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) + override inline fun acos(arg: Float): Float = kotlin.math.acos(arg) + override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) + override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) - override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) + override inline fun power(arg: Float, pow: Number): Float = arg.pow(pow.toFloat()) - override inline fun exp(arg: Float) = kotlin.math.exp(arg) - override inline fun ln(arg: Float) = kotlin.math.ln(arg) + override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) + override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) - override inline fun norm(arg: Float) = abs(arg) + override inline fun norm(arg: Float): Float = abs(arg) - override inline fun Float.unaryMinus() = -this + override inline fun Float.unaryMinus(): Float = -this - override inline fun Float.plus(b: Float) = this + b + override inline fun Float.plus(b: Float): Float = this + b - override inline fun Float.minus(b: Float) = this - b + override inline fun Float.minus(b: Float): Float = this - b - override inline fun Float.times(b: Float) = this * b + override inline fun Float.times(b: Float): Float = this * b - override inline fun Float.div(b: Float) = this / b + override inline fun Float.div(b: Float): Float = this / b } /** @@ -100,14 +142,14 @@ object FloatField : ExtendedField, Norm { @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object IntRing : Ring, Norm { override val zero: Int = 0 - override inline fun add(a: Int, b: Int) = a + b - override inline fun multiply(a: Int, b: Int) = a * b - override inline fun multiply(a: Int, k: Number) = (k * a) + override inline fun add(a: Int, b: Int): Int = a + b + override inline fun multiply(a: Int, b: Int): Int = a * b + override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a override val one: Int = 1 - override inline fun norm(arg: Int) = abs(arg) + override inline fun norm(arg: Int): Int = abs(arg) - override inline fun Int.unaryMinus() = -this + override inline fun Int.unaryMinus(): Int = -this override inline fun Int.plus(b: Int): Int = this + b @@ -122,20 +164,20 @@ object IntRing : Ring, Norm { @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object ShortRing : Ring, Norm { override val zero: Short = 0 - override inline fun add(a: Short, b: Short) = (a + b).toShort() - override inline fun multiply(a: Short, b: Short) = (a * b).toShort() - override inline fun multiply(a: Short, k: Number) = (a * k) + override inline fun add(a: Short, b: Short): Short = (a + b).toShort() + override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() + override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() override val one: Short = 1 override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() - override inline fun Short.unaryMinus() = (-this).toShort() + override inline fun Short.unaryMinus(): Short = (-this).toShort() - override inline fun Short.plus(b: Short) = (this + b).toShort() + override inline fun Short.plus(b: Short): Short = (this + b).toShort() - override inline fun Short.minus(b: Short) = (this - b).toShort() + override inline fun Short.minus(b: Short): Short = (this - b).toShort() - override inline fun Short.times(b: Short) = (this * b).toShort() + override inline fun Short.times(b: Short): Short = (this * b).toShort() } /** @@ -144,20 +186,20 @@ object ShortRing : Ring, Norm { @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object ByteRing : Ring, Norm { override val zero: Byte = 0 - override inline fun add(a: Byte, b: Byte) = (a + b).toByte() - override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte() - override inline fun multiply(a: Byte, k: Number) = (a * k) + override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() + override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() + override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() override val one: Byte = 1 override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() - override inline fun Byte.unaryMinus() = (-this).toByte() + override inline fun Byte.unaryMinus(): Byte = (-this).toByte() - override inline fun Byte.plus(b: Byte) = (this + b).toByte() + override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() - override inline fun Byte.minus(b: Byte) = (this - b).toByte() + override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() - override inline fun Byte.times(b: Byte) = (this * b).toByte() + override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() } /** @@ -166,18 +208,18 @@ object ByteRing : Ring, Norm { @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object LongRing : Ring, Norm { override val zero: Long = 0 - override inline fun add(a: Long, b: Long) = (a + b) - override inline fun multiply(a: Long, b: Long) = (a * b) - override inline fun multiply(a: Long, k: Number) = (a * k) + override inline fun add(a: Long, b: Long): Long = (a + b) + override inline fun multiply(a: Long, b: Long): Long = (a * b) + override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() override val one: Long = 1 override fun norm(arg: Long): Long = abs(arg) - override inline fun Long.unaryMinus() = (-this) + override inline fun Long.unaryMinus(): Long = (-this) - override inline fun Long.plus(b: Long) = (this + b) + override inline fun Long.plus(b: Long): Long = (this + b) - override inline fun Long.minus(b: Long) = (this - b) + override inline fun Long.minus(b: Long): Long = (this - b) - override inline fun Long.times(b: Long) = (this * b) -} \ No newline at end of file + override inline fun Long.times(b: Long): Long = (this * b) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index bd83932e7..dea45a145 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -1,57 +1,214 @@ package scientifik.kmath.operations - -/* Trigonometric operations */ - /** - * A container for trigonometric operations for specific type. Trigonometric operations are limited to fields. + * A container for trigonometric operations for specific type. They are limited to semifields. * * The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. - * It also allows to override behavior for optional operations - * + * It also allows to override behavior for optional operations. */ interface TrigonometricOperations : FieldOperations { + /** + * Computes the sine of [arg]. + */ fun sin(arg: T): T + + /** + * Computes the cosine of [arg]. + */ fun cos(arg: T): T - fun tg(arg: T): T = sin(arg) / cos(arg) + /** + * Computes the tangent of [arg]. + */ + fun tan(arg: T): T - fun ctg(arg: T): T = cos(arg) / sin(arg) + companion object { + /** + * The identifier of sine. + */ + const val SIN_OPERATION: String = "sin" + + /** + * The identifier of cosine. + */ + const val COS_OPERATION: String = "cos" + + /** + * The identifier of tangent. + */ + const val TAN_OPERATION: String = "tan" + } } -fun >> sin(arg: T): T = arg.context.sin(arg) -fun >> cos(arg: T): T = arg.context.cos(arg) -fun >> tg(arg: T): T = arg.context.tg(arg) -fun >> ctg(arg: T): T = arg.context.ctg(arg) - -/* Power and roots */ - /** - * A context extension to include power operations like square roots, etc + * A container for inverse trigonometric operations for specific type. They are limited to semifields. + * + * The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. + * It also allows to override behavior for optional operations. + */ +interface InverseTrigonometricOperations : TrigonometricOperations { + /** + * Computes the inverse sine of [arg]. + */ + fun asin(arg: T): T + + /** + * Computes the inverse cosine of [arg]. + */ + fun acos(arg: T): T + + /** + * Computes the inverse tangent of [arg]. + */ + fun atan(arg: T): T + + companion object { + /** + * The identifier of inverse sine. + */ + const val ASIN_OPERATION: String = "asin" + + /** + * The identifier of inverse cosine. + */ + const val ACOS_OPERATION: String = "acos" + + /** + * The identifier of inverse tangent. + */ + const val ATAN_OPERATION: String = "atan" + } +} + +/** + * Computes the sine of [arg]. + */ +fun >> sin(arg: T): T = arg.context.sin(arg) + +/** + * Computes the cosine of [arg]. + */ +fun >> cos(arg: T): T = arg.context.cos(arg) + +/** + * Computes the tangent of [arg]. + */ +fun >> tan(arg: T): T = arg.context.tan(arg) + +/** + * Computes the inverse sine of [arg]. + */ +fun >> asin(arg: T): T = arg.context.asin(arg) + +/** + * Computes the inverse cosine of [arg]. + */ +fun >> acos(arg: T): T = arg.context.acos(arg) + +/** + * Computes the inverse tangent of [arg]. + */ +fun >> atan(arg: T): T = arg.context.atan(arg) + +/** + * A context extension to include power operations based on exponentiation. */ interface PowerOperations : Algebra { + /** + * Raises [arg] to the power [pow]. + */ fun power(arg: T, pow: Number): T - fun sqrt(arg: T) = power(arg, 0.5) - infix fun T.pow(pow: Number) = power(this, pow) + /** + * Computes the square root of the value [arg]. + */ + fun sqrt(arg: T): T = power(arg, 0.5) + + /** + * Raises this value to the power [pow]. + */ + infix fun T.pow(pow: Number): T = power(this, pow) + + companion object { + /** + * The identifier of exponentiation. + */ + const val POW_OPERATION: String = "pow" + + /** + * The identifier of square root. + */ + const val SQRT_OPERATION: String = "sqrt" + } } +/** + * Raises this element to the power [pow]. + * + * @receiver the base. + * @param power the exponent. + * @return the base raised to the power. + */ infix fun >> T.pow(power: Double): T = context.power(this, power) + +/** + * Computes the square root of the value [arg]. + */ fun >> sqrt(arg: T): T = arg pow 0.5 + +/** + * Computes the square of the value [arg]. + */ fun >> sqr(arg: T): T = arg pow 2.0 -/* Exponential */ - -interface ExponentialOperations: Algebra { +/** + * A container for operations related to `exp` and `ln` functions. + */ +interface ExponentialOperations : Algebra { + /** + * Computes Euler's number `e` raised to the power of the value [arg]. + */ fun exp(arg: T): T + + /** + * Computes the natural logarithm (base `e`) of the value [arg]. + */ fun ln(arg: T): T + + companion object { + /** + * The identifier of exponential function. + */ + const val EXP_OPERATION: String = "exp" + + /** + * The identifier of natural logarithm. + */ + const val LN_OPERATION: String = "ln" + } } +/** + * The identifier of exponential function. + */ fun >> exp(arg: T): T = arg.context.exp(arg) + +/** + * The identifier of natural logarithm. + */ fun >> ln(arg: T): T = arg.context.ln(arg) +/** + * A container for norm functional on element. + */ interface Norm { + /** + * Computes the norm of [arg] (i.e. absolute value or vector length). + */ fun norm(arg: T): R } -fun >, R> norm(arg: T): R = arg.context.norm(arg) \ No newline at end of file +/** + * Computes the norm of [arg] (i.e. absolute value or vector length). + */ +fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt index e6d4b226d..4cbb565c1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt @@ -3,7 +3,6 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Field import scientifik.kmath.operations.FieldElement - class BoxingNDField>( override val shape: IntArray, override val elementContext: F, @@ -19,10 +18,10 @@ class BoxingNDField>( if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") } - override val zero by lazy { produce { zero } } - override val one by lazy { produce { one } } + override val zero: BufferedNDFieldElement by lazy { produce { zero } } + override val one: BufferedNDFieldElement by lazy { produce { one } } - override fun produce(initializer: F.(IntArray) -> T) = + override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement = BufferedNDFieldElement( this, buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) @@ -79,4 +78,4 @@ inline fun , R> F.nd( ): R { val ndfield: BoxingNDField = NDField.boxing(this, *shape, bufferFactory = bufferFactory) return ndfield.action() -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt index 39fc555e8..f7be95736 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt @@ -18,10 +18,10 @@ class BoxingNDRing>( if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") } - override val zero by lazy { produce { zero } } - override val one by lazy { produce { one } } + override val zero: BufferedNDRingElement by lazy { produce { zero } } + override val one: BufferedNDRingElement by lazy { produce { one } } - override fun produce(initializer: R.(IntArray) -> T) = + override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement = BufferedNDRingElement( this, buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) @@ -69,4 +69,4 @@ class BoxingNDRing>( override fun NDBuffer.toElement(): RingElement, *, out BufferedNDRing> = BufferedNDRingElement(this@BoxingNDRing, buffer) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt index b14da5d99..00832b69c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt @@ -7,16 +7,16 @@ import kotlin.reflect.KClass */ class BufferAccessor2D(val type: KClass, val rowNum: Int, val colNum: Int) { - operator fun Buffer.get(i: Int, j: Int) = get(i + colNum * j) + operator fun Buffer.get(i: Int, j: Int): T = get(i + colNum * j) operator fun MutableBuffer.set(i: Int, j: Int, value: T) { set(i + colNum * j, value) } - inline fun create(init: (i: Int, j: Int) -> T) = + inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer = MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } - fun create(mat: Structure2D) = create { i, j -> mat[i, j] } + fun create(mat: Structure2D): MutableBuffer = create { i, j -> mat[i, j] } //TODO optimize wrapper fun MutableBuffer.collect(): Structure2D = @@ -41,5 +41,5 @@ class BufferAccessor2D(val type: KClass, val rowNum: Int, val colNum /** * Get row */ - fun MutableBuffer.row(i: Int) = Row(this, i) + fun MutableBuffer.row(i: Int): Row = Row(this, i) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt index 9742f3662..06922c56f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt @@ -2,7 +2,7 @@ package scientifik.kmath.structures import scientifik.kmath.operations.* -interface BufferedNDAlgebra: NDAlgebra>{ +interface BufferedNDAlgebra : NDAlgebra> { val strides: Strides override fun check(vararg elements: NDBuffer) { @@ -11,7 +11,8 @@ interface BufferedNDAlgebra: NDAlgebra>{ /** * Convert any [NDStructure] to buffered structure using strides from this context. - * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes + * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over + * indices. * * If the argument is [NDBuffer] with different strides structure, the new element will be produced. */ @@ -30,7 +31,7 @@ interface BufferedNDAlgebra: NDAlgebra>{ } -interface BufferedNDSpace> : NDSpace>, BufferedNDAlgebra { +interface BufferedNDSpace> : NDSpace>, BufferedNDAlgebra { override fun NDBuffer.toElement(): SpaceElement, *, out BufferedNDSpace> } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt index 04049368a..d1d622b23 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt @@ -3,12 +3,12 @@ package scientifik.kmath.structures import scientifik.kmath.operations.* /** - * Base interface for an element with context, containing strides + * Base class for an element with context, containing strides */ -interface BufferedNDElement : NDBuffer, NDElement> { - override val context: BufferedNDAlgebra +abstract class BufferedNDElement : NDBuffer(), NDElement> { + abstract override val context: BufferedNDAlgebra - override val strides get() = context.strides + override val strides: Strides get() = context.strides override val shape: IntArray get() = context.shape } @@ -16,7 +16,7 @@ interface BufferedNDElement : NDBuffer, NDElement> { class BufferedNDSpaceElement>( override val context: BufferedNDSpace, override val buffer: Buffer -) : BufferedNDElement, SpaceElement, BufferedNDSpaceElement, BufferedNDSpace> { +) : BufferedNDElement(), SpaceElement, BufferedNDSpaceElement, BufferedNDSpace> { override fun unwrap(): NDBuffer = this @@ -29,7 +29,7 @@ class BufferedNDSpaceElement>( class BufferedNDRingElement>( override val context: BufferedNDRing, override val buffer: Buffer -) : BufferedNDElement, RingElement, BufferedNDRingElement, BufferedNDRing> { +) : BufferedNDElement(), RingElement, BufferedNDRingElement, BufferedNDRing> { override fun unwrap(): NDBuffer = this @@ -42,7 +42,7 @@ class BufferedNDRingElement>( class BufferedNDFieldElement>( override val context: BufferedNDField, override val buffer: Buffer -) : BufferedNDElement, FieldElement, BufferedNDFieldElement, BufferedNDField> { +) : BufferedNDElement(), FieldElement, BufferedNDFieldElement, BufferedNDField> { override fun unwrap(): NDBuffer = this @@ -54,9 +54,9 @@ class BufferedNDFieldElement>( /** - * Element by element application of any operation on elements to the whole array. Just like in numpy + * Element by element application of any operation on elements to the whole array. Just like in numpy. */ -operator fun > Function1.invoke(ndElement: BufferedNDElement) = +operator fun > Function1.invoke(ndElement: BufferedNDElement): MathElement> = ndElement.context.run { map(ndElement) { invoke(it) }.toElement() } /* plus and minus */ @@ -64,13 +64,13 @@ operator fun > Function1.invoke(ndElement: BufferedN /** * Summation operation for [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.plus(arg: T) = +operator fun > BufferedNDElement.plus(arg: T): NDElement> = context.map(this) { it + arg }.wrap() /** * Subtraction operation between [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.minus(arg: T) = +operator fun > BufferedNDElement.minus(arg: T): NDElement> = context.map(this) { it - arg }.wrap() /* prod and div */ @@ -78,11 +78,11 @@ operator fun > BufferedNDElement.minus(arg: T) = /** * Product operation for [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.times(arg: T) = +operator fun > BufferedNDElement.times(arg: T): NDElement> = context.map(this) { it * arg }.wrap() /** * Division operation between [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.div(arg: T) = - context.map(this) { it / arg }.wrap() \ No newline at end of file +operator fun > BufferedNDElement.div(arg: T): NDElement> = + context.map(this) { it / arg }.wrap() diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index f02fd8dd0..5fdf79e88 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -4,42 +4,51 @@ import scientifik.kmath.operations.Complex import scientifik.kmath.operations.complex import kotlin.reflect.KClass - +/** + * Function that produces [Buffer] from its size and function that supplies values. + * + * @param T the type of buffer. + */ typealias BufferFactory = (Int, (Int) -> T) -> Buffer -typealias MutableBufferFactory = (Int, (Int) -> T) -> MutableBuffer - /** - * A generic random access structure for both primitives and objects + * Function that produces [MutableBuffer] from its size and function that supplies values. + * + * @param T the type of buffer. + */ +typealias MutableBufferFactory = (Int, (Int) -> T) -> MutableBuffer + +/** + * A generic immutable random-access structure for both primitives and objects. + * + * @param T the type of elements contained in the buffer. */ interface Buffer { - /** - * The size of the buffer + * The size of this buffer. */ val size: Int /** - * Get element at given index + * Gets element at given index. */ operator fun get(index: Int): T /** - * Iterate over all elements + * Iterates over all elements. */ operator fun iterator(): Iterator /** - * Check content eqiality with another buffer + * Checks content equality with another buffer. */ fun contentEquals(other: Buffer<*>): Boolean = asSequence().mapIndexed { index, value -> value == other[index] }.all { it } companion object { - - inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { + inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { val array = DoubleArray(size) { initializer(it) } - return DoubleBuffer(array) + return RealBuffer(array) } /** @@ -51,7 +60,7 @@ interface Buffer { inline fun auto(type: KClass, size: Int, crossinline initializer: (Int) -> T): Buffer { //TODO add resolution based on Annotation or companion resolution return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer @@ -69,17 +78,34 @@ interface Buffer { } } +/** + * Creates a sequence that returns all elements from this [Buffer]. + */ fun Buffer.asSequence(): Sequence = Sequence(::iterator) -fun Buffer.asIterable(): Iterable = asSequence().asIterable() +/** + * Creates an iterable that returns all elements from this [Buffer]. + */ +fun Buffer.asIterable(): Iterable = Iterable(::iterator) -val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1) +/** + * Returns an [IntRange] of the valid indices for this [Buffer]. + */ +val Buffer<*>.indices: IntRange get() = 0 until size +/** + * A generic mutable random-access structure for both primitives and objects. + * + * @param T the type of elements contained in the buffer. + */ interface MutableBuffer : Buffer { + /** + * Sets the array element at the specified [index] to the specified [value]. + */ operator fun set(index: Int, value: T) /** - * A shallow copy of the buffer + * Returns a shallow copy of the buffer. */ fun copy(): MutableBuffer @@ -93,7 +119,7 @@ interface MutableBuffer : Buffer { @Suppress("UNCHECKED_CAST") inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer { return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer @@ -109,14 +135,18 @@ interface MutableBuffer : Buffer { auto(T::class, size, initializer) val real: MutableBufferFactory = { size: Int, initializer: (Int) -> Double -> - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) } } } - +/** + * [Buffer] implementation over [List]. + * + * @param T the type of elements contained in the buffer. + * @property list The underlying list. + */ inline class ListBuffer(val list: List) : Buffer { - override val size: Int get() = list.size @@ -125,11 +155,26 @@ inline class ListBuffer(val list: List) : Buffer { override fun iterator(): Iterator = list.iterator() } -fun List.asBuffer() = ListBuffer(this) +/** + * Returns an [ListBuffer] that wraps the original list. + */ +fun List.asBuffer(): ListBuffer = ListBuffer(this) -@Suppress("FunctionName") -inline fun ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer() +/** + * Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an array element given its index. + */ +inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer = List(size, init).asBuffer() +/** + * [MutableBuffer] implementation over [MutableList]. + * + * @param T the type of elements contained in the buffer. + * @property list The underlying list. + */ inline class MutableListBuffer(val list: MutableList) : MutableBuffer { override val size: Int @@ -145,8 +190,14 @@ inline class MutableListBuffer(val list: MutableList) : MutableBuffer { override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) } +/** + * [MutableBuffer] implementation over [Array]. + * + * @param T the type of elements contained in the buffer. + * @property array The underlying array. + */ class ArrayBuffer(private val array: Array) : MutableBuffer { - //Can't inline because array is invariant + // Can't inline because array is invariant override val size: Int get() = array.size @@ -161,99 +212,30 @@ class ArrayBuffer(private val array: Array) : MutableBuffer { override fun copy(): MutableBuffer = ArrayBuffer(array.copyOf()) } -fun Array.asBuffer() = ArrayBuffer(this) - -inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Double = array[index] - - override fun set(index: Int, value: Double) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = DoubleBuffer(array.copyOf()) -} - -@Suppress("FunctionName") -inline fun DoubleBuffer(size: Int, init: (Int) -> Double) = DoubleBuffer(DoubleArray(size) { init(it) }) +/** + * Returns an [ArrayBuffer] that wraps the original array. + */ +fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) /** - * Transform buffer of doubles into array for high performance operations + * Immutable wrapper for [MutableBuffer]. + * + * @param T the type of elements contained in the buffer. + * @property buffer The underlying buffer. */ -val Buffer.array: DoubleArray - get() = if (this is DoubleBuffer) { - array - } else { - DoubleArray(size) { get(it) } - } - -fun DoubleArray.asBuffer() = DoubleBuffer(this) - -inline class ShortBuffer(val array: ShortArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Short = array[index] - - override fun set(index: Int, value: Short) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) - -} - -fun ShortArray.asBuffer() = ShortBuffer(this) - -inline class IntBuffer(val array: IntArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Int = array[index] - - override fun set(index: Int, value: Int) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = IntBuffer(array.copyOf()) - -} - -fun IntArray.asBuffer() = IntBuffer(this) - -inline class LongBuffer(val array: LongArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Long = array[index] - - override fun set(index: Int, value: Long) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) - -} - -fun LongArray.asBuffer() = LongBuffer(this) - inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { override val size: Int get() = buffer.size - override fun get(index: Int): T = buffer.get(index) + override fun get(index: Int): T = buffer[index] - override fun iterator() = buffer.iterator() + override fun iterator(): Iterator = buffer.iterator() } /** - * A buffer with content calculated on-demand. The calculated contect is not stored, so it is recalculated on each call. + * A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call. * Useful when one needs single element from the buffer. + * + * @param T the type of elements provided by the buffer. */ class VirtualBuffer(override val size: Int, private val generator: (Int) -> T) : Buffer { override fun get(index: Int): T { @@ -273,17 +255,16 @@ class VirtualBuffer(override val size: Int, private val generator: (Int) -> T } /** - * Convert this buffer to read-only buffer + * Convert this buffer to read-only buffer. */ -fun Buffer.asReadOnly(): Buffer = if (this is MutableBuffer) { - ReadOnlyBuffer(this) -} else { - this -} +fun Buffer.asReadOnly(): Buffer = if (this is MutableBuffer) ReadOnlyBuffer(this) else this /** - * Typealias for buffer transformations + * Typealias for buffer transformations. */ typealias BufferTransform = (Buffer) -> Buffer -typealias SuspendBufferTransform = suspend (Buffer) -> Buffer \ No newline at end of file +/** + * Typealias for buffer transformations with suspend function. + */ +typealias SuspendBufferTransform = suspend (Buffer) -> Buffer diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt index a79366a99..be0b9e5c6 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt @@ -17,8 +17,8 @@ class ComplexNDField(override val shape: IntArray) : override val strides: Strides = DefaultStrides(shape) override val elementContext: ComplexField get() = ComplexField - override val zero by lazy { produce { zero } } - override val one by lazy { produce { one } } + override val zero: ComplexNDElement by lazy { produce { zero } } + override val one: ComplexNDElement by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer = Buffer.complex(size) { initializer(it) } @@ -69,16 +69,23 @@ class ComplexNDField(override val shape: IntArray) : override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = BufferedNDFieldElement(this@ComplexNDField, buffer) - override fun power(arg: NDBuffer, pow: Number) = map(arg) { power(it, pow) } + override fun power(arg: NDBuffer, pow: Number): ComplexNDElement = map(arg) { power(it, pow) } - override fun exp(arg: NDBuffer) = map(arg) { exp(it) } + override fun exp(arg: NDBuffer): ComplexNDElement = map(arg) { exp(it) } - override fun ln(arg: NDBuffer) = map(arg) { ln(it) } + override fun ln(arg: NDBuffer): ComplexNDElement = map(arg) { ln(it) } - override fun sin(arg: NDBuffer) = map(arg) { sin(it) } + override fun sin(arg: NDBuffer): ComplexNDElement = map(arg) { sin(it) } - override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun cos(arg: NDBuffer): ComplexNDElement = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): ComplexNDElement = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): ComplexNDElement = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): ComplexNDElement = map(arg) { acos(it) } + + override fun atan(arg: NDBuffer): ComplexNDElement = map(arg) { atan(it) } } @@ -91,13 +98,13 @@ inline fun BufferedNDField.produceInline(crossinline init } /** - * Map one [ComplexNDElement] using function with indexes + * Map one [ComplexNDElement] using function with indices. */ -inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex) = +inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement = context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } /** - * Map one [ComplexNDElement] using function without indexes + * Map one [ComplexNDElement] using function without indices. */ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } @@ -107,7 +114,7 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> /** * Element by element application of any operation on elements to the whole array. Just like in numpy */ -operator fun Function1.invoke(ndElement: ComplexNDElement) = +operator fun Function1.invoke(ndElement: ComplexNDElement): ComplexNDElement = ndElement.map { this@invoke(it) } @@ -116,19 +123,18 @@ operator fun Function1.invoke(ndElement: ComplexNDElement) = /** * Summation operation for [BufferedNDElement] and single element */ -operator fun ComplexNDElement.plus(arg: Complex) = - map { it + arg } +operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg } /** * Subtraction operation between [BufferedNDElement] and single element */ -operator fun ComplexNDElement.minus(arg: Complex) = +operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement = map { it - arg } -operator fun ComplexNDElement.plus(arg: Double) = +operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement = map { it + arg } -operator fun ComplexNDElement.minus(arg: Double) = +operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement = map { it - arg } fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape) @@ -141,4 +147,4 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In */ inline fun ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { return NDField.complex(*shape).run(action) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 3437644ff..24aa48c6b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -1,14 +1,15 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.* - -interface ExtendedNDField> : - NDField, - TrigonometricOperations, - PowerOperations, - ExponentialOperations - where F : ExtendedFieldOperations, F : Field +import scientifik.kmath.operations.ExtendedField +/** + * [ExtendedField] over [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param F the extended field of structure elements. + */ +interface ExtendedNDField, N : NDStructure> : NDField, ExtendedField ///** // * NDField that supports [ExtendedField] operations on its elements @@ -41,5 +42,3 @@ interface ExtendedNDField> : // return produce { with(elementContext) { cos(arg[it]) } } // } //} - - diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt new file mode 100644 index 000000000..a2d0a71b3 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt @@ -0,0 +1,73 @@ +package scientifik.kmath.structures + +import kotlin.experimental.and + +/** + * Represents flags to supply additional info about values of buffer. + * + * @property mask bit mask value of this flag. + */ +enum class ValueFlag(val mask: Byte) { + /** + * Reports the value is NaN. + */ + NAN(0b0000_0001), + + /** + * Reports the value doesn't present in the buffer (when the type of value doesn't support `null`). + */ + MISSING(0b0000_0010), + + /** + * Reports the value is negative infinity. + */ + NEGATIVE_INFINITY(0b0000_0100), + + /** + * Reports the value is positive infinity + */ + POSITIVE_INFINITY(0b0000_1000) +} + +/** + * A buffer with flagged values. + */ +interface FlaggedBuffer : Buffer { + fun getFlag(index: Int): Byte +} + +/** + * The value is valid if all flags are down + */ +fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte() + +fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte() + +fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING) + +/** + * A real buffer which supports flags for each value like NaN or Missing + */ +class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer, Buffer { + init { + require(values.size == flags.size) { "Values and flags must have the same dimensions" } + } + + override fun getFlag(index: Int): Byte = flags[index] + + override val size: Int get() = values.size + + override fun get(index: Int): Double? = if (isValid(index)) values[index] else null + + override fun iterator(): Iterator = values.indices.asSequence().map { + if (isValid(it)) values[it] else null + }.iterator() +} + +inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { + for (i in indices) { + if (isValid(i)) { + block(values[i]) + } + } +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt new file mode 100644 index 000000000..e42df8c14 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt @@ -0,0 +1,49 @@ +package scientifik.kmath.structures + +/** + * Specialized [MutableBuffer] implementation over [FloatArray]. + * + * @property array the underlying array. + */ +inline class FloatBuffer(val array: FloatArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Float = array[index] + + override fun set(index: Int, value: Float) { + array[index] = value + } + + override fun iterator(): FloatIterator = array.iterator() + + override fun copy(): MutableBuffer = + FloatBuffer(array.copyOf()) +} + +/** + * Creates a new [FloatBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an buffer element given its index. + */ +inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) }) + +/** + * Returns a new [FloatBuffer] of given elements. + */ +fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats) + +/** + * Returns a [FloatArray] containing all of the elements of this [MutableBuffer]. + */ +val MutableBuffer.array: FloatArray + get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) }) + +/** + * Returns [FloatBuffer] over this array. + * + * @receiver the array. + * @return the new buffer. + */ +fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt new file mode 100644 index 000000000..a3f0f3c3e --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt @@ -0,0 +1,50 @@ +package scientifik.kmath.structures + +/** + * Specialized [MutableBuffer] implementation over [IntArray]. + * + * @property array the underlying array. + */ +inline class IntBuffer(val array: IntArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Int = array[index] + + override fun set(index: Int, value: Int) { + array[index] = value + } + + override fun iterator(): IntIterator = array.iterator() + + override fun copy(): MutableBuffer = + IntBuffer(array.copyOf()) + +} + +/** + * Creates a new [IntBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an buffer element given its index. + */ +inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) }) + +/** + * Returns a new [IntBuffer] of given elements. + */ +fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints) + +/** + * Returns a [IntArray] containing all of the elements of this [MutableBuffer]. + */ +val MutableBuffer.array: IntArray + get() = (if (this is IntBuffer) array else IntArray(size) { get(it) }) + +/** + * Returns [IntBuffer] over this array. + * + * @receiver the array. + * @return the new buffer. + */ +fun IntArray.asBuffer(): IntBuffer = IntBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt new file mode 100644 index 000000000..912656c68 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt @@ -0,0 +1,50 @@ +package scientifik.kmath.structures + +/** + * Specialized [MutableBuffer] implementation over [LongArray]. + * + * @property array the underlying array. + */ +inline class LongBuffer(val array: LongArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Long = array[index] + + override fun set(index: Int, value: Long) { + array[index] = value + } + + override fun iterator(): LongIterator = array.iterator() + + override fun copy(): MutableBuffer = + LongBuffer(array.copyOf()) + +} + +/** + * Creates a new [LongBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an buffer element given its index. + */ +inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) }) + +/** + * Returns a new [LongBuffer] of given elements. + */ +fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs) + +/** + * Returns a [IntArray] containing all of the elements of this [MutableBuffer]. + */ +val MutableBuffer.array: LongArray + get() = (if (this is LongBuffer) array else LongArray(size) { get(it) }) + +/** + * Returns [LongBuffer] over this array. + * + * @receiver the array. + * @return the new buffer. + */ +fun LongArray.asBuffer(): LongBuffer = LongBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt index a09f09165..1d0c87580 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt @@ -3,13 +3,16 @@ package scientifik.kmath.structures import scientifik.memory.* /** - * A non-boxing buffer based on [ByteBuffer] storage + * A non-boxing buffer over [Memory] object. + * + * @param T the type of elements contained in the buffer. + * @property memory the underlying memory segment. + * @property spec the spec of [T] type. */ open class MemoryBuffer(protected val memory: Memory, protected val spec: MemorySpec) : Buffer { - override val size: Int get() = memory.size / spec.objectSize - private val reader = memory.reader() + private val reader: MemoryReader = memory.reader() override fun get(index: Int): T = reader.read(spec, spec.objectSize * index) @@ -17,7 +20,7 @@ open class MemoryBuffer(protected val memory: Memory, protected val spe companion object { - fun create(spec: MemorySpec, size: Int) = + fun create(spec: MemorySpec, size: Int): MemoryBuffer = MemoryBuffer(Memory.allocate(size * spec.objectSize), spec) inline fun create( @@ -33,28 +36,35 @@ open class MemoryBuffer(protected val memory: Memory, protected val spe } } +/** + * A mutable non-boxing buffer over [Memory] object. + * + * @param T the type of elements contained in the buffer. + * @property memory the underlying memory segment. + * @property spec the spec of [T] type. + */ class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : MemoryBuffer(memory, spec), MutableBuffer { - private val writer = memory.writer() + private val writer: MemoryWriter = memory.writer() - override fun set(index: Int, value: T) = writer.write(spec, spec.objectSize * index, value) + override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) override fun copy(): MutableBuffer = MutableMemoryBuffer(memory.copy(), spec) companion object { - fun create(spec: MemorySpec, size: Int) = + fun create(spec: MemorySpec, size: Int): MutableMemoryBuffer = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec) inline fun create( spec: MemorySpec, size: Int, crossinline initializer: (Int) -> T - ) = + ): MutableMemoryBuffer = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> (0 until size).forEach { buffer[it] = initializer(it) } } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt index c826565cf..f09db3c72 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt @@ -56,7 +56,7 @@ interface NDAlgebra> { /** * element-by-element invoke a function working on [T] on a [NDStructure] */ - operator fun Function1.invoke(structure: N) = map(structure) { value -> this@invoke(value) } + operator fun Function1.invoke(structure: N): N = map(structure) { value -> this@invoke(value) } companion object } @@ -76,12 +76,12 @@ interface NDSpace, N : NDStructure> : Space, NDAlgebra add(arg, value) } + operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) } - operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) } + operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) } - operator fun T.plus(arg: N) = map(arg) { value -> add(this@plus, value) } - operator fun T.minus(arg: N) = map(arg) { value -> add(-this@minus, value) } + operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) } + operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) } companion object } @@ -97,20 +97,19 @@ interface NDRing, N : NDStructure> : Ring, NDSpace override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } //TODO move to extensions after KEEP-176 - operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) } + operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) } - operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) } + operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) } companion object } /** - * Field for n-dimensional structures. - * @param shape - the list of dimensions of the array - * @param elementField - operations field defined on individual array element - * @param T - the type of the element contained in ND structure - * @param F - field of structure elements - * @param R - actual nd-element type of this field + * Field of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param F field of structure elements. */ interface NDField, N : NDStructure> : Field, NDRing { @@ -120,9 +119,9 @@ interface NDField, N : NDStructure> : Field, NDRing divide(aValue, bValue) } //TODO move to extensions after KEEP-176 - operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) } + operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) } - operator fun T.div(arg: N) = map(arg) { divide(it, this@div) } + operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) } companion object { @@ -131,7 +130,7 @@ interface NDField, N : NDStructure> : Field, NDRing, N : NDStructure> : Field, NDRing = Buffer.Companion::boxing - ) = BoxingNDField(shape, field, bufferFactory) + ): BoxingNDField = BoxingNDField(shape, field, bufferFactory) /** * Create a most suitable implementation for nd-field using reified class. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt index a18a03364..9dfe2b5a8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt @@ -23,19 +23,23 @@ interface NDElement> : NDStructure { /** * Create a optimized NDArray of doubles */ - fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }) = + fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement = NDField.real(*shape).produce(initializer) - fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) = + fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = real(intArrayOf(dim)) { initializer(it[0]) } - fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) = + fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) = - real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } + fun real3D( + dim1: Int, + dim2: Int, + dim3: Int, + initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 } + ): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } /** @@ -62,16 +66,17 @@ interface NDElement> : NDStructure { } -fun > NDElement.mapIndexed(transform: C.(index: IntArray, T) -> T) = +fun > NDElement.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement = context.mapIndexed(unwrap(), transform).wrap() -fun > NDElement.map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap() +fun > NDElement.map(transform: C.(T) -> T): NDElement = + context.map(unwrap(), transform).wrap() /** * Element by element application of any operation on elements to the whole [NDElement] */ -operator fun > Function1.invoke(ndElement: NDElement) = +operator fun > Function1.invoke(ndElement: NDElement): NDElement = ndElement.map { value -> this@invoke(value) } /* plus and minus */ @@ -79,13 +84,13 @@ operator fun > Function1.invoke(ndElement: NDElem /** * Summation operation for [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.plus(arg: T) = +operator fun , N : NDStructure> NDElement.plus(arg: T): NDElement = map { value -> arg + value } /** * Subtraction operation between [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.minus(arg: T) = +operator fun , N : NDStructure> NDElement.minus(arg: T): NDElement = map { value -> arg - value } /* prod and div */ @@ -93,13 +98,13 @@ operator fun , N : NDStructure> NDElement.minus(arg: /** * Product operation for [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.times(arg: T) = +operator fun , N : NDStructure> NDElement.times(arg: T): NDElement = map { value -> arg * value } /** * Division operation between [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.div(arg: T) = +operator fun , N : NDStructure> NDElement.div(arg: T): NDElement = map { value -> arg / value } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 808f970c5..9d7735053 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -3,70 +3,138 @@ package scientifik.kmath.structures import kotlin.jvm.JvmName import kotlin.reflect.KClass - +/** + * Represents n-dimensional structure, i.e. multidimensional container of items of the same type and size. The number + * of dimensions and items in an array is defined by its shape, which is a sequence of non-negative integers that + * specify the sizes of each dimension. + * + * @param T the type of items. + */ interface NDStructure { - + /** + * The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of + * this structure. + */ val shape: IntArray - val dimension get() = shape.size + /** + * The count of dimensions in this structure. It should be equal to size of [shape]. + */ + val dimension: Int get() = shape.size + /** + * Returns the value at the specified indices. + * + * @param index the indices. + * @return the value. + */ operator fun get(index: IntArray): T + /** + * Returns the sequence of all the elements associated by their indices. + * + * @return the lazy sequence of pairs of indices to values. + */ fun elements(): Sequence> + override fun equals(other: Any?): Boolean + + override fun hashCode(): Int + companion object { + /** + * Indicates whether some [NDStructure] is equal to another one. + */ fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { - return when { - st1 === st2 -> true - st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals( - st2.buffer - ) - else -> st1.elements().all { (index, value) -> value == st2[index] } + if (st1 === st2) return true + + // fast comparison of buffers if possible + if ( + st1 is NDBuffer && + st2 is NDBuffer && + st1.strides == st2.strides + ) { + return st1.buffer.contentEquals(st2.buffer) } + + //element by element comparison if it could not be avoided + return st1.elements().all { (index, value) -> value == st2[index] } } /** - * Create a NDStructure with explicit buffer factory + * Creates a NDStructure with explicit buffer factory. * - * Strides should be reused if possible + * Strides should be reused if possible. */ fun build( strides: Strides, bufferFactory: BufferFactory = Buffer.Companion::boxing, initializer: (IntArray) -> T - ) = + ): BufferNDStructure = BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) /** * Inline create NDStructure with non-boxing buffer implementation if it is possible */ - inline fun auto(strides: Strides, crossinline initializer: (IntArray) -> T) = + inline fun auto( + strides: Strides, + crossinline initializer: (IntArray) -> T + ): BufferNDStructure = BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) - inline fun auto(type: KClass, strides: Strides, crossinline initializer: (IntArray) -> T) = + inline fun auto( + type: KClass, + strides: Strides, + crossinline initializer: (IntArray) -> T + ): BufferNDStructure = BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) fun build( shape: IntArray, bufferFactory: BufferFactory = Buffer.Companion::boxing, initializer: (IntArray) -> T - ) = build(DefaultStrides(shape), bufferFactory, initializer) + ): BufferNDStructure = build(DefaultStrides(shape), bufferFactory, initializer) - inline fun auto(shape: IntArray, crossinline initializer: (IntArray) -> T) = + inline fun auto( + shape: IntArray, + crossinline initializer: (IntArray) -> T + ): BufferNDStructure = auto(DefaultStrides(shape), initializer) @JvmName("autoVarArg") - inline fun auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) = + inline fun auto( + vararg shape: Int, + crossinline initializer: (IntArray) -> T + ): BufferNDStructure = auto(DefaultStrides(shape), initializer) - inline fun auto(type: KClass, vararg shape: Int, crossinline initializer: (IntArray) -> T) = + inline fun auto( + type: KClass, + vararg shape: Int, + crossinline initializer: (IntArray) -> T + ): BufferNDStructure = auto(type, DefaultStrides(shape), initializer) } } +/** + * Returns the value at the specified indices. + * + * @param index the indices. + * @return the value. + */ operator fun NDStructure.get(vararg index: Int): T = get(index) +/** + * Represents mutable [NDStructure]. + */ interface MutableNDStructure : NDStructure { + /** + * Inserts an item at the specified indices. + * + * @param index the indices. + * @param value the value. + */ operator fun set(index: IntArray, value: T) } @@ -77,7 +145,7 @@ inline fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T) { } /** - * A way to convert ND index to linear one and back + * A way to convert ND index to linear one and back. */ interface Strides { /** @@ -114,11 +182,14 @@ interface Strides { } } +/** + * Simple implementation of [Strides]. + */ class DefaultStrides private constructor(override val shape: IntArray) : Strides { /** * Strides for memory access */ - override val strides by lazy { + override val strides: List by lazy { sequence { var current = 1 yield(1) @@ -153,19 +224,14 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides override val linearSize: Int get() = strides[shape.size] - override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is DefaultStrides) return false - if (!shape.contentEquals(other.shape)) return false - return true } - override fun hashCode(): Int { - return shape.contentHashCode() - } + override fun hashCode(): Int = shape.contentHashCode() companion object { private val defaultStridesCache = HashMap() @@ -177,15 +243,37 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides } } -interface NDBuffer : NDStructure { - val buffer: Buffer - val strides: Strides +/** + * Represents [NDStructure] over [Buffer]. + * + * @param T the type of items. + */ +abstract class NDBuffer : NDStructure { + /** + * The underlying buffer. + */ + abstract val buffer: Buffer + + /** + * The strides to access elements of [Buffer] by linear indices. + */ + abstract val strides: Strides override fun get(index: IntArray): T = buffer[strides.offset(index)] override val shape: IntArray get() = strides.shape - override fun elements() = strides.indices().map { it to this[it] } + override fun elements(): Sequence> = strides.indices().map { it to this[it] } + + override fun equals(other: Any?): Boolean { + return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + } + + override fun hashCode(): Int { + var result = strides.hashCode() + result = 31 * result + buffer.hashCode() + return result + } } /** @@ -194,34 +282,12 @@ interface NDBuffer : NDStructure { class BufferNDStructure( override val strides: Strides, override val buffer: Buffer -) : NDBuffer { - +) : NDBuffer() { init { if (strides.linearSize != buffer.size) { error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") } } - - override fun get(index: IntArray): T = buffer[strides.offset(index)] - - override val shape: IntArray get() = strides.shape - - override fun elements() = strides.indices().map { it to this[it] } - - override fun equals(other: Any?): Boolean { - return when { - this === other -> true - other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer) - other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] } - else -> false - } - } - - override fun hashCode(): Int { - var result = strides.hashCode() - result = 31 * result + buffer.hashCode() - return result - } } /** @@ -240,20 +306,20 @@ inline fun NDStructure.mapToBuffer( } /** - * Mutable ND buffer based on linear [autoBuffer] + * Mutable ND buffer based on linear [MutableBuffer]. */ class MutableBufferNDStructure( override val strides: Strides, override val buffer: MutableBuffer -) : NDBuffer, MutableNDStructure { +) : NDBuffer(), MutableNDStructure { init { - if (strides.linearSize != buffer.size) { - error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") + require(strides.linearSize == buffer.size) { + "Expected buffer side of ${strides.linearSize}, but found ${buffer.size}" } } - override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value) + override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) } inline fun NDStructure.combine( @@ -262,4 +328,4 @@ inline fun NDStructure.combine( ): NDStructure { if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") return NDStructure.auto(shape) { block(this[it], struct[it]) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt new file mode 100644 index 000000000..e999e12b2 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt @@ -0,0 +1,49 @@ +package scientifik.kmath.structures + +/** + * Specialized [MutableBuffer] implementation over [DoubleArray]. + * + * @property array the underlying array. + */ +inline class RealBuffer(val array: DoubleArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Double = array[index] + + override fun set(index: Int, value: Double) { + array[index] = value + } + + override fun iterator(): DoubleIterator = array.iterator() + + override fun copy(): MutableBuffer = + RealBuffer(array.copyOf()) +} + +/** + * Creates a new [RealBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an buffer element given its index. + */ +inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) + +/** + * Returns a new [RealBuffer] of given elements. + */ +fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) + +/** + * Returns a [DoubleArray] containing all of the elements of this [MutableBuffer]. + */ +val MutableBuffer.array: DoubleArray + get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) }) + +/** + * Returns [RealBuffer] over this array. + * + * @receiver the array. + * @return the new buffer. + */ +fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 88c8c29db..33198aac1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -6,148 +6,180 @@ import kotlin.math.* /** - * A simple field over linear buffers of [Double] + * [ExtendedFieldOperations] over [RealBuffer]. */ object RealBufferFieldOperations : ExtendedFieldOperations> { - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { val kValue = k.toDouble() - return if (a is DoubleBuffer) { + + return if (a is RealBuffer) { val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } - override fun sin(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - } + override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) + } else { + RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } - override fun cos(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) - } + override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + + override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + + override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) + } else { + RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - } - } + override fun acos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { acos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - override fun exp(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - } - } + override fun atan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { atan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - override fun ln(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) - } - } + override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + } else + RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + + override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + + override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } +/** + * [ExtendedField] over [RealBuffer]. + * + * @property size the size of buffers to operate on. + */ class RealBufferField(val size: Int) : ExtendedField> { + override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } + override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } - override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } - - override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } - - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, k) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, b) } - - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) } - override fun sin(arg: Buffer): DoubleBuffer { + override fun sin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sin(arg) } - override fun cos(arg: Buffer): DoubleBuffer { + override fun cos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cos(arg) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { + override fun tan(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.tan(arg) + } + + override fun asin(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.asin(arg) + } + + override fun acos(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.acos(arg) + } + + override fun atan(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.atan(arg) + } + + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) } - override fun exp(arg: Buffer): DoubleBuffer { + override fun exp(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.exp(arg) } - override fun ln(arg: Buffer): DoubleBuffer { + override fun ln(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } - -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 8c1bd4239..e2a1a33df 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -12,11 +12,11 @@ class RealNDField(override val shape: IntArray) : override val strides: Strides = DefaultStrides(shape) override val elementContext: RealField get() = RealField - override val zero by lazy { produce { zero } } - override val one by lazy { produce { one } } + override val zero: RealNDElement by lazy { produce { zero } } + override val one: RealNDElement by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) /** * Inline transform an NDStructure to @@ -64,16 +64,23 @@ class RealNDField(override val shape: IntArray) : override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = BufferedNDFieldElement(this@RealNDField, buffer) - override fun power(arg: NDBuffer, pow: Number) = map(arg) { power(it, pow) } + override fun power(arg: NDBuffer, pow: Number): RealNDElement = map(arg) { power(it, pow) } - override fun exp(arg: NDBuffer) = map(arg) { exp(it) } + override fun exp(arg: NDBuffer): RealNDElement = map(arg) { exp(it) } - override fun ln(arg: NDBuffer) = map(arg) { ln(it) } + override fun ln(arg: NDBuffer): RealNDElement = map(arg) { ln(it) } - override fun sin(arg: NDBuffer) = map(arg) { sin(it) } + override fun sin(arg: NDBuffer): RealNDElement = map(arg) { sin(it) } - override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun cos(arg: NDBuffer): RealNDElement = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): NDBuffer = map(arg) { acos(it) } + + override fun atan(arg: NDBuffer): NDBuffer = map(arg) { atan(it) } } @@ -82,27 +89,27 @@ class RealNDField(override val shape: IntArray) : */ inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } - return BufferedNDFieldElement(this, DoubleBuffer(array)) + return BufferedNDFieldElement(this, RealBuffer(array)) } /** - * Map one [RealNDElement] using function with indexes + * Map one [RealNDElement] using function with indices. */ -inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double) = +inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double): RealNDElement = context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } /** - * Map one [RealNDElement] using function without indexes + * Map one [RealNDElement] using function without indices. */ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } - return BufferedNDFieldElement(context, DoubleBuffer(array)) + return BufferedNDFieldElement(context, RealBuffer(array)) } /** - * Element by element application of any operation on elements to the whole array. Just like in numpy + * Element by element application of any operation on elements to the whole array. Just like in numpy. */ -operator fun Function1.invoke(ndElement: RealNDElement) = +operator fun Function1.invoke(ndElement: RealNDElement): RealNDElement = ndElement.map { this@invoke(it) } @@ -111,13 +118,13 @@ operator fun Function1.invoke(ndElement: RealNDElement) = /** * Summation operation for [BufferedNDElement] and single element */ -operator fun RealNDElement.plus(arg: Double) = +operator fun RealNDElement.plus(arg: Double): RealNDElement = map { it + arg } /** * Subtraction operation between [BufferedNDElement] and single element */ -operator fun RealNDElement.minus(arg: Double) = +operator fun RealNDElement.minus(arg: Double): RealNDElement = map { it - arg } /** @@ -125,4 +132,4 @@ operator fun RealNDElement.minus(arg: Double) = */ inline fun RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R { return NDField.real(*shape).run(action) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt new file mode 100644 index 000000000..c6f19feaf --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt @@ -0,0 +1,50 @@ +package scientifik.kmath.structures + +/** + * Specialized [MutableBuffer] implementation over [ShortArray]. + * + * @property array the underlying array. + */ +inline class ShortBuffer(val array: ShortArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Short = array[index] + + override fun set(index: Int, value: Short) { + array[index] = value + } + + override fun iterator(): ShortIterator = array.iterator() + + override fun copy(): MutableBuffer = + ShortBuffer(array.copyOf()) + +} + +/** + * Creates a new [ShortBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an buffer element given its index. + */ +inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) }) + +/** + * Returns a new [ShortBuffer] of given elements. + */ +fun ShortBuffer(vararg shorts: Short): ShortBuffer = ShortBuffer(shorts) + +/** + * Returns a [ShortArray] containing all of the elements of this [MutableBuffer]. + */ +val MutableBuffer.array: ShortArray + get() = (if (this is ShortBuffer) array else ShortArray(size) { get(it) }) + +/** + * Returns [ShortBuffer] over this array. + * + * @receiver the array. + * @return the new buffer. + */ +fun ShortArray.asBuffer(): ShortBuffer = ShortBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt index 6b09c91de..f404a2a27 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt @@ -12,8 +12,8 @@ class ShortNDRing(override val shape: IntArray) : override val strides: Strides = DefaultStrides(shape) override val elementContext: ShortRing get() = ShortRing - override val zero by lazy { produce { ShortRing.zero } } - override val one by lazy { produce { ShortRing.one } } + override val zero: ShortNDElement by lazy { produce { zero } } + override val one: ShortNDElement by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer = ShortBuffer(ShortArray(size) { initializer(it) }) @@ -40,6 +40,7 @@ class ShortNDRing(override val shape: IntArray) : transform: ShortRing.(index: IntArray, Short) -> Short ): ShortNDElement { check(arg) + return BufferedNDRingElement( this, buildBuffer(arg.strides.linearSize) { offset -> @@ -67,7 +68,7 @@ class ShortNDRing(override val shape: IntArray) : /** - * Fast element production using function inlining + * Fast element production using function inlining. */ inline fun BufferedNDRing.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement { val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) } @@ -75,22 +76,22 @@ inline fun BufferedNDRing.produceInline(crossinline initialize } /** - * Element by element application of any operation on elements to the whole array. Just like in numpy + * Element by element application of any operation on elements to the whole array. */ -operator fun Function1.invoke(ndElement: ShortNDElement) = +operator fun Function1.invoke(ndElement: ShortNDElement): ShortNDElement = ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } /* plus and minus */ /** - * Summation operation for [StridedNDFieldElement] and single element + * Summation operation for [ShortNDElement] and single element. */ -operator fun ShortNDElement.plus(arg: Short) = +operator fun ShortNDElement.plus(arg: Short): ShortNDElement = context.produceInline { i -> (buffer[i] + arg).toShort() } /** - * Subtraction operation between [StridedNDFieldElement] and single element + * Subtraction operation between [ShortNDElement] and single element. */ -operator fun ShortNDElement.minus(arg: Short) = - context.produceInline { i -> (buffer[i] - arg).toShort() } \ No newline at end of file +operator fun ShortNDElement.minus(arg: Short): ShortNDElement = + context.produceInline { i -> (buffer[i] - arg).toShort() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt index df56017a3..faf022367 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt @@ -39,14 +39,14 @@ private inline class Buffer1DWrapper(val buffer: Buffer) : Structure1D override fun elements(): Sequence> = asSequence().mapIndexed { index, value -> intArrayOf(index) to value } - override fun get(index: Int): T = buffer.get(index) + override fun get(index: Int): T = buffer[index] } /** * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch */ fun NDStructure.as1D(): Structure1D = if (shape.size == 1) { - if( this is NDBuffer){ + if (this is NDBuffer) { Buffer1DWrapper(this.buffer) } else { Structure1DWrapper(this) @@ -59,4 +59,4 @@ fun NDStructure.as1D(): Structure1D = if (shape.size == 1) { /** * Represent this buffer as 1D structure */ -fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) \ No newline at end of file +fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt index e736f84a0..30fd556d3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt @@ -14,7 +14,6 @@ interface Structure2D : NDStructure { return get(index[0], index[1]) } - val rows: Buffer> get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } @@ -33,9 +32,7 @@ interface Structure2D : NDStructure { } } - companion object { - - } + companion object } /** @@ -58,22 +55,4 @@ fun NDStructure.as2D(): Structure2D = if (shape.size == 2) { error("Can't create 2d-structure from ${shape.size}d-structure") } -/** - * Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise. - */ -fun Structure2D.as1D() = if (colNum == 1) { - object : Structure1D { - override fun get(index: Int): T = get(index, 0) - - override val shape: IntArray get() = intArrayOf(rowNum) - - override fun elements(): Sequence> = elements() - - override val size: Int get() = rowNum - } -} else { - error("Can't convert matrix with more than one column to vector") -} - - -typealias Matrix = Structure2D \ No newline at end of file +typealias Matrix = Structure2D diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index 033b2792f..22b924ef9 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -9,7 +9,7 @@ import kotlin.test.assertEquals class ExpressionFieldTest { @Test fun testExpression() { - val context = ExpressionField(RealField) + val context = FunctionalExpressionField(RealField) val expression = with(context) { val x = variable("x", 2.0) x * x + 2 * x + one @@ -20,7 +20,7 @@ class ExpressionFieldTest { @Test fun testComplex() { - val context = ExpressionField(ComplexField) + val context = FunctionalExpressionField(ComplexField) val expression = with(context) { val x = variable("x", Complex(2.0, 0.0)) x * x + 2 * x + one @@ -31,23 +31,23 @@ class ExpressionFieldTest { @Test fun separateContext() { - fun ExpressionField.expression(): Expression { + fun FunctionalExpressionField.expression(): Expression { val x = variable("x") return x * x + 2 * x + one } - val expression = ExpressionField(RealField).expression() + val expression = FunctionalExpressionField(RealField).expression() assertEquals(expression("x" to 1.0), 4.0) } @Test fun valueExpression() { - val expressionBuilder: ExpressionField.() -> Expression = { + val expressionBuilder: FunctionalExpressionField.() -> Expression = { val x = variable("x") x * x + 2 * x + one } - val expression = ExpressionField(RealField).expressionBuilder() + val expression = FunctionalExpressionField(RealField).expressionBuilder() assertEquals(expression("x" to 1.0), 4.0) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt index 7d1209963..987426250 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt @@ -17,7 +17,7 @@ class MatrixTest { @Test fun testBuilder() { - val matrix = Matrix.build(2, 3)( + val matrix = Matrix.build(2, 3)( 1.0, 0.0, 0.0, 0.0, 1.0, 2.0 ) @@ -49,17 +49,17 @@ class MatrixTest { @Test fun test2DDot() { - val firstMatrix = NDStructure.auto(2,3){ (i, j) -> (i + j).toDouble() }.as2D() - val secondMatrix = NDStructure.auto(3,2){ (i, j) -> (i + j).toDouble() }.as2D() + val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D() + val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D() MatrixContext.real.run { // val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() } // val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() } val result = firstMatrix dot secondMatrix assertEquals(2, result.rowNum) assertEquals(2, result.colNum) - assertEquals(8.0, result[0,1]) - assertEquals(8.0, result[1,0]) - assertEquals(14.0, result[1,1]) + assertEquals(8.0, result[0, 1]) + assertEquals(8.0, result[1, 0]) + assertEquals(14.0, result[1, 1]) } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index 56a0b7aad..34bd8a0e3 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -48,4 +48,4 @@ class RealLUSolverTest { assertEquals(expected, inverted) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt new file mode 100644 index 000000000..e69de29bb diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt index 9dc8f5ef7..c08a63ccb 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt @@ -8,10 +8,10 @@ import kotlin.test.assertEquals import kotlin.test.assertTrue class AutoDiffTest { + fun Variable(int: Int): Variable = Variable(int.toDouble()) - fun Variable(int: Int) = Variable(int.toDouble()) - - fun deriv(body: AutoDiffField.() -> Variable) = RealField.deriv(body) + fun deriv(body: AutoDiffField.() -> Variable): DerivationResult = + RealField.deriv(body) @Test fun testPlusX2() { @@ -178,5 +178,4 @@ class AutoDiffTest { private fun assertApprox(a: Double, b: Double) { if ((a - b) > 1e-10) assertEquals(a, b) } - -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/CumulativeKtTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/CumulativeKtTest.kt index e7c99e7d0..82ea5318f 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/CumulativeKtTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/CumulativeKtTest.kt @@ -10,4 +10,4 @@ class CumulativeKtTest { val cumulative = initial.cumulativeSum() assertEquals(listOf(-1.0, 1.0, 2.0, 3.0), cumulative) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt index 5ae977196..c22d2f27b 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt @@ -47,4 +47,3 @@ class BigIntAlgebraTest { } } - diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConstructorTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConstructorTest.kt index 2af1b7e50..5e3f6d1b0 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConstructorTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConstructorTest.kt @@ -19,8 +19,8 @@ class BigIntConstructorTest { @Test fun testConstructor_0xffffffffaL() { - val x = -0xffffffffaL.toBigInt() + val x = (-0xffffffffaL).toBigInt() val y = uintArrayOf(0xfffffffaU, 0xfU).toBigInt(-1) assertEquals(x, y) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConversionsTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConversionsTest.kt index 51b9509e0..41df1968d 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConversionsTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConversionsTest.kt @@ -19,7 +19,7 @@ class BigIntConversionsTest { @Test fun testToString_0x17ead2ffffd() { - val x = -0x17ead2ffffdL.toBigInt() + val x = (-0x17ead2ffffdL).toBigInt() assertEquals("-0x17ead2ffffd", x.toString()) } @@ -40,4 +40,4 @@ class BigIntConversionsTest { val x = "-7059135710711894913860".parseBigInteger() assertEquals("-0x17ead2ffffd11223344", x.toString()) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntOperationsTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntOperationsTest.kt index 72ac9f229..b7f4cf43b 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntOperationsTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntOperationsTest.kt @@ -31,7 +31,7 @@ class BigIntOperationsTest { @Test fun testUnaryMinus() { val x = 1234.toBigInt() - val y = -1234.toBigInt() + val y = (-1234).toBigInt() assertEquals(-x, y) } @@ -48,18 +48,18 @@ class BigIntOperationsTest { @Test fun testMinus__2_1() { - val x = -2.toBigInt() + val x = (-2).toBigInt() val y = 1.toBigInt() val res = x - y - val sum = -3.toBigInt() + val sum = (-3).toBigInt() assertEquals(sum, res) } @Test fun testMinus___2_1() { - val x = -2.toBigInt() + val x = (-2).toBigInt() val y = 1.toBigInt() val res = -x - y @@ -74,7 +74,7 @@ class BigIntOperationsTest { val y = 0xffffffffaL.toBigInt() val res = x - y - val sum = -0xfffffcfc1L.toBigInt() + val sum = (-0xfffffcfc1L).toBigInt() assertEquals(sum, res) } @@ -92,11 +92,11 @@ class BigIntOperationsTest { @Test fun testMultiply__2_3() { - val x = -2.toBigInt() + val x = (-2).toBigInt() val y = 3.toBigInt() val res = x * y - val prod = -6.toBigInt() + val prod = (-6).toBigInt() assertEquals(prod, res) } @@ -129,7 +129,7 @@ class BigIntOperationsTest { val y = -0xfff456 val res = x * y - val prod = -0xffe579ad5dc2L.toBigInt() + val prod = (-0xffe579ad5dc2L).toBigInt() assertEquals(prod, res) } @@ -259,7 +259,7 @@ class BigIntOperationsTest { val y = -3 val res = x / y - val div = -6.toBigInt() + val div = (-6).toBigInt() assertEquals(div, res) } @@ -267,10 +267,10 @@ class BigIntOperationsTest { @Test fun testBigDivision_20__3() { val x = 20.toBigInt() - val y = -3.toBigInt() + val y = (-3).toBigInt() val res = x / y - val div = -6.toBigInt() + val div = (-6).toBigInt() assertEquals(div, res) } @@ -378,4 +378,4 @@ class BigIntOperationsTest { return assertEquals(res, x % mod) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt index 779cfc4b8..9dfa3bdd1 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt @@ -11,4 +11,4 @@ class RealFieldTest { } assertEquals(5.0, sqrt) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/ComplexBufferSpecTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/ComplexBufferSpecTest.kt index 454683dac..cbbe6f0f4 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/ComplexBufferSpecTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/ComplexBufferSpecTest.kt @@ -11,4 +11,4 @@ class ComplexBufferSpecTest { val buffer = Buffer.complex(20) { Complex(it.toDouble(), -it.toDouble()) } assertEquals(Complex(5.0, -5.0), buffer[5]) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt index 39cce5c67..7abeefca6 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt @@ -10,4 +10,4 @@ class NDFieldTest { val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } assertEquals(ndArray[5, 5], 10.0) } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt index 60f1f9979..d48aabfd0 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt @@ -8,8 +8,8 @@ import kotlin.test.Test import kotlin.test.assertEquals class NumberNDFieldTest { - val array1 = real2D(3, 3) { i, j -> (i + j).toDouble() } - val array2 = real2D(3, 3) { i, j -> (i - j).toDouble() } + val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() } + val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() } @Test fun testSum() { diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt index f44d15042..06f2b31ad 100644 --- a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt @@ -1,45 +1,60 @@ package scientifik.kmath.operations -import scientifik.kmath.structures.* import java.math.BigDecimal import java.math.BigInteger import java.math.MathContext -object BigIntegerRing : Ring { - override val zero: BigInteger = BigInteger.ZERO - override val one: BigInteger = BigInteger.ONE +/** + * A field over [BigInteger]. + */ +object JBigIntegerField : Field { + override val zero: BigInteger + get() = BigInteger.ZERO + override val one: BigInteger + get() = BigInteger.ONE + + override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) + override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) - + override fun BigInteger.minus(b: BigInteger): BigInteger = this.subtract(b) override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) - override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) + override fun BigInteger.unaryMinus(): BigInteger = negate() } -class BigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field { - override val zero: BigDecimal = BigDecimal.ZERO - override val one: BigDecimal = BigDecimal.ONE +/** + * An abstract field over [BigDecimal]. + * + * @property mathContext the [MathContext] to use. + */ +abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathContext = MathContext.DECIMAL64) : + Field, + PowerOperations { + override val zero: BigDecimal + get() = BigDecimal.ZERO + + override val one: BigDecimal + get() = BigDecimal.ONE override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) + override fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) + override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun multiply(a: BigDecimal, k: Number): BigDecimal = a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext) override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) + override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) + override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) + override fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) + } -inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): Buffer = - boxing(size, initializer) - -inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): MutableBuffer = - boxing(size, initializer) - -fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing = - BoxingNDRing(shape, BigIntegerRing, Buffer.Companion::bigInt) - -fun NDElement.Companion.bigInt( - vararg shape: Int, - initializer: BigIntegerRing.(IntArray) -> BigInteger -): BufferedNDRingElement = - NDAlgebra.bigInt(*shape).produce(initializer) \ No newline at end of file +/** + * A field over [BigDecimal]. + */ +class JBigDecimalField(mathContext: MathContext = MathContext.DECIMAL64) : JBigDecimalFieldBase(mathContext) { + companion object : JBigDecimalFieldBase() +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt new file mode 100644 index 000000000..e9b499d71 --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt @@ -0,0 +1,12 @@ +package scientifik.kmath.chains + +/** + * Performance optimized chain for integer values + */ +abstract class BlockingIntChain : Chain { + abstract fun nextInt(): Int + + override suspend fun next(): Int = nextInt() + + fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() } +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt new file mode 100644 index 000000000..ab819d327 --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt @@ -0,0 +1,12 @@ +package scientifik.kmath.chains + +/** + * Performance optimized chain for real values + */ +abstract class BlockingRealChain : Chain { + abstract fun nextDouble(): Double + + override suspend fun next(): Double = nextDouble() + + fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() } +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 1c2872d17..6cc9770af 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -22,12 +22,11 @@ import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock - /** * A not-necessary-Markov chain of some type * @param R - the chain element type */ -interface Chain: Flow { +interface Chain : Flow { /** * Generate next value, changing state if needed */ @@ -38,13 +37,16 @@ interface Chain: Flow { */ fun fork(): Chain - @InternalCoroutinesApi + @OptIn(InternalCoroutinesApi::class) override suspend fun collect(collector: FlowCollector) { - kotlinx.coroutines.flow.flow { while (true) emit(next()) }.collect(collector) + kotlinx.coroutines.flow.flow { + while (true) { + emit(next()) + } + }.collect(collector) } companion object - } @@ -68,6 +70,8 @@ class MarkovChain(private val seed: suspend () -> R, private val ge private var value: R? = null + fun value(): R? = value + override suspend fun next(): R { mutex.withLock { val newValue = gen(value ?: seed()) @@ -92,11 +96,12 @@ class StatefulChain( private val forkState: ((S) -> S), private val gen: suspend S.(R) -> R ) : Chain { - - private val mutex = Mutex() + private val mutex: Mutex = Mutex() private var value: R? = null + fun value(): R? = value + override suspend fun next(): R { mutex.withLock { val newValue = state.gen(value ?: state.seed()) @@ -105,9 +110,7 @@ class StatefulChain( } } - override fun fork(): Chain { - return StatefulChain(forkState(state), seed, forkState, gen) - } + override fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) } /** @@ -156,7 +159,8 @@ fun Chain.collect(mapper: suspend (Chain) -> R): Chain = object fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain) -> R): Chain = object : Chain { override suspend fun next(): R = state.mapper(this@collectWithState) - override fun fork(): Chain = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) + override fun fork(): Chain = + this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) } /** @@ -166,4 +170,4 @@ fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = this@zip.fork().zip(other.fork(), block) -} \ No newline at end of file +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt index bfd16d763..e8537304c 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt @@ -24,4 +24,4 @@ fun Flow.mean(space: Space): Flow = with(space) { this.num += 1 } }.map { it.sum / it.num } -} \ No newline at end of file +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt index fdde62304..7e00b30a1 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt @@ -4,7 +4,8 @@ import kotlinx.coroutines.* import kotlinx.coroutines.channels.produce import kotlinx.coroutines.flow.* -val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default +val Dispatchers.Math: CoroutineDispatcher + get() = Default /** * An imitator of [Deferred] which holds a suspended function block and dispatcher @@ -42,7 +43,7 @@ fun Flow.async( } @FlowPreview -fun AsyncFlow.map(action: (T) -> R) = +fun AsyncFlow.map(action: (T) -> R): AsyncFlow = AsyncFlow(deferredFlow.map { input -> //TODO add function composition LazyDeferred(input.dispatcher) { @@ -82,9 +83,9 @@ suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector< @ExperimentalCoroutinesApi @FlowPreview -suspend fun AsyncFlow.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit { +suspend fun AsyncFlow.collect(concurrency: Int, action: suspend (value: T) -> Unit) { collect(concurrency, object : FlowCollector { - override suspend fun emit(value: T) = action(value) + override suspend fun emit(value: T): Unit = action(value) }) } @@ -94,9 +95,7 @@ fun Flow.mapParallel( dispatcher: CoroutineDispatcher = Dispatchers.Default, transform: suspend (T) -> R ): Flow { - return flatMapMerge{ value -> + return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher) } - - diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt index cf4d4cc17..9b7e82da5 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt @@ -2,14 +2,16 @@ package scientifik.kmath.streaming import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.* +import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer +import scientifik.kmath.structures.asBuffer /** * Create a [Flow] from buffer */ -fun Buffer.asFlow() = iterator().asFlow() +fun Buffer.asFlow(): Flow = iterator().asFlow() /** * Flat map a [Flow] of [Buffer] into continuous [Flow] of elements @@ -43,22 +45,30 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< /** * Specialized flow chunker for real buffer */ -fun Flow.chunked(bufferSize: Int): Flow = flow { +fun Flow.chunked(bufferSize: Int): Flow = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } - val array = DoubleArray(bufferSize) - var counter = 0 - this@chunked.collect { element -> - array[counter] = element - counter++ - if (counter == bufferSize) { - val buffer = DoubleBuffer(array) - emit(buffer) - counter = 0 + if (this@chunked is BlockingRealChain) { + //performance optimization for blocking primitive chain + while (true) { + emit(nextBlock(bufferSize).asBuffer()) + } + } else { + val array = DoubleArray(bufferSize) + var counter = 0 + + this@chunked.collect { element -> + array[counter] = element + counter++ + if (counter == bufferSize) { + val buffer = RealBuffer(array) + emit(buffer) + counter = 0 + } + } + if (counter > 0) { + emit(RealBuffer(counter) { array[it] }) } - } - if (counter > 0) { - emit(DoubleBuffer(counter) { array[it] }) } } @@ -73,4 +83,4 @@ fun Flow.windowed(window: Int): Flow> = flow { ringBuffer.push(element) emit(ringBuffer.snapshot()) } -} \ No newline at end of file +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt index 6b99e34ff..245d003b3 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt @@ -5,19 +5,17 @@ import kotlinx.coroutines.sync.withLock import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.MutableBuffer import scientifik.kmath.structures.VirtualBuffer -import kotlin.reflect.KClass /** * Thread-safe ring buffer */ @Suppress("UNCHECKED_CAST") -internal class RingBuffer( +class RingBuffer( private val buffer: MutableBuffer, private var startIndex: Int = 0, size: Int = 0 ) : Buffer { - - private val mutex = Mutex() + private val mutex: Mutex = Mutex() override var size: Int = size private set @@ -28,7 +26,7 @@ internal class RingBuffer( return buffer[startIndex.forward(index)] as T } - fun isFull() = size == buffer.size + fun isFull(): Boolean = size == buffer.size /** * Iterator could provide wrong results if buffer is changed in initialization (iteration is safe) @@ -90,4 +88,4 @@ internal class RingBuffer( return RingBuffer(buffer) } } -} \ No newline at end of file +} diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt index 013ea2922..0a3c67e00 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt @@ -6,7 +6,7 @@ import kotlin.sequences.Sequence /** * Represent a chain as regular iterator (uses blocking calls) */ -operator fun Chain.iterator() = object : Iterator { +operator fun Chain.iterator(): Iterator = object : Iterator { override fun hasNext(): Boolean = true override fun next(): R = runBlocking { next() } diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt index 784b7cd10..8d5145976 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt @@ -8,10 +8,9 @@ class LazyNDStructure( override val shape: IntArray, val function: suspend (IntArray) -> T ) : NDStructure { + private val cache: MutableMap> = hashMapOf() - private val cache = HashMap>() - - fun deferred(index: IntArray) = cache.getOrPut(index) { + fun deferred(index: IntArray): Deferred = cache.getOrPut(index) { scope.async(context = Dispatchers.Math) { function(index) } @@ -30,19 +29,33 @@ class LazyNDStructure( } return res.asSequence() } + + override fun equals(other: Any?): Boolean { + return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + } + + override fun hashCode(): Int { + var result = scope.hashCode() + result = 31 * result + shape.contentHashCode() + result = 31 * result + function.hashCode() + result = 31 * result + cache.hashCode() + return result + } } -fun NDStructure.deferred(index: IntArray) = +fun NDStructure.deferred(index: IntArray): Deferred = if (this is LazyNDStructure) this.deferred(index) else CompletableDeferred(get(index)) -suspend fun NDStructure.await(index: IntArray) = +suspend fun NDStructure.await(index: IntArray): T = if (this is LazyNDStructure) this.await(index) else get(index) /** - * PENDING would benifit from KEEP-176 + * PENDING would benefit from KEEP-176 */ -fun NDStructure.mapAsyncIndexed(scope: CoroutineScope, function: suspend (T, index: IntArray) -> R) = - LazyNDStructure(scope, shape) { index -> function(get(index), index) } +fun NDStructure.mapAsyncIndexed( + scope: CoroutineScope, + function: suspend (T, index: IntArray) -> R +): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index), index) } -fun NDStructure.mapAsync(scope: CoroutineScope, function: suspend (T) -> R) = +fun NDStructure.mapAsync(scope: CoroutineScope, function: suspend (T) -> R): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index)) } \ No newline at end of file diff --git a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt b/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt index 147f687f0..427349072 100644 --- a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt +++ b/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt @@ -15,14 +15,13 @@ import kotlin.test.Test @InternalCoroutinesApi @FlowPreview class BufferFlowTest { - - val dispatcher = Executors.newFixedThreadPool(4).asCoroutineDispatcher() + val dispatcher: CoroutineDispatcher = Executors.newFixedThreadPool(4).asCoroutineDispatcher() @Test @Timeout(2000) fun map() { runBlocking { - (1..20).asFlow().mapParallel( dispatcher) { + (1..20).asFlow().mapParallel(dispatcher) { println("Started $it on ${Thread.currentThread().name}") @Suppress("BlockingMethodInNonBlockingContext") Thread.sleep(200) diff --git a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/RingBufferTest.kt b/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/RingBufferTest.kt index c14d1a26c..c84ef89ef 100644 --- a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/RingBufferTest.kt +++ b/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/RingBufferTest.kt @@ -19,17 +19,17 @@ class RingBufferTest { } @Test - fun windowed(){ - val flow = flow{ + fun windowed() { + val flow = flow { var i = 0 - while(true){ - emit(i++) - } + while (true) emit(i++) } + val windowed = flow.windowed(10) + runBlocking { val first = windowed.take(1).single() - val res = windowed.take(15).map { it -> it.asSequence().average() }.toList() + val res = windowed.take(15).map { it.asSequence().average() }.toList() assertEquals(0.0, res[0]) assertEquals(4.5, res[9]) assertEquals(9.5, res[14]) diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt index 37e89c111..f40483cfd 100644 --- a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt +++ b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt @@ -31,5 +31,5 @@ object D2 : Dimension { } object D3 : Dimension { - override val dim: UInt get() = 31U -} \ No newline at end of file + override val dim: UInt get() = 3U +} diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/linear/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/linear/RealVector.kt deleted file mode 100644 index d3cf07e79..000000000 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/linear/RealVector.kt +++ /dev/null @@ -1,48 +0,0 @@ -package scientifik.kmath.linear - -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Norm -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.SpaceElement -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer -import scientifik.kmath.structures.asBuffer -import scientifik.kmath.structures.asSequence - -fun DoubleArray.asVector() = RealVector(this.asBuffer()) -fun List.asVector() = RealVector(this.asBuffer()) - - -object VectorL2Norm : Norm, Double> { - override fun norm(arg: Point): Double = - kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() }) -} - -inline class RealVector(val point: Point) : - SpaceElement, RealVector, VectorSpace>, Point { - override val context: VectorSpace get() = space(point.size) - - override fun unwrap(): Point = point - - override fun Point.wrap(): RealVector = RealVector(this) - - override val size: Int get() = point.size - - override fun get(index: Int): Double = point[index] - - override fun iterator(): Iterator = point.iterator() - - companion object { - - private val spaceCache = HashMap>() - - inline operator fun invoke(dim:Int, initalizer: (Int)-> Double) = RealVector(DoubleBuffer(dim, initalizer)) - - operator fun invoke(vararg values: Double) = values.asVector() - - fun space(dim: Int) = - spaceCache.getOrPut(dim) { - BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } - } - } -} \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/DoubleMatrixOperations.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/DoubleMatrixOperations.kt deleted file mode 100644 index 7eeba3031..000000000 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/DoubleMatrixOperations.kt +++ /dev/null @@ -1,146 +0,0 @@ -package scientifik.kmath.real - -import scientifik.kmath.linear.MatrixContext -import scientifik.kmath.linear.RealMatrixContext.elementContext -import scientifik.kmath.linear.VirtualMatrix -import scientifik.kmath.operations.sum -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.asSequence -import kotlin.math.pow - -/* - * Functions for convenient "numpy-like" operations with Double matrices. - * - * Initial implementation of these functions is taken from: - * https://github.com/thomasnield/numky/blob/master/src/main/kotlin/org/nield/numky/linear/DoubleOperators.kt - * - */ - -/* - * Functions that help create a real (Double) matrix - */ - -fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double) = - MatrixContext.real.produce(rowNum, colNum, initializer) - -fun Sequence.toMatrix() = toList().let { - MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } -} - -fun Matrix.repeatStackVertical(n: Int) = VirtualMatrix(rowNum*n, colNum) { - row, col -> get(if (row == 0) 0 else row % rowNum, col) -} - -/* - * Operations for matrix and real number - */ - -operator fun Matrix.times(double: Double) = MatrixContext.real.produce(rowNum, colNum) { - row, col -> this[row, col] * double -} - -operator fun Matrix.plus(double: Double) = MatrixContext.real.produce(rowNum, colNum) { - row, col -> this[row, col] + double -} - -operator fun Matrix.minus(double: Double) = MatrixContext.real.produce(rowNum, colNum) { - row, col -> this[row, col] - double -} - -operator fun Matrix.div(double: Double) = MatrixContext.real.produce(rowNum, colNum) { - row, col -> this[row, col] / double -} - -operator fun Double.times(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { - row, col -> this * matrix[row, col] -} - -operator fun Double.plus(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { - row, col -> this + matrix[row, col] -} - -operator fun Double.minus(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { - row, col -> this - matrix[row, col] -} - -// TODO: does this operation make sense? Should it be 'this/matrix[row, col]'? -//operator fun Double.div(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { -// row, col -> matrix[row, col] / this -//} - -/* - * Per-element (!) square and power operations - */ - -fun Matrix.square() = MatrixContext.real.produce(rowNum, colNum) { - row, col -> this[row, col].pow(2) -} - -fun Matrix.pow(n: Int) = MatrixContext.real.produce(rowNum, colNum) { - i, j -> this[i, j].pow(n) -} - -/* - * Operations on two matrices (per-element!) - */ - -operator fun Matrix.times(other: Matrix) = MatrixContext.real.produce(rowNum, colNum) { - row, col -> this[row, col] * other[row, col] -} - -operator fun Matrix.plus(other: Matrix) = MatrixContext.real.add(this, other) - -operator fun Matrix.minus(other: Matrix) = MatrixContext.real.produce(rowNum, colNum) { - row, col -> this[row,col] - other[row,col] -} - -/* - * Operations on columns - */ - -inline fun Matrix.appendColumn(crossinline mapper: (Buffer) -> Double) = - MatrixContext.real.produce(rowNum,colNum+1) { - row, col -> - if (col < colNum) - this[row, col] - else - mapper(rows[row]) - } - -fun Matrix.extractColumns(columnRange: IntRange) = MatrixContext.real.produce(rowNum, columnRange.count()) { - row, col -> this[row, columnRange.first + col] -} - -fun Matrix.extractColumn(columnIndex: Int) = extractColumns(columnIndex..columnIndex) - -fun Matrix.sumByColumn() = MatrixContext.real.produce(1, colNum) { _, j -> - val column = columns[j] - with(elementContext) { - sum(column.asSequence()) - } -} - -fun Matrix.minByColumn() = MatrixContext.real.produce(1, colNum) { - _, j -> columns[j].asSequence().min() ?: throw Exception("Cannot produce min on empty column") -} - -fun Matrix.maxByColumn() = MatrixContext.real.produce(1, colNum) { - _, j -> columns[j].asSequence().max() ?: throw Exception("Cannot produce min on empty column") -} - -fun Matrix.averageByColumn() = MatrixContext.real.produce(1, colNum) { - _, j -> columns[j].asSequence().average() -} - -/* - * Operations processing all elements - */ - -fun Matrix.sum() = elements().map { (_, value) -> value }.sum() - -fun Matrix.min() = elements().map { (_, value) -> value }.min() - -fun Matrix.max() = elements().map { (_, value) -> value }.max() - -fun Matrix.average() = elements().map { (_, value) -> value }.average() diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt new file mode 100644 index 000000000..2b89904e3 --- /dev/null +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -0,0 +1,52 @@ +package scientifik.kmath.real + +import scientifik.kmath.linear.BufferVectorSpace +import scientifik.kmath.linear.Point +import scientifik.kmath.linear.VectorSpace +import scientifik.kmath.operations.Norm +import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.SpaceElement +import scientifik.kmath.structures.Buffer +import scientifik.kmath.structures.RealBuffer +import scientifik.kmath.structures.asBuffer +import scientifik.kmath.structures.asIterable +import kotlin.math.sqrt + +typealias RealPoint = Point + +fun DoubleArray.asVector() = RealVector(this.asBuffer()) +fun List.asVector() = RealVector(this.asBuffer()) + +object VectorL2Norm : Norm, Double> { + override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) +} + +inline class RealVector(private val point: Point) : + SpaceElement>, RealPoint { + + override val context: VectorSpace get() = space(point.size) + + override fun unwrap(): RealPoint = point + + override fun RealPoint.wrap(): RealVector = RealVector(this) + + override val size: Int get() = point.size + + override fun get(index: Int): Double = point[index] + + override fun iterator(): Iterator = point.iterator() + + companion object { + + private val spaceCache = HashMap>() + + inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = + RealVector(RealBuffer(dim, initializer)) + + operator fun invoke(vararg values: Double): RealVector = values.asVector() + + fun space(dim: Int): BufferVectorSpace = spaceCache.getOrPut(dim) { + BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } + } + } +} \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt new file mode 100644 index 000000000..82c0e86b2 --- /dev/null +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt @@ -0,0 +1,8 @@ +package scientifik.kmath.real + +import scientifik.kmath.structures.RealBuffer + +/** + * Simplified [RealBuffer] to array comparison + */ +fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt new file mode 100644 index 000000000..65f86eec7 --- /dev/null +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -0,0 +1,165 @@ +package scientifik.kmath.real + +import scientifik.kmath.linear.MatrixContext +import scientifik.kmath.linear.RealMatrixContext.elementContext +import scientifik.kmath.linear.VirtualMatrix +import scientifik.kmath.operations.sum +import scientifik.kmath.structures.Buffer +import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.RealBuffer +import scientifik.kmath.structures.asIterable +import kotlin.math.pow + +/* + * Functions for convenient "numpy-like" operations with Double matrices. + * + * Initial implementation of these functions is taken from: + * https://github.com/thomasnield/numky/blob/master/src/main/kotlin/org/nield/numky/linear/DoubleOperators.kt + * + */ + +/* + * Functions that help create a real (Double) matrix + */ + +typealias RealMatrix = Matrix + +fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum, initializer) + +fun Array.toMatrix(): RealMatrix{ + return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } +} + +fun Sequence.toMatrix(): RealMatrix = toList().let { + MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } +} + +fun Matrix.repeatStackVertical(n: Int): RealMatrix = + VirtualMatrix(rowNum * n, colNum) { row, col -> + get(if (row == 0) 0 else row % rowNum, col) + } + +/* + * Operations for matrix and real number + */ + +operator fun Matrix.times(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] * double + } + +operator fun Matrix.plus(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] + double + } + +operator fun Matrix.minus(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] - double + } + +operator fun Matrix.div(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] / double + } + +operator fun Double.times(matrix: Matrix): RealMatrix = + MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> + this * matrix[row, col] + } + +operator fun Double.plus(matrix: Matrix): RealMatrix = + MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> + this + matrix[row, col] + } + +operator fun Double.minus(matrix: Matrix): RealMatrix = + MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> + this - matrix[row, col] + } + +// TODO: does this operation make sense? Should it be 'this/matrix[row, col]'? +//operator fun Double.div(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { +// row, col -> matrix[row, col] / this +//} + +/* + * Per-element (!) square and power operations + */ + +fun Matrix.square(): RealMatrix = MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col].pow(2) +} + +fun Matrix.pow(n: Int): RealMatrix = MatrixContext.real.produce(rowNum, colNum) { i, j -> + this[i, j].pow(n) +} + +/* + * Operations on two matrices (per-element!) + */ + +operator fun Matrix.times(other: Matrix): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] * other[row, col] + } + +operator fun Matrix.plus(other: Matrix): RealMatrix = + MatrixContext.real.add(this, other) + +operator fun Matrix.minus(other: Matrix): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] - other[row, col] + } + +/* + * Operations on columns + */ + +inline fun Matrix.appendColumn(crossinline mapper: (Buffer) -> Double) = + MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> + if (col < colNum) + this[row, col] + else + mapper(rows[row]) + } + +fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = + MatrixContext.real.produce(rowNum, columnRange.count()) { row, col -> + this[row, columnRange.first + col] + } + +fun Matrix.extractColumn(columnIndex: Int): RealMatrix = + extractColumns(columnIndex..columnIndex) + +fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> + val column = columns[j] + with(elementContext) { + sum(column.asIterable()) + } +} + +fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> + columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") +} + +fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> + columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") +} + +fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> + columns[j].asIterable().average() +} + +/* + * Operations processing all elements + */ + +fun Matrix.sum() = elements().map { (_, value) -> value }.sum() + +fun Matrix.min() = elements().map { (_, value) -> value }.min() + +fun Matrix.max() = elements().map { (_, value) -> value }.max() + +fun Matrix.average() = elements().map { (_, value) -> value }.average() diff --git a/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt b/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt index 31b8b5252..8918fb300 100644 --- a/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt @@ -1,11 +1,12 @@ package scientific.kmath.real -import scientifik.kmath.real.* import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.linear.build +import scientifik.kmath.real.* import scientifik.kmath.structures.Matrix import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertTrue class RealMatrixTest { @Test @@ -19,72 +20,72 @@ class RealMatrixTest { fun testSequenceToMatrix() { val m = Sequence { listOf( - DoubleArray(10) { 10.0 }, - DoubleArray(10) { 20.0 }, - DoubleArray(10) { 30.0 }).iterator() + DoubleArray(10) { 10.0 }, + DoubleArray(10) { 20.0 }, + DoubleArray(10) { 30.0 }).iterator() }.toMatrix() assertEquals(m.sum(), 20.0 * 30) } @Test fun testRepeatStackVertical() { - val matrix1 = Matrix.build(2, 3)( - 1.0, 0.0, 0.0, - 0.0, 1.0, 2.0 + val matrix1 = Matrix.build(2, 3)( + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0 ) - val matrix2 = Matrix.build(6, 3)( - 1.0, 0.0, 0.0, - 0.0, 1.0, 2.0, - 1.0, 0.0, 0.0, - 0.0, 1.0, 2.0, - 1.0, 0.0, 0.0, - 0.0, 1.0, 2.0 + val matrix2 = Matrix.build(6, 3)( + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 2.0 ) assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3)) } @Test fun testMatrixAndDouble() { - val matrix1 = Matrix.build(2, 3)( - 1.0, 0.0, 3.0, - 4.0, 6.0, 2.0 + val matrix1 = Matrix.build(2, 3)( + 1.0, 0.0, 3.0, + 4.0, 6.0, 2.0 ) val matrix2 = (matrix1 * 2.5 + 1.0 - 2.0) / 2.0 - val expectedResult = Matrix.build(2, 3)( - 0.75, -0.5, 3.25, - 4.5, 7.0, 2.0 + val expectedResult = Matrix.build(2, 3)( + 0.75, -0.5, 3.25, + 4.5, 7.0, 2.0 ) assertEquals(matrix2, expectedResult) } @Test fun testDoubleAndMatrix() { - val matrix1 = Matrix.build(2, 3)( - 1.0, 0.0, 3.0, - 4.0, 6.0, 2.0 + val matrix1 = Matrix.build(2, 3)( + 1.0, 0.0, 3.0, + 4.0, 6.0, 2.0 ) val matrix2 = 20.0 - (10.0 + (5.0 * matrix1)) //val matrix2 = 10.0 + (5.0 * matrix1) - val expectedResult = Matrix.build(2, 3)( - 5.0, 10.0, -5.0, - -10.0, -20.0, 0.0 + val expectedResult = Matrix.build(2, 3)( + 5.0, 10.0, -5.0, + -10.0, -20.0, 0.0 ) assertEquals(matrix2, expectedResult) } @Test fun testSquareAndPower() { - val matrix1 = Matrix.build(2, 3)( - -1.0, 0.0, 3.0, - 4.0, -6.0, -2.0 + val matrix1 = Matrix.build(2, 3)( + -1.0, 0.0, 3.0, + 4.0, -6.0, -2.0 ) - val matrix2 = Matrix.build(2, 3)( - 1.0, 0.0, 9.0, - 16.0, 36.0, 4.0 + val matrix2 = Matrix.build(2, 3)( + 1.0, 0.0, 9.0, + 16.0, 36.0, 4.0 ) - val matrix3 = Matrix.build(2, 3)( - -1.0, 0.0, 27.0, - 64.0, -216.0, -8.0 + val matrix3 = Matrix.build(2, 3)( + -1.0, 0.0, 27.0, + 64.0, -216.0, -8.0 ) assertEquals(matrix1.square(), matrix2) assertEquals(matrix1.pow(3), matrix3) @@ -92,51 +93,61 @@ class RealMatrixTest { @Test fun testTwoMatrixOperations() { - val matrix1 = Matrix.build(2, 3)( - -1.0, 0.0, 3.0, - 4.0, -6.0, 7.0 + val matrix1 = Matrix.build(2, 3)( + -1.0, 0.0, 3.0, + 4.0, -6.0, 7.0 ) - val matrix2 = Matrix.build(2, 3)( - 1.0, 0.0, 3.0, - 4.0, 6.0, -2.0 + val matrix2 = Matrix.build(2, 3)( + 1.0, 0.0, 3.0, + 4.0, 6.0, -2.0 ) val result = matrix1 * matrix2 + matrix1 - matrix2 - val expectedResult = Matrix.build(2, 3)( - -3.0, 0.0, 9.0, - 16.0, -48.0, -5.0 + val expectedResult = Matrix.build(2, 3)( + -3.0, 0.0, 9.0, + 16.0, -48.0, -5.0 ) assertEquals(result, expectedResult) } @Test fun testColumnOperations() { - val matrix1 = Matrix.build(2, 4)( - -1.0, 0.0, 3.0, 15.0, - 4.0, -6.0, 7.0, -11.0 + val matrix1 = Matrix.build(2, 4)( + -1.0, 0.0, 3.0, 15.0, + 4.0, -6.0, 7.0, -11.0 ) - val matrix2 = Matrix.build(2, 5)( - -1.0, 0.0, 3.0, 15.0, -1.0, - 4.0, -6.0, 7.0, -11.0, 4.0 + val matrix2 = Matrix.build(2, 5)( + -1.0, 0.0, 3.0, 15.0, -1.0, + 4.0, -6.0, 7.0, -11.0, 4.0 ) - val col1 = Matrix.build(2, 1)(0.0, -6.0) - val cols1to2 = Matrix.build(2, 2)( - 0.0, 3.0, - -6.0, 7.0 + val col1 = Matrix.build(2, 1)(0.0, -6.0) + val cols1to2 = Matrix.build(2, 2)( + 0.0, 3.0, + -6.0, 7.0 ) + assertEquals(matrix1.appendColumn { it[0] }, matrix2) assertEquals(matrix1.extractColumn(1), col1) assertEquals(matrix1.extractColumns(1..2), cols1to2) - assertEquals(matrix1.sumByColumn(), Matrix.build(4, 1)(3.0, -6.0, 10.0, 4.0)) - assertEquals(matrix1.minByColumn(), Matrix.build(4, 1)(-1.0, -6.0, 3.0, -11.0)) - assertEquals(matrix1.maxByColumn(), Matrix.build(4, 1)(4.0, 0.0, 7.0, 15.0)) - assertEquals(matrix1.averageByColumn(), Matrix.build(4, 1)(1.5, -3.0, 5.0, 2.0)) + //equals should never be called on buffers + assertTrue { + matrix1.sumByColumn().contentEquals(3.0, -6.0, 10.0, 4.0) + } //assertEquals(matrix1.sumByColumn(), DoubleBuffer(3.0, -6.0, 10.0, 4.0)) + assertTrue { + matrix1.minByColumn().contentEquals(-1.0, -6.0, 3.0, -11.0) + } //assertEquals(matrix1.minByColumn(), DoubleBuffer(-1.0, -6.0, 3.0, -11.0)) + assertTrue { + matrix1.maxByColumn().contentEquals(4.0, 0.0, 7.0, 15.0) + } //assertEquals(matrix1.maxByColumn(), DoubleBuffer(4.0, 0.0, 7.0, 15.0)) + assertTrue { + matrix1.averageByColumn().contentEquals(1.5, -3.0, 5.0, 2.0) + } //assertEquals(matrix1.averageByColumn(), DoubleBuffer(1.5, -3.0, 5.0, 2.0)) } @Test fun testAllElementOperations() { - val matrix1 = Matrix.build(2, 4)( - -1.0, 0.0, 3.0, 15.0, - 4.0, -6.0, 7.0, -11.0 + val matrix1 = Matrix.build(2, 4)( + -1.0, 0.0, 3.0, 15.0, + 4.0, -6.0, 7.0, -11.0 ) assertEquals(matrix1.sum(), 11.0) assertEquals(matrix1.min(), -11.0) diff --git a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt b/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt index 0e73ee4a5..28e62b066 100644 --- a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt @@ -1,5 +1,6 @@ package scientifik.kmath.linear +import scientifik.kmath.real.RealVector import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 329af72a1..43d50ad20 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -1,17 +1,9 @@ package scientifik.kmath.histogram +import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer -import scientifik.kmath.structures.DoubleBuffer - -/** - * A simple geometric domain - * TODO move to geometry module - */ -interface Domain { - operator fun contains(vector: Point): Boolean - val dimension: Int -} +import scientifik.kmath.structures.RealBuffer /** * The bin in the histogram. The histogram is by definition always done in the real space @@ -51,9 +43,9 @@ interface MutableHistogram> : Histogram { fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) fun MutableHistogram.put(vararg point: Number) = - put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) + put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(DoubleBuffer(point)) +fun MutableHistogram.put(vararg point: Double) = put(RealBuffer(point)) fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 85f078fda..628a68461 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -1,8 +1,8 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point -import scientifik.kmath.linear.asVector import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.real.asVector import scientifik.kmath.structures.* import kotlin.math.floor @@ -21,7 +21,7 @@ data class BinDef>(val space: SpaceOperations>, val c class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - override fun contains(vector: Point): Boolean = def.contains(vector) + override fun contains(point: Point): Boolean = def.contains(point) override val dimension: Int get() = def.center.size @@ -50,7 +50,7 @@ class RealHistogram( override val dimension: Int get() = lower.size - private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } + private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks diff --git a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt index 4dc4dfc74..5edecb5a5 100644 --- a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt @@ -3,7 +3,7 @@ package scietifik.kmath.histogram import scientifik.kmath.histogram.RealHistogram import scientifik.kmath.histogram.fill import scientifik.kmath.histogram.put -import scientifik.kmath.linear.RealVector +import scientifik.kmath.real.RealVector import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index 90b9aff5e..af01205bf 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -1,7 +1,7 @@ package scientifik.kmath.histogram -import scientifik.kmath.linear.RealVector -import scientifik.kmath.linear.asVector +import scientifik.kmath.real.RealVector +import scientifik.kmath.real.asVector import scientifik.kmath.structures.Buffer import java.util.* import kotlin.math.floor @@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) - override fun contains(vector: Buffer): Boolean = contains(vector[0]) + override fun contains(point: Buffer): Boolean = contains(point[0]) internal operator fun inc() = this.also { counter.increment() } diff --git a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt index 74681ac48..10deabd73 100644 --- a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt +++ b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt @@ -4,6 +4,7 @@ import koma.extensions.fill import koma.matrix.MatrixFactory import scientifik.kmath.operations.Space import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.NDStructure class KomaMatrixContext( private val factory: MatrixFactory>, @@ -85,6 +86,18 @@ class KomaMatrix(val origin: koma.matrix.Matrix, features: Set ?: return false) + } + + override fun hashCode(): Int { + var result = origin.hashCode() + result = 31 * result + features.hashCode() + return result + } + + } class KomaVector internal constructor(val origin: koma.matrix.Matrix) : Point { diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt index f9f938dcc..a749a7074 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt @@ -1,71 +1,148 @@ package scientifik.memory +/** + * Represents a display of certain memory structure. + */ interface Memory { + /** + * The length of this memory in bytes. + */ val size: Int /** - * Get a projection of this memory (it reflects the changes in the parent memory block) + * Get a projection of this memory (it reflects the changes in the parent memory block). */ fun view(offset: Int, length: Int): Memory /** - * Create a copy of this memory, which does not know anything about this memory + * Creates an independent copy of this memory. */ fun copy(): Memory /** - * Create and possibly register a new reader + * Gets or creates a reader of this memory. */ fun reader(): MemoryReader + /** + * Gets or creates a writer of this memory. + */ fun writer(): MemoryWriter - companion object { - - } + companion object } +/** + * The interface to read primitive types in this memory. + */ interface MemoryReader { + /** + * The underlying memory. + */ val memory: Memory + /** + * Reads [Double] at certain [offset]. + */ fun readDouble(offset: Int): Double + + /** + * Reads [Float] at certain [offset]. + */ fun readFloat(offset: Int): Float + + /** + * Reads [Byte] at certain [offset]. + */ fun readByte(offset: Int): Byte + + /** + * Reads [Short] at certain [offset]. + */ fun readShort(offset: Int): Short + + /** + * Reads [Int] at certain [offset]. + */ fun readInt(offset: Int): Int + + /** + * Reads [Long] at certain [offset]. + */ fun readLong(offset: Int): Long + /** + * Disposes this reader if needed. + */ fun release() } /** - * Use the memory for read then release the reader + * Uses the memory for read then releases the reader. */ inline fun Memory.read(block: MemoryReader.() -> Unit) { - reader().apply(block).apply { release() } + reader().apply(block).release() } +/** + * The interface to write primitive types into this memory. + */ interface MemoryWriter { + /** + * The underlying memory. + */ val memory: Memory + /** + * Writes [Double] at certain [offset]. + */ fun writeDouble(offset: Int, value: Double) + + /** + * Writes [Float] at certain [offset]. + */ fun writeFloat(offset: Int, value: Float) + + /** + * Writes [Byte] at certain [offset]. + */ fun writeByte(offset: Int, value: Byte) + + /** + * Writes [Short] at certain [offset]. + */ fun writeShort(offset: Int, value: Short) + + /** + * Writes [Int] at certain [offset]. + */ fun writeInt(offset: Int, value: Int) + + /** + * Writes [Long] at certain [offset]. + */ fun writeLong(offset: Int, value: Long) + /** + * Disposes this writer if needed. + */ fun release() } /** - * Use the memory for write then release the writer + * Uses the memory for write then releases the writer. */ inline fun Memory.write(block: MemoryWriter.() -> Unit) { - writer().apply(block).apply { release() } + writer().apply(block).release() } /** - * Allocate the most effective platform-specific memory + * Allocates the most effective platform-specific memory. */ expect fun Memory.Companion.allocate(length: Int): Memory + +/** + * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied + * and could be mutated independently from the resulting [Memory]. + */ +expect fun Memory.Companion.wrap(array: ByteArray): Memory diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt index 0896f0dcb..59a93f290 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt @@ -1,34 +1,53 @@ package scientifik.memory /** - * A specification to read or write custom objects with fixed size in bytes + * A specification to read or write custom objects with fixed size in bytes. + * + * @param T the type of object this spec manages. */ interface MemorySpec { /** - * Size of [T] in bytes after serialization + * Size of [T] in bytes after serialization. */ val objectSize: Int + /** + * Reads the object starting from [offset]. + */ fun MemoryReader.read(offset: Int): T + + // TODO consider thread safety + + /** + * Writes the object [value] starting from [offset]. + */ fun MemoryWriter.write(offset: Int, value: T) } -fun MemoryReader.read(spec: MemorySpec, offset: Int): T = spec.run { read(offset) } -fun MemoryWriter.write(spec: MemorySpec, offset: Int, value: T) = spec.run { write(offset, value) } +/** + * Reads the object with [spec] starting from [offset]. + */ +fun MemoryReader.read(spec: MemorySpec, offset: Int): T = with(spec) { read(offset) } -inline fun MemoryReader.readArray(spec: MemorySpec, offset: Int, size: Int) = +/** + * Writes the object [value] with [spec] starting from [offset]. + */ +fun MemoryWriter.write(spec: MemorySpec, offset: Int, value: T): Unit = with(spec) { write(offset, value) } + +/** + * Reads array of [size] objects mapped by [spec] at certain [offset]. + */ +inline fun MemoryReader.readArray(spec: MemorySpec, offset: Int, size: Int): Array = Array(size) { i -> spec.run { read(offset + i * objectSize) } } -fun MemoryWriter.writeArray(spec: MemorySpec, offset: Int, array: Array) { - spec.run { - for (i in array.indices) { - write(offset + i * objectSize, array[i]) - } - } -} +/** + * Writes [array] of objects mapped by [spec] at certain [offset]. + */ +fun MemoryWriter.writeArray(spec: MemorySpec, offset: Int, array: Array): Unit = + with(spec) { array.indices.forEach { i -> write(offset + i * objectSize, array[i]) } } -//TODO It is possible to add elastic MemorySpec with unknown object size \ No newline at end of file +// TODO It is possible to add elastic MemorySpec with unknown object size diff --git a/kmath-memory/src/jsMain/kotlin/scientifik/memory/DataViewMemory.kt b/kmath-memory/src/jsMain/kotlin/scientifik/memory/DataViewMemory.kt index 843464ab9..974750502 100644 --- a/kmath-memory/src/jsMain/kotlin/scientifik/memory/DataViewMemory.kt +++ b/kmath-memory/src/jsMain/kotlin/scientifik/memory/DataViewMemory.kt @@ -2,34 +2,25 @@ package scientifik.memory import org.khronos.webgl.ArrayBuffer import org.khronos.webgl.DataView +import org.khronos.webgl.Int8Array -/** - * Allocate the most effective platform-specific memory - */ -actual fun Memory.Companion.allocate(length: Int): Memory { - val buffer = ArrayBuffer(length) - return DataViewMemory(DataView(buffer, 0, length)) -} - -class DataViewMemory(val view: DataView) : Memory { - +private class DataViewMemory(val view: DataView) : Memory { override val size: Int get() = view.byteLength override fun view(offset: Int, length: Int): Memory { require(offset >= 0) { "offset shouldn't be negative: $offset" } require(length >= 0) { "length shouldn't be negative: $length" } - if (offset + length > size) { + require(offset + length <= size) { "Can't view memory outside the parent region." } + + if (offset + length > size) throw IndexOutOfBoundsException("offset + length > size: $offset + $length > $size") - } + return DataViewMemory(DataView(view.buffer, view.byteOffset + offset, length)) } + override fun copy(): Memory = DataViewMemory(DataView(view.buffer.slice(0))) - override fun copy(): Memory { - TODO("not implemented") //To change body of created functions use File | Settings | File Templates. - } - - private val reader = object : MemoryReader { + private val reader: MemoryReader = object : MemoryReader { override val memory: Memory get() = this@DataViewMemory override fun readDouble(offset: Int): Double = view.getFloat64(offset, false) @@ -42,17 +33,17 @@ class DataViewMemory(val view: DataView) : Memory { override fun readInt(offset: Int): Int = view.getInt32(offset, false) - override fun readLong(offset: Int): Long = (view.getInt32(offset, false).toLong() shl 32) or - view.getInt32(offset + 4, false).toLong() + override fun readLong(offset: Int): Long = + view.getInt32(offset, false).toLong() shl 32 or view.getInt32(offset + 4, false).toLong() override fun release() { - // does nothing on JS because of GC + // does nothing on JS } } override fun reader(): MemoryReader = reader - private val writer = object : MemoryWriter { + private val writer: MemoryWriter = object : MemoryWriter { override val memory: Memory get() = this@DataViewMemory override fun writeDouble(offset: Int, value: Double) { @@ -81,11 +72,27 @@ class DataViewMemory(val view: DataView) : Memory { } override fun release() { - //does nothing on JS + // does nothing on JS } - } override fun writer(): MemoryWriter = writer -} \ No newline at end of file +} + +/** + * Allocates memory based on a [DataView]. + */ +actual fun Memory.Companion.allocate(length: Int): Memory { + val buffer = ArrayBuffer(length) + return DataViewMemory(DataView(buffer, 0, length)) +} + +/** + * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied + * and could be mutated independently from the resulting [Memory]. + */ +actual fun Memory.Companion.wrap(array: ByteArray): Memory { + @Suppress("CAST_NEVER_SUCCEEDS") val int8Array = array as Int8Array + return DataViewMemory(DataView(int8Array.buffer, int8Array.byteOffset, int8Array.length)) +} diff --git a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt b/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt index df2e9847a..b5a0dd51b 100644 --- a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt +++ b/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt @@ -6,27 +6,18 @@ import java.nio.file.Files import java.nio.file.Path import java.nio.file.StandardOpenOption - -/** - * Allocate the most effective platform-specific memory - */ -actual fun Memory.Companion.allocate(length: Int): Memory { - val buffer = ByteBuffer.allocate(length) - return ByteBufferMemory(buffer) -} - private class ByteBufferMemory( val buffer: ByteBuffer, val startOffset: Int = 0, override val size: Int = buffer.limit() ) : Memory { - - @Suppress("NOTHING_TO_INLINE") private inline fun position(o: Int): Int = startOffset + o override fun view(offset: Int, length: Int): Memory { - if (offset + length > size) error("Selecting a Memory view outside of memory range") + require(offset >= 0) { "offset shouldn't be negative: $offset" } + require(length >= 0) { "length shouldn't be negative: $length" } + require(offset + length <= size) { "Can't view memory outside the parent region." } return ByteBufferMemory(buffer, position(offset), length) } @@ -36,10 +27,9 @@ private class ByteBufferMemory( copy.put(buffer) copy.flip() return ByteBufferMemory(copy) - } - private val reader = object : MemoryReader { + private val reader: MemoryReader = object : MemoryReader { override val memory: Memory get() = this@ByteBufferMemory override fun readDouble(offset: Int) = buffer.getDouble(position(offset)) @@ -55,13 +45,13 @@ private class ByteBufferMemory( override fun readLong(offset: Int) = buffer.getLong(position(offset)) override fun release() { - //does nothing on JVM + // does nothing on JVM } } override fun reader(): MemoryReader = reader - private val writer = object : MemoryWriter { + private val writer: MemoryWriter = object : MemoryWriter { override val memory: Memory get() = this@ByteBufferMemory override fun writeDouble(offset: Int, value: Double) { @@ -89,7 +79,7 @@ private class ByteBufferMemory( } override fun release() { - //does nothing on JVM + // does nothing on JVM } } @@ -97,10 +87,32 @@ private class ByteBufferMemory( } /** - * Use direct memory-mapped buffer from file to read something and close it afterwards. + * Allocates memory based on a [ByteBuffer]. */ -fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R { - return FileChannel.open(this, StandardOpenOption.READ).use { +actual fun Memory.Companion.allocate(length: Int): Memory = + ByteBufferMemory(checkNotNull(ByteBuffer.allocate(length))) + +/** + * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied + * and could be mutated independently from the resulting [Memory]. + */ +actual fun Memory.Companion.wrap(array: ByteArray): Memory = ByteBufferMemory(checkNotNull(ByteBuffer.wrap(array))) + +/** + * Wraps this [ByteBuffer] to [Memory] object. + * + * @receiver the byte buffer. + * @param startOffset the start offset. + * @param size the size of memory to map. + * @return the [Memory] object. + */ +fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory = + ByteBufferMemory(this, startOffset, size) + +/** + * Uses direct memory-mapped buffer from file to read something and close it afterwards. + */ +fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R = + FileChannel.open(this, StandardOpenOption.READ).use { ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() } -} \ No newline at end of file diff --git a/kmath-prob/build.gradle.kts b/kmath-prob/build.gradle.kts index 59b25d340..a69d61b73 100644 --- a/kmath-prob/build.gradle.kts +++ b/kmath-prob/build.gradle.kts @@ -8,4 +8,10 @@ kotlin.sourceSets { api(project(":kmath-coroutines")) } } + jvmMain{ + dependencies{ + api("org.apache.commons:commons-rng-sampling:1.3") + api("org.apache.commons:commons-rng-simple:1.3") + } + } } \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt index 13983c6b2..2a225fe47 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt @@ -6,10 +6,16 @@ import kotlin.random.Random * A basic generator */ interface RandomGenerator { + fun nextBoolean(): Boolean + fun nextDouble(): Double fun nextInt(): Int + fun nextInt(until: Int): Int fun nextLong(): Long - fun nextBlock(size: Int): ByteArray + fun nextLong(until: Long): Long + + fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size) + fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) } /** * Create a new generator which is independent from current generator (operations on new generator do not affect this one @@ -21,21 +27,29 @@ interface RandomGenerator { fun fork(): RandomGenerator companion object { - val default by lazy { DefaultGenerator(Random.nextLong()) } + val default by lazy { DefaultGenerator() } + + fun default(seed: Long) = DefaultGenerator(Random(seed)) } } -class DefaultGenerator(seed: Long?) : RandomGenerator { - private val random = seed?.let { Random(it) } ?: Random +inline class DefaultGenerator(val random: Random = Random) : RandomGenerator { + override fun nextBoolean(): Boolean = random.nextBoolean() override fun nextDouble(): Double = random.nextDouble() override fun nextInt(): Int = random.nextInt() + override fun nextInt(until: Int): Int = random.nextInt(until) override fun nextLong(): Long = random.nextLong() - override fun nextBlock(size: Int): ByteArray = random.nextBytes(size) + override fun nextLong(until: Long): Long = random.nextLong(until) - override fun fork(): RandomGenerator = DefaultGenerator(nextLong()) + override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { + random.nextBytes(array, fromIndex, toIndex) + } + override fun nextBytes(size: Int): ByteArray = random.nextBytes(size) + + override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong()) } \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/UniformDistribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/UniformDistribution.kt index ff6e2a551..9d96bff59 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/UniformDistribution.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/UniformDistribution.kt @@ -28,4 +28,7 @@ class UniformDistribution(val range: ClosedFloatingPointRange) : Univari else -> (arg - range.start) / length } } -} \ No newline at end of file +} + +fun Distribution.Companion.uniform(range: ClosedFloatingPointRange): UniformDistribution = + UniformDistribution(range) \ No newline at end of file diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt new file mode 100644 index 000000000..f5a73a08b --- /dev/null +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt @@ -0,0 +1,67 @@ +package scientifik.kmath.prob + +import org.apache.commons.rng.UniformRandomProvider +import org.apache.commons.rng.simple.RandomSource + +class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator { + internal val random: UniformRandomProvider = seed?.let { + RandomSource.create(source, seed) + } ?: RandomSource.create(source) + + override fun nextBoolean(): Boolean = random.nextBoolean() + + override fun nextDouble(): Double = random.nextDouble() + + override fun nextInt(): Int = random.nextInt() + override fun nextInt(until: Int): Int = random.nextInt(until) + + override fun nextLong(): Long = random.nextLong() + override fun nextLong(until: Long): Long = random.nextLong(until) + + override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { + require(toIndex > fromIndex) + random.nextBytes(array, fromIndex, toIndex - fromIndex) + } + + override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) +} + +inline class RandomGeneratorProvider(val generator: RandomGenerator) : UniformRandomProvider { + override fun nextBoolean(): Boolean = generator.nextBoolean() + + override fun nextFloat(): Float = generator.nextDouble().toFloat() + + override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + override fun nextBytes(bytes: ByteArray, start: Int, len: Int) { + generator.fillBytes(bytes, start, start + len) + } + + override fun nextInt(): Int = generator.nextInt() + + override fun nextInt(n: Int): Int = generator.nextInt(n) + + override fun nextDouble(): Double = generator.nextDouble() + + override fun nextLong(): Long = generator.nextLong() + + override fun nextLong(n: Long): Long = generator.nextLong(n) +} + +/** + * Represent this [RandomGenerator] as commons-rng [UniformRandomProvider] preserving and mirroring its current state. + * Getting new value from one of those changes the state of another. + */ +fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this is RandomSourceGenerator) { + random +} else { + RandomGeneratorProvider(this) +} + +fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator = + RandomSourceGenerator(source, seed) + +fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator = + fromSource(RandomSource.MT, seed) diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt new file mode 100644 index 000000000..412454994 --- /dev/null +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt @@ -0,0 +1,109 @@ +package scientifik.kmath.prob + +import org.apache.commons.rng.UniformRandomProvider +import org.apache.commons.rng.sampling.distribution.* +import scientifik.kmath.chains.BlockingIntChain +import scientifik.kmath.chains.BlockingRealChain +import scientifik.kmath.chains.Chain +import java.util.* +import kotlin.math.PI +import kotlin.math.exp +import kotlin.math.pow +import kotlin.math.sqrt + +abstract class ContinuousSamplerDistribution : Distribution { + + private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() { + private val sampler = buildCMSampler(generator) + + override fun nextDouble(): Double = sampler.sample() + + override fun fork(): Chain = ContinuousSamplerChain(generator.fork()) + } + + protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler + + override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator) +} + +abstract class DiscreteSamplerDistribution : Distribution { + + private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() { + private val sampler = buildSampler(generator) + + override fun nextInt(): Int = sampler.sample() + + override fun fork(): Chain = ContinuousSamplerChain(generator.fork()) + } + + protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler + + override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator) +} + +enum class NormalSamplerMethod { + BoxMuller, + Marsaglia, + Ziggurat +} + +private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomProvider): NormalizedGaussianSampler = + when (method) { + NormalSamplerMethod.BoxMuller -> BoxMullerNormalizedGaussianSampler(provider) + NormalSamplerMethod.Marsaglia -> MarsagliaNormalizedGaussianSampler(provider) + NormalSamplerMethod.Ziggurat -> ZigguratNormalizedGaussianSampler(provider) + } + +fun Distribution.Companion.normal( + method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat +): Distribution = object : ContinuousSamplerDistribution() { + override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { + val provider: UniformRandomProvider = generator.asUniformRandomProvider() + return normalSampler(method, provider) + } + + override fun probability(arg: Double): Double { + return exp(-arg.pow(2) / 2) / sqrt(PI * 2) + } +} + +fun Distribution.Companion.normal( + mean: Double, + sigma: Double, + method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat +): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() { + private val sigma2 = sigma.pow(2) + private val norm = sigma * sqrt(PI * 2) + + override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { + val provider: UniformRandomProvider = generator.asUniformRandomProvider() + val normalizedSampler = normalSampler(method, provider) + return GaussianSampler(normalizedSampler, mean, sigma) + } + + override fun probability(arg: Double): Double { + return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm + } +} + +fun Distribution.Companion.poisson( + lambda: Double +): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() { + + override fun buildSampler(generator: RandomGenerator): DiscreteSampler { + return PoissonSampler.of(generator.asUniformRandomProvider(), lambda) + } + + private val computedProb: HashMap = hashMapOf(0 to exp(-lambda)) + + override fun probability(arg: Int): Double { + require(arg >= 0) { "The argument must be >= 0" } + return if (arg > 40) { + exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda) + } else { + computedProb.getOrPut(arg) { + probability(arg - 1) * lambda / arg + } + } + } +} diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt new file mode 100644 index 000000000..7638c695e --- /dev/null +++ b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt @@ -0,0 +1,28 @@ +package scientifik.kmath.prob + +import kotlinx.coroutines.flow.take +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class CommonsDistributionsTest { + @Test + fun testNormalDistributionSuspend() { + val distribution = Distribution.normal(7.0, 2.0) + val generator = RandomGenerator.default(1) + val sample = runBlocking { + distribution.sample(generator).take(1000).toList() + } + Assertions.assertEquals(7.0, sample.average(), 0.1) + } + + @Test + fun testNormalDistributionBlocking() { + val distribution = Distribution.normal(7.0, 2.0) + val generator = RandomGenerator.default(1) + val sample = distribution.sample(generator).nextBlock(1000) + Assertions.assertEquals(7.0, sample.average(), 0.1) + } + +} \ No newline at end of file diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/StatisticTest.kt b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/StatisticTest.kt index af069810f..2613f71d5 100644 --- a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/StatisticTest.kt +++ b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/StatisticTest.kt @@ -9,7 +9,7 @@ import kotlin.test.Test class StatisticTest { //create a random number generator. - val generator = DefaultGenerator(1) + val generator = RandomGenerator.default(1) //Create a stateless chain from generator. val data = generator.chain { nextDouble() } //Convert a chaint to Flow and break it into chunks. diff --git a/settings.gradle.kts b/settings.gradle.kts index a08d5f7ee..487e1d87f 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,10 +1,14 @@ pluginManagement { + val toolsVersion = "0.5.0" + plugins { - id("scientifik.mpp") version "0.4.1" - id("scientifik.jvm") version "0.4.1" - id("scientifik.atomic") version "0.4.1" - id("scientifik.publish") version "0.4.1" + id("kotlinx.benchmark") version "0.2.0-dev-8" + id("scientifik.mpp") version toolsVersion + id("scientifik.jvm") version toolsVersion + id("scientifik.atomic") version toolsVersion + id("scientifik.publish") version toolsVersion + kotlin("plugin.allopen") version "1.3.72" } repositories { @@ -20,7 +24,7 @@ pluginManagement { resolutionStrategy { eachPlugin { when (requested.id.id) { - "scientifik.mpp", "scientifik.jvm", "scientifik.publish" -> useModule("scientifik:gradle-tools:${requested.version}") + "scientifik.mpp", "scientifik.jvm", "scientifik.publish" -> useModule("scientifik:gradle-tools:$toolsVersion") } } } @@ -42,5 +46,6 @@ include( ":kmath-dimensions", ":kmath-for-real", ":kmath-geometry", + ":kmath-ast", ":examples" )