Dev #127
32
CHANGELOG.md
Normal file
32
CHANGELOG.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# KMath
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Functional Expressions API
|
||||||
|
- Mathematical Syntax Tree, its interpreter and API
|
||||||
|
- String to MST parser (https://github.com/mipt-npm/kmath/pull/120)
|
||||||
|
- MST to JVM bytecode translator (https://github.com/mipt-npm/kmath/pull/94)
|
||||||
|
- FloatBuffer (specialized MutableBuffer over FloatArray)
|
||||||
|
- FlaggedBuffer to associate primitive numbers buffer with flags (to mark values infinite or missing, etc.)
|
||||||
|
- Specialized builder functions for all primitive buffers like `IntBuffer(25) { it + 1 }` (https://github.com/mipt-npm/kmath/pull/125)
|
||||||
|
- Interface `NumericAlgebra` where `number` operation is available to convert numbers to algebraic elements
|
||||||
|
- Inverse trigonometric functions support in ExtendedField (`asin`, `acos`, `atan`) (https://github.com/mipt-npm/kmath/pull/114)
|
||||||
|
- New space extensions: `average` and `averageWith`
|
||||||
|
- Local coding conventions
|
||||||
|
- Geometric Domains API in `kmath-core`
|
||||||
|
- Blocking chains in `kmath-coroutines`
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations
|
||||||
|
- `power(T, Int)` extension function has preconditions and supports `Field<T>`
|
||||||
|
- Memory objects have more preconditions (overflow checking)
|
||||||
|
- `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114)
|
||||||
|
- Gradle version: 6.3 -> 6.5.1
|
||||||
|
- Moved probability distributions to commons-rng and to `kmath-prob`.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106)
|
||||||
|
- D3.dim value in `kmath-dimensions`
|
||||||
|
- Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101)
|
||||||
|
- Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93)
|
@ -1,8 +1,8 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id("scientifik.publish") version "0.4.2" apply false
|
id("scientifik.publish") apply false
|
||||||
}
|
}
|
||||||
|
|
||||||
val kmathVersion by extra("0.1.4-dev-4")
|
val kmathVersion by extra("0.1.4-dev-8")
|
||||||
|
|
||||||
val bintrayRepo by extra("scientifik")
|
val bintrayRepo by extra("scientifik")
|
||||||
val githubProject by extra("kmath")
|
val githubProject by extra("kmath")
|
||||||
@ -11,6 +11,7 @@ allprojects {
|
|||||||
repositories {
|
repositories {
|
||||||
jcenter()
|
jcenter()
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
}
|
}
|
||||||
|
|
||||||
group = "scientifik"
|
group = "scientifik"
|
||||||
|
124
doc/algebra.md
124
doc/algebra.md
@ -1,110 +1,124 @@
|
|||||||
# Algebra and algebra elements
|
# Algebraic Structures and Algebraic Elements
|
||||||
|
|
||||||
The mathematical operations in `kmath` are generally separated from mathematical objects.
|
The mathematical operations in KMath are generally separated from mathematical objects. This means that to perform an
|
||||||
This means that in order to perform an operation, say `+`, one needs two objects of a type `T` and
|
operation, say `+`, one needs two objects of a type `T` and an algebra context, which draws appropriate operation up,
|
||||||
and algebra context which defines appropriate operation, say `Space<T>`. Next one needs to run actual operation
|
say `Space<T>`. Next one needs to run the actual operation in the context:
|
||||||
in the context:
|
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
val a: T
|
import scientifik.kmath.operations.*
|
||||||
val b: T
|
|
||||||
val space: Space<T>
|
|
||||||
|
|
||||||
val c = space.run{a + b}
|
val a: T = ...
|
||||||
|
val b: T = ...
|
||||||
|
val space: Space<T> = ...
|
||||||
|
|
||||||
|
val c = space { a + b }
|
||||||
```
|
```
|
||||||
|
|
||||||
From the first glance, this distinction seems to be a needless complication, but in fact one needs
|
At first glance, this distinction seems to be a needless complication, but in fact one needs to remember that in
|
||||||
to remember that in mathematics, one could define different operations on the same objects. For example,
|
mathematics, one could draw up different operations on same objects. For example, one could use different types of
|
||||||
one could use different types of geometry for vectors.
|
geometry for vectors.
|
||||||
|
|
||||||
## Algebra hierarchy
|
## Algebraic Structures
|
||||||
|
|
||||||
Mathematical contexts have the following hierarchy:
|
Mathematical contexts have the following hierarchy:
|
||||||
|
|
||||||
**Space** <- **Ring** <- **Field**
|
**Algebra** ← **Space** ← **Ring** ← **Field**
|
||||||
|
|
||||||
All classes follow abstract mathematical constructs.
|
These interfaces follow real algebraic structures:
|
||||||
[Space](http://mathworld.wolfram.com/Space.html) defines `zero` element, addition operation and multiplication by constant,
|
|
||||||
[Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and unit `one` element,
|
|
||||||
[Field](http://mathworld.wolfram.com/Field.html) adds division operation.
|
|
||||||
|
|
||||||
Typical case of `Field` is the `RealField` which works on doubles. And typical case of `Space` is a `VectorSpace`.
|
- [Space](https://mathworld.wolfram.com/VectorSpace.html) defines addition, its neutral element (i.e. 0) and scalar
|
||||||
|
multiplication;
|
||||||
|
- [Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and its neutral element (i.e. 1);
|
||||||
|
- [Field](http://mathworld.wolfram.com/Field.html) adds division operation.
|
||||||
|
|
||||||
In some cases algebra context could hold additional operation like `exp` or `sin`, in this case it inherits appropriate
|
A typical implementation of `Field<T>` is the `RealField` which works on doubles, and `VectorSpace` for `Space<T>`.
|
||||||
interface. Also a context could have an operation which produces an element outside of its context. For example
|
|
||||||
`Matrix` `dot` operation produces a matrix with new dimensions which can be incompatible with initial matrix in
|
|
||||||
terms of linear operations.
|
|
||||||
|
|
||||||
## Algebra element
|
In some cases algebra context can hold additional operations like `exp` or `sin`, and then it inherits appropriate
|
||||||
|
interface. Also, contexts may have operations, which produce elements outside of the context. For example, `Matrix.dot`
|
||||||
|
operation produces a matrix with new dimensions, which can be incompatible with initial matrix in terms of linear
|
||||||
|
operations.
|
||||||
|
|
||||||
In order to achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving contexts
|
## Algebraic Element
|
||||||
`kmath` introduces special type objects called `MathElement`. A `MathElement` is basically some object coupled to
|
|
||||||
|
To achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving
|
||||||
|
contexts KMath submits special type objects called `MathElement`. A `MathElement` is basically some object coupled to
|
||||||
a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts,
|
a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts,
|
||||||
but it also holds reference to the `ComplexField` singleton which allows to perform direct operations on `Complex`
|
but it also holds reference to the `ComplexField` singleton, which allows performing direct operations on `Complex`
|
||||||
numbers without explicit involving the context like:
|
numbers without explicit involving the context like:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
val c1 = Complex(1.0, 1.0)
|
import scientifik.kmath.operations.*
|
||||||
val c2 = Complex(1.0, -1.0)
|
|
||||||
val c3 = c1 + c2 + 3.0.toComplex()
|
// Using elements
|
||||||
//or with field notation:
|
val c1 = Complex(1.0, 1.0)
|
||||||
val c4 = ComplexField.run{c1 + i - 2.0}
|
val c2 = Complex(1.0, -1.0)
|
||||||
|
val c3 = c1 + c2 + 3.0.toComplex()
|
||||||
|
|
||||||
|
// Using context
|
||||||
|
val c4 = ComplexField { c1 + i - 2.0 }
|
||||||
```
|
```
|
||||||
|
|
||||||
Both notations have their pros and cons.
|
Both notations have their pros and cons.
|
||||||
|
|
||||||
The hierarchy for algebra elements follows the hierarchy for the corresponding algebra.
|
The hierarchy for algebraic elements follows the hierarchy for the corresponding algebraic structures.
|
||||||
|
|
||||||
**MathElement** <- **SpaceElement** <- **RingElement** <- **FieldElement**
|
**MathElement** ← **SpaceElement** ← **RingElement** ← **FieldElement**
|
||||||
|
|
||||||
**MathElement** is the generic common ancestor of the class with context.
|
`MathElement<C>` is the generic common ancestor of the class with context.
|
||||||
|
|
||||||
One important distinction between algebra elements and algebra contexts is that algebra element has three type parameters:
|
One major distinction between algebraic elements and algebraic contexts is that elements have three type
|
||||||
|
parameters:
|
||||||
|
|
||||||
1. The type of elements, field operates on.
|
1. The type of elements, the field operates on.
|
||||||
2. The self-type of the element returned from operation (must be algebra element).
|
2. The self-type of the element returned from operation (which has to be an algebraic element).
|
||||||
3. The type of the algebra over first type-parameter.
|
3. The type of the algebra over first type-parameter.
|
||||||
|
|
||||||
The middle type is needed in case algebra members do not store context. For example, it is not possible to add
|
The middle type is needed for of algebra members do not store context. For example, it is impossible to add a context
|
||||||
a context to regular `Double`. The element performs automatic conversions from context types and back.
|
to regular `Double`. The element performs automatic conversions from context types and back. One should use context
|
||||||
One should used context operations in all important places. The performance of element operations is not guaranteed.
|
operations in all performance-critical places. The performance of element operations is not guaranteed.
|
||||||
|
|
||||||
## Spaces and fields
|
## Spaces and Fields
|
||||||
|
|
||||||
An obvious first choice of mathematical objects to implement in a context-oriented style are algebraic elements like spaces,
|
KMath submits both contexts and elements for builtin algebraic structures:
|
||||||
rings and fields. Those are located in the `scientifik.kmath.operations.Algebra.kt` file. Alongside common contexts, the file includes definitions for algebra elements like `FieldElement`. A `FieldElement` object
|
|
||||||
stores a reference to the `Field` which contains additive and multiplicative operations, meaning
|
|
||||||
it has one fixed context attached and does not require explicit external context. So those `MathElements` can be operated without context:
|
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
val c1 = Complex(1.0, 2.0)
|
val c1 = Complex(1.0, 2.0)
|
||||||
val c2 = ComplexField.i
|
val c2 = ComplexField.i
|
||||||
|
|
||||||
val c3 = c1 + c2
|
val c3 = c1 + c2
|
||||||
|
// or
|
||||||
|
val c3 = ComplexField { c1 + c2 }
|
||||||
```
|
```
|
||||||
|
|
||||||
`ComplexField` also features special operations to mix complex and real numbers, for example:
|
Also, `ComplexField` features special operations to mix complex and real numbers, for example:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
val c1 = Complex(1.0, 2.0)
|
val c1 = Complex(1.0, 2.0)
|
||||||
val c2 = ComplexField.run{ c1 - 1.0} // Returns: [re:0.0, im: 2.0]
|
val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0)
|
||||||
val c3 = ComplexField.run{ c1 - i*2.0}
|
val c3 = ComplexField { c1 - i * 2.0 }
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: In theory it is possible to add behaviors directly to the context, but currently kotlin syntax does not support
|
**Note**: In theory it is possible to add behaviors directly to the context, but as for now Kotlin does not support
|
||||||
that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and [KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates.
|
that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and
|
||||||
|
[KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates.
|
||||||
|
|
||||||
## Nested fields
|
## Nested fields
|
||||||
|
|
||||||
Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex elements like so:
|
Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex
|
||||||
|
elements like so:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
val element = NDElement.complex(shape = intArrayOf(2,2)){ index: IntArray ->
|
val element = NDElement.complex(shape = intArrayOf(2, 2)) { index: IntArray ->
|
||||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
The `element` in this example is a member of the `Field` of 2-d structures, each element of which is a member of its own
|
The `element` in this example is a member of the `Field` of 2D structures, each element of which is a member of its own
|
||||||
`ComplexField`. The important thing is one does not need to create a special n-d class to hold complex
|
`ComplexField`. It is important one does not need to create a special n-d class to hold complex
|
||||||
numbers and implement operations on it, one just needs to provide a field for its elements.
|
numbers and implement operations on it, one just needs to provide a field for its elements.
|
||||||
|
|
||||||
**Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts like
|
**Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts like
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
# Buffers
|
# Buffers
|
||||||
|
|
||||||
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
|
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
|
||||||
There are different types of buffers:
|
There are different types of buffers:
|
||||||
|
|
||||||
* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays.
|
* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays.
|
||||||
* Boxing `ListBuffer` wrapping a list
|
* Boxing `ListBuffer` wrapping a list
|
||||||
* Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value
|
* Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value
|
||||||
* `MemoryBuffer` allows direct allocation of objects in continuous memory block.
|
* `MemoryBuffer` allows direct allocation of objects in continuous memory block.
|
||||||
@ -12,4 +13,5 @@ Some kmath features require a `BufferFactory` class to operate properly. A gener
|
|||||||
buffer for given reified type (for types with custom memory buffer it still better to use their own `MemoryBuffer.create()` factory).
|
buffer for given reified type (for types with custom memory buffer it still better to use their own `MemoryBuffer.create()` factory).
|
||||||
|
|
||||||
## Buffer performance
|
## Buffer performance
|
||||||
|
|
||||||
One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead
|
One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead
|
34
doc/codestyle.md
Normal file
34
doc/codestyle.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Coding Conventions
|
||||||
|
|
||||||
|
KMath code follows general [Kotlin conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but
|
||||||
|
with a number of small changes and clarifications.
|
||||||
|
|
||||||
|
## Utility Class Naming
|
||||||
|
|
||||||
|
Filename should coincide with a name of one of the classes contained in the file or start with small letter and
|
||||||
|
describe its contents.
|
||||||
|
|
||||||
|
The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that
|
||||||
|
file names should start with a capital letter even if file does not contain classes. Yet starting utility classes and
|
||||||
|
aggregators with a small letter seems to be a good way to visually separate those files.
|
||||||
|
|
||||||
|
This convention could be changed in future in a non-breaking way.
|
||||||
|
|
||||||
|
## Private Variable Naming
|
||||||
|
|
||||||
|
Private variables' names may start with underscore `_` for of the private mutable variable is shadowed by the public
|
||||||
|
read-only value with the same meaning.
|
||||||
|
|
||||||
|
This rule does not permit underscores in names, but it is sometimes useful to "underscore" the fact that public and
|
||||||
|
private versions draw up the same entity. It is allowed only for private variables.
|
||||||
|
|
||||||
|
This convention could be changed in future in a non-breaking way.
|
||||||
|
|
||||||
|
## Functions and Properties One-liners
|
||||||
|
|
||||||
|
Use one-liners when they occupy single code window line both for functions and properties with getters like
|
||||||
|
`val b: String get() = "fff"`. The same should be performed with multiline expressions when they could be
|
||||||
|
cleanly separated.
|
||||||
|
|
||||||
|
There is no universal consensus whenever use `fun a() = ...` or `fun a() { return ... }`. Yet from reader outlook
|
||||||
|
one-lines seem to better show that the property or function is easily calculated.
|
@ -1,6 +1,6 @@
|
|||||||
## Basic linear algebra layout
|
## Basic linear algebra layout
|
||||||
|
|
||||||
Kmath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared
|
KMath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared
|
||||||
in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple
|
in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple
|
||||||
back-ends. The new operations added as extensions to contexts instead of being member functions of data structures.
|
back-ends. The new operations added as extensions to contexts instead of being member functions of data structures.
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Nd-structure generation and operations
|
# ND-structure generation and operations
|
||||||
|
|
||||||
**TODO**
|
**TODO**
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
|||||||
plugins {
|
plugins {
|
||||||
java
|
java
|
||||||
kotlin("jvm")
|
kotlin("jvm")
|
||||||
kotlin("plugin.allopen") version "1.3.71"
|
kotlin("plugin.allopen") version "1.3.72"
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-7"
|
id("kotlinx.benchmark") version "0.2.0-dev-8"
|
||||||
}
|
}
|
||||||
|
|
||||||
configure<AllOpenExtension> {
|
configure<AllOpenExtension> {
|
||||||
@ -24,16 +24,18 @@ sourceSets {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
implementation(project(":kmath-ast"))
|
||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
|
implementation(project(":kmath-prob"))
|
||||||
implementation(project(":kmath-koma"))
|
implementation(project(":kmath-koma"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
|
||||||
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
|
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure benchmark
|
// Configure benchmark
|
||||||
@ -57,6 +59,6 @@ benchmark {
|
|||||||
|
|
||||||
tasks.withType<KotlinCompile> {
|
tasks.withType<KotlinCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = Scientifik.JVM_VERSION
|
jvmTarget = Scientifik.JVM_TARGET.toString()
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex
|
|||||||
class BufferBenchmark {
|
class BufferBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun genericDoubleBufferReadWrite() {
|
fun genericRealBufferReadWrite() {
|
||||||
val buffer = DoubleBuffer(size){it.toDouble()}
|
val buffer = RealBuffer(size){it.toDouble()}
|
||||||
|
|
||||||
(0 until size).forEach {
|
(0 until size).forEach {
|
||||||
buffer[it]
|
buffer[it]
|
||||||
|
@ -20,48 +20,39 @@ class ViktorBenchmark {
|
|||||||
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Automatic field addition`() {
|
fun automaticFieldAddition() {
|
||||||
autoField.run {
|
autoField.run {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) { res += one }
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Viktor field addition`() {
|
fun viktorFieldAddition() {
|
||||||
viktorField.run {
|
viktorField.run {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) { res += one }
|
||||||
res += one
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Raw Viktor`() {
|
fun rawViktor() {
|
||||||
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) { res = res + one }
|
||||||
res = res + one
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Real field log`() {
|
fun realdFieldLog() {
|
||||||
realField.run {
|
realField.run {
|
||||||
val fortyTwo = produce { 42.0 }
|
val fortyTwo = produce { 42.0 }
|
||||||
var res = one
|
var res = one
|
||||||
|
repeat(n) { res = ln(fortyTwo) }
|
||||||
repeat(n) {
|
|
||||||
res = ln(fortyTwo)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Raw Viktor log`() {
|
fun rawViktorLog() {
|
||||||
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
||||||
var res: F64Array
|
var res: F64Array
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
|
@ -0,0 +1,70 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.expressions.expressionInField
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
class ExpressionsInterpretersBenchmark {
|
||||||
|
private val algebra: Field<Double> = RealField
|
||||||
|
fun functionalExpression() {
|
||||||
|
val expr = algebra.expressionInField {
|
||||||
|
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
invokeAndSum(expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun mstExpression() {
|
||||||
|
val expr = algebra.mstInField {
|
||||||
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
invokeAndSum(expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun asmExpression() {
|
||||||
|
val expr = algebra.mstInField {
|
||||||
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
|
}.compile()
|
||||||
|
|
||||||
|
invokeAndSum(expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun invokeAndSum(expr: Expression<Double>) {
|
||||||
|
val random = Random(0)
|
||||||
|
var sum = 0.0
|
||||||
|
|
||||||
|
repeat(1000000) {
|
||||||
|
sum += expr("x" to random.nextDouble())
|
||||||
|
}
|
||||||
|
|
||||||
|
println(sum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val benchmark = ExpressionsInterpretersBenchmark()
|
||||||
|
|
||||||
|
val fe = measureTimeMillis {
|
||||||
|
benchmark.functionalExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("fe=$fe")
|
||||||
|
|
||||||
|
val mst = measureTimeMillis {
|
||||||
|
benchmark.mstExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("mst=$mst")
|
||||||
|
|
||||||
|
val asm = measureTimeMillis {
|
||||||
|
benchmark.asmExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("asm=$asm")
|
||||||
|
}
|
@ -0,0 +1,71 @@
|
|||||||
|
package scientifik.kmath.commons.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.async
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
|
||||||
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
|
import scientifik.kmath.chains.BlockingRealChain
|
||||||
|
import scientifik.kmath.prob.*
|
||||||
|
import java.time.Duration
|
||||||
|
import java.time.Instant
|
||||||
|
|
||||||
|
|
||||||
|
private suspend fun runChain(): Duration {
|
||||||
|
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
||||||
|
|
||||||
|
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat)
|
||||||
|
val chain = normal.sample(generator) as BlockingRealChain
|
||||||
|
|
||||||
|
val startTime = Instant.now()
|
||||||
|
var sum = 0.0
|
||||||
|
repeat(10000001) { counter ->
|
||||||
|
|
||||||
|
sum += chain.nextDouble()
|
||||||
|
|
||||||
|
if (counter % 100000 == 0) {
|
||||||
|
val duration = Duration.between(startTime, Instant.now())
|
||||||
|
val meanValue = sum / counter
|
||||||
|
println("Chain sampler completed $counter elements in $duration: $meanValue")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Duration.between(startTime, Instant.now())
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun runDirect(): Duration {
|
||||||
|
val provider = RandomSource.create(RandomSource.MT, 123L)
|
||||||
|
val sampler = ZigguratNormalizedGaussianSampler(provider)
|
||||||
|
val startTime = Instant.now()
|
||||||
|
|
||||||
|
var sum = 0.0
|
||||||
|
repeat(10000001) { counter ->
|
||||||
|
|
||||||
|
sum += sampler.sample()
|
||||||
|
|
||||||
|
if (counter % 100000 == 0) {
|
||||||
|
val duration = Duration.between(startTime, Instant.now())
|
||||||
|
val meanValue = sum / counter
|
||||||
|
println("Direct sampler completed $counter elements in $duration: $meanValue")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Duration.between(startTime, Instant.now())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Comparing chain sampling performance with direct sampling performance
|
||||||
|
*/
|
||||||
|
fun main() {
|
||||||
|
runBlocking(Dispatchers.Default) {
|
||||||
|
val chainJob = async {
|
||||||
|
runChain()
|
||||||
|
}
|
||||||
|
|
||||||
|
val directJob = async {
|
||||||
|
runDirect()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("Chain: ${chainJob.await()}")
|
||||||
|
println("Direct: ${directJob.await()}")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -5,10 +5,11 @@ import scientifik.kmath.chains.Chain
|
|||||||
import scientifik.kmath.chains.collectWithState
|
import scientifik.kmath.chains.collectWithState
|
||||||
import scientifik.kmath.prob.Distribution
|
import scientifik.kmath.prob.Distribution
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.normal
|
||||||
|
|
||||||
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||||
|
|
||||||
fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState(),{it.copy()}){ chain->
|
fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState(), { it.copy() }) { chain ->
|
||||||
val next = chain.next()
|
val next = chain.next()
|
||||||
num++
|
num++
|
||||||
value += next
|
value += next
|
||||||
|
@ -27,7 +27,7 @@ fun main() {
|
|||||||
|
|
||||||
val complexTime = measureTimeMillis {
|
val complexTime = measureTimeMillis {
|
||||||
complexField.run {
|
complexField.run {
|
||||||
var res = one
|
var res: NDBuffer<Complex> = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += 1.0
|
||||||
}
|
}
|
||||||
|
@ -23,14 +23,14 @@ fun main() {
|
|||||||
|
|
||||||
measureAndPrint("Automatic field addition") {
|
measureAndPrint("Automatic field addition") {
|
||||||
autoField.run {
|
autoField.run {
|
||||||
var res = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += number(1.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Element addition"){
|
measureAndPrint("Element addition") {
|
||||||
var res = genericField.one
|
var res = genericField.one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += 1.0
|
||||||
@ -63,7 +63,7 @@ fun main() {
|
|||||||
genericField.run {
|
genericField.run {
|
||||||
var res: NDBuffer<Double> = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += one // con't avoid using `one` due to resolution ambiguity
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ fun main(args: Array<String>) {
|
|||||||
val n = 6000
|
val n = 6000
|
||||||
|
|
||||||
val array = DoubleArray(n * n) { 1.0 }
|
val array = DoubleArray(n * n) { 1.0 }
|
||||||
val buffer = DoubleBuffer(array)
|
val buffer = RealBuffer(array)
|
||||||
val strides = DefaultStrides(intArrayOf(n, n))
|
val strides = DefaultStrides(intArrayOf(n, n))
|
||||||
|
|
||||||
val structure = BufferNDStructure(strides, buffer)
|
val structure = BufferNDStructure(strides, buffer)
|
||||||
|
@ -26,10 +26,10 @@ fun main(args: Array<String>) {
|
|||||||
}
|
}
|
||||||
println("Array mapping finished in $time2 millis")
|
println("Array mapping finished in $time2 millis")
|
||||||
|
|
||||||
val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 })
|
val buffer = RealBuffer(DoubleArray(n * n) { 1.0 })
|
||||||
|
|
||||||
val time3 = measureTimeMillis {
|
val time3 = measureTimeMillis {
|
||||||
val target = DoubleBuffer(DoubleArray(n * n))
|
val target = RealBuffer(DoubleArray(n * n))
|
||||||
val res = array.forEachIndexed { index, value ->
|
val res = array.forEachIndexed { index, value ->
|
||||||
target[index] = value + 1
|
target[index] = value + 1
|
||||||
}
|
}
|
||||||
|
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
Binary file not shown.
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
|||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.3-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
2
gradlew
vendored
2
gradlew
vendored
@ -82,6 +82,7 @@ esac
|
|||||||
|
|
||||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||||
|
|
||||||
|
|
||||||
# Determine the Java command to use to start the JVM.
|
# Determine the Java command to use to start the JVM.
|
||||||
if [ -n "$JAVA_HOME" ] ; then
|
if [ -n "$JAVA_HOME" ] ; then
|
||||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||||
@ -129,6 +130,7 @@ fi
|
|||||||
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
||||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||||
|
|
||||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||||
|
|
||||||
# We build the pattern for arguments to be converted via cygpath
|
# We build the pattern for arguments to be converted via cygpath
|
||||||
|
1
gradlew.bat
vendored
1
gradlew.bat
vendored
@ -84,6 +84,7 @@ set CMD_LINE_ARGS=%*
|
|||||||
|
|
||||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||||
|
|
||||||
|
|
||||||
@rem Execute Gradle
|
@rem Execute Gradle
|
||||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
|
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
|
||||||
|
|
||||||
|
91
kmath-ast/README.md
Normal file
91
kmath-ast/README.md
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# Abstract Syntax Tree Expression Representation and Operations (`kmath-ast`)
|
||||||
|
|
||||||
|
This subproject implements the following features:
|
||||||
|
|
||||||
|
- Expression Language and its parser.
|
||||||
|
- MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation.
|
||||||
|
- Type-safe builder for MST.
|
||||||
|
- Evaluating expressions by traversing MST.
|
||||||
|
|
||||||
|
> #### Artifact:
|
||||||
|
> This module is distributed in the artifact `scientifik:kmath-ast:0.1.4-dev-8`.
|
||||||
|
>
|
||||||
|
> **Gradle:**
|
||||||
|
>
|
||||||
|
> ```gradle
|
||||||
|
> repositories {
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
|
> maven { url https://dl.bintray.com/hotkeytlt/maven' }
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {
|
||||||
|
> implementation 'scientifik:kmath-ast:0.1.4-dev-8'
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
> **Gradle Kotlin DSL:**
|
||||||
|
>
|
||||||
|
> ```kotlin
|
||||||
|
> repositories {
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/scientifik")
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {
|
||||||
|
> implementation("scientifik:kmath-ast:0.1.4-dev-8")
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
|
||||||
|
## Dynamic Expression Code Generation with ObjectWeb ASM
|
||||||
|
|
||||||
|
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||||
|
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||||
|
|
||||||
|
For example, the following builder:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||||
|
```
|
||||||
|
|
||||||
|
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||||
|
|
||||||
|
```java
|
||||||
|
package scientifik.kmath.asm.generated;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import scientifik.kmath.asm.internal.MapIntrinsics;
|
||||||
|
import scientifik.kmath.expressions.Expression;
|
||||||
|
import scientifik.kmath.operations.RealField;
|
||||||
|
|
||||||
|
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
|
||||||
|
private final RealField algebra;
|
||||||
|
|
||||||
|
public final Double invoke(Map<String, ? extends Double> arguments) {
|
||||||
|
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsmCompiledExpression_1073786867_0(RealField algebra) {
|
||||||
|
this.algebra = algebra;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Usage
|
||||||
|
|
||||||
|
This API extends MST and MstExpression, so you may optimize as both of them:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||||
|
RealField.expression("x+2".parseMath())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Known issues
|
||||||
|
|
||||||
|
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
||||||
|
class loading overhead.
|
||||||
|
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
||||||
|
|
||||||
|
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).
|
23
kmath-ast/build.gradle.kts
Normal file
23
kmath-ast/build.gradle.kts
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
plugins { id("scientifik.mpp") }
|
||||||
|
|
||||||
|
kotlin.sourceSets {
|
||||||
|
// all {
|
||||||
|
// languageSettings.apply{
|
||||||
|
// enableLanguageFeature("NewInference")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
commonMain {
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jvmMain {
|
||||||
|
dependencies {
|
||||||
|
implementation("org.ow2.asm:asm:8.0.1")
|
||||||
|
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||||
|
implementation(kotlin("reflect"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
grammar ArithmeticsEvaluator;
|
||||||
|
|
||||||
|
fragment DIGIT: '0'..'9';
|
||||||
|
fragment LETTER: 'a'..'z';
|
||||||
|
fragment CAPITAL_LETTER: 'A'..'Z';
|
||||||
|
fragment UNDERSCORE: '_';
|
||||||
|
|
||||||
|
ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*;
|
||||||
|
NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?;
|
||||||
|
MUL: '*';
|
||||||
|
DIV: '/';
|
||||||
|
PLUS: '+';
|
||||||
|
MINUS: '-';
|
||||||
|
POW: '^';
|
||||||
|
COMMA: ',';
|
||||||
|
LPAR: '(';
|
||||||
|
RPAR: ')';
|
||||||
|
WS: [ \n\t\r]+ -> skip;
|
||||||
|
|
||||||
|
num
|
||||||
|
: NUM
|
||||||
|
;
|
||||||
|
|
||||||
|
singular
|
||||||
|
: ID
|
||||||
|
;
|
||||||
|
|
||||||
|
unaryFunction
|
||||||
|
: ID LPAR subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
binaryFunction
|
||||||
|
: ID LPAR subSumChain COMMA subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
term
|
||||||
|
: num
|
||||||
|
| singular
|
||||||
|
| unaryFunction
|
||||||
|
| binaryFunction
|
||||||
|
| MINUS term
|
||||||
|
| LPAR subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
powChain
|
||||||
|
: term (POW term)*
|
||||||
|
;
|
||||||
|
|
||||||
|
divMulChain
|
||||||
|
: powChain ((DIV | MUL) powChain)*
|
||||||
|
;
|
||||||
|
|
||||||
|
subSumChain
|
||||||
|
: divMulChain ((PLUS | MINUS) divMulChain)*
|
||||||
|
;
|
||||||
|
|
||||||
|
rootParser
|
||||||
|
: subSumChain EOF
|
||||||
|
;
|
87
kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt
Normal file
87
kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import scientifik.kmath.operations.NumericAlgebra
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Mathematical Syntax Tree node for mathematical expressions.
|
||||||
|
*/
|
||||||
|
sealed class MST {
|
||||||
|
/**
|
||||||
|
* A node containing raw string.
|
||||||
|
*
|
||||||
|
* @property value the value of this node.
|
||||||
|
*/
|
||||||
|
data class Symbolic(val value: String) : MST()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A node containing a numeric value or scalar.
|
||||||
|
*
|
||||||
|
* @property value the value of this number.
|
||||||
|
*/
|
||||||
|
data class Numeric(val value: Number) : MST()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A node containing an unary operation.
|
||||||
|
*
|
||||||
|
* @property operation the identifier of operation.
|
||||||
|
* @property value the argument of this operation.
|
||||||
|
*/
|
||||||
|
data class Unary(val operation: String, val value: MST) : MST() {
|
||||||
|
companion object
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A node containing binary operation.
|
||||||
|
*
|
||||||
|
* @property operation the identifier operation.
|
||||||
|
* @property left the left operand.
|
||||||
|
* @property right the right operand.
|
||||||
|
*/
|
||||||
|
data class Binary(val operation: String, val left: MST, val right: MST) : MST() {
|
||||||
|
companion object
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO add a function with named arguments
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interprets the [MST] node with this [Algebra].
|
||||||
|
*
|
||||||
|
* @receiver the algebra that provides operations.
|
||||||
|
* @param node the node to evaluate.
|
||||||
|
* @return the value of expression.
|
||||||
|
*/
|
||||||
|
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||||
|
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||||
|
?: error("Numeric nodes are not supported by $this")
|
||||||
|
is MST.Symbolic -> symbol(node.value)
|
||||||
|
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||||
|
is MST.Binary -> when {
|
||||||
|
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||||
|
|
||||||
|
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||||
|
val number = RealField.binaryOperation(
|
||||||
|
node.operation,
|
||||||
|
node.left.value.toDouble(),
|
||||||
|
node.right.value.toDouble()
|
||||||
|
)
|
||||||
|
|
||||||
|
number(number)
|
||||||
|
}
|
||||||
|
|
||||||
|
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||||
|
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||||
|
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interprets the [MST] node with this [Algebra].
|
||||||
|
*
|
||||||
|
* @receiver the node to evaluate.
|
||||||
|
* @param algebra the algebra that provides operations.
|
||||||
|
* @return the value of expression.
|
||||||
|
*/
|
||||||
|
fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)
|
@ -0,0 +1,102 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [Algebra] over [MST] nodes.
|
||||||
|
*/
|
||||||
|
object MstAlgebra : NumericAlgebra<MST> {
|
||||||
|
override fun number(value: Number): MST = MST.Numeric(value)
|
||||||
|
|
||||||
|
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST =
|
||||||
|
MST.Unary(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MST.Binary(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [Space] over [MST] nodes.
|
||||||
|
*/
|
||||||
|
object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
|
override val zero: MST = number(0.0)
|
||||||
|
|
||||||
|
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||||
|
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||||
|
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [Ring] over [MST] nodes.
|
||||||
|
*/
|
||||||
|
object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
|
override val zero: MST = number(0.0)
|
||||||
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
|
override fun number(value: Number): MST = MstSpace.number(value)
|
||||||
|
override fun symbol(value: String): MST = MstSpace.symbol(value)
|
||||||
|
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
|
||||||
|
|
||||||
|
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
|
||||||
|
|
||||||
|
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MstSpace.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [Field] over [MST] nodes.
|
||||||
|
*/
|
||||||
|
object MstField : Field<MST> {
|
||||||
|
override val zero: MST = number(0.0)
|
||||||
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
|
override fun symbol(value: String): MST = MstRing.symbol(value)
|
||||||
|
override fun number(value: Number): MST = MstRing.number(value)
|
||||||
|
override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
||||||
|
override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
||||||
|
override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
||||||
|
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MstRing.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [ExtendedField] over [MST] nodes.
|
||||||
|
*/
|
||||||
|
object MstExtendedField : ExtendedField<MST> {
|
||||||
|
override val zero: MST = number(0.0)
|
||||||
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
|
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||||
|
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||||
|
override fun asin(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg)
|
||||||
|
override fun acos(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg)
|
||||||
|
override fun atan(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg)
|
||||||
|
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
||||||
|
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
||||||
|
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
||||||
|
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
|
||||||
|
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||||
|
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||||
|
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MstField.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg)
|
||||||
|
}
|
@ -0,0 +1,88 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.*
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
|
||||||
|
* ASM-generated expressions.
|
||||||
|
*
|
||||||
|
* @property algebra the algebra that provides operations.
|
||||||
|
* @property mst the [MST] node.
|
||||||
|
*/
|
||||||
|
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
||||||
|
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
||||||
|
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
||||||
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
|
algebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
||||||
|
algebra.number(value)
|
||||||
|
else
|
||||||
|
error("Numeric nodes are not supported by $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [Algebra].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
|
mstAlgebra: E,
|
||||||
|
block: E.() -> MST
|
||||||
|
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [Space].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstSpace.block())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [Ring].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstRing.block())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [Field].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstField.block())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [ExtendedField].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstExtendedField.block())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [FunctionalExpressionSpace].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(
|
||||||
|
block: MstSpace.() -> MST
|
||||||
|
): MstExpression<T> = algebra.mstInSpace(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [FunctionalExpressionRing].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(
|
||||||
|
block: MstRing.() -> MST
|
||||||
|
): MstExpression<T> = algebra.mstInRing(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [FunctionalExpressionField].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(
|
||||||
|
block: MstField.() -> MST
|
||||||
|
): MstExpression<T> = algebra.mstInField(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||||
|
block: MstExtendedField.() -> MST
|
||||||
|
): MstExpression<T> = algebra.mstInExtendedField(block)
|
@ -0,0 +1,97 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import com.github.h0tk3y.betterParse.combinators.*
|
||||||
|
import com.github.h0tk3y.betterParse.grammar.Grammar
|
||||||
|
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
||||||
|
import com.github.h0tk3y.betterParse.grammar.parser
|
||||||
|
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.Token
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.regexToken
|
||||||
|
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||||
|
import com.github.h0tk3y.betterParse.parser.Parser
|
||||||
|
import scientifik.kmath.operations.FieldOperations
|
||||||
|
import scientifik.kmath.operations.PowerOperations
|
||||||
|
import scientifik.kmath.operations.RingOperations
|
||||||
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TODO move to core
|
||||||
|
*/
|
||||||
|
object ArithmeticsEvaluator : Grammar<MST>() {
|
||||||
|
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
||||||
|
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
|
||||||
|
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
|
||||||
|
private val lpar: Token by regexToken("\\(")
|
||||||
|
private val rpar: Token by regexToken("\\)")
|
||||||
|
private val comma: Token by regexToken(",")
|
||||||
|
private val mul: Token by regexToken("\\*")
|
||||||
|
private val pow: Token by regexToken("\\^")
|
||||||
|
private val div: Token by regexToken("/")
|
||||||
|
private val minus: Token by regexToken("-")
|
||||||
|
private val plus: Token by regexToken("\\+")
|
||||||
|
private val ws: Token by regexToken("\\s+", ignore = true)
|
||||||
|
|
||||||
|
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||||
|
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
|
||||||
|
|
||||||
|
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
|
.map { (id, term) -> MST.Unary(id.text, term) }
|
||||||
|
|
||||||
|
private val binaryFunction: Parser<MST> by id
|
||||||
|
.and(skip(lpar))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(comma))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(rpar))
|
||||||
|
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
|
||||||
|
|
||||||
|
private val term: Parser<MST> by number
|
||||||
|
.or(binaryFunction)
|
||||||
|
.or(unaryFunction)
|
||||||
|
.or(singular)
|
||||||
|
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
|
||||||
|
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
|
|
||||||
|
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||||
|
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val divMulChain: Parser<MST> by leftAssociative(
|
||||||
|
term = powChain,
|
||||||
|
operator = div or mul use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == div)
|
||||||
|
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
else
|
||||||
|
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val subSumChain: Parser<MST> by leftAssociative(
|
||||||
|
term = divMulChain,
|
||||||
|
operator = plus or minus use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == plus)
|
||||||
|
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
else
|
||||||
|
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val rootParser: Parser<MST> by subSumChain
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tries to parse the string into [MST].
|
||||||
|
*
|
||||||
|
* @receiver the string to parse.
|
||||||
|
* @return the [MST] node.
|
||||||
|
*/
|
||||||
|
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses the string into [MST].
|
||||||
|
*
|
||||||
|
* @receiver the string to parse.
|
||||||
|
* @return the [MST] node.
|
||||||
|
*/
|
||||||
|
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)
|
64
kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt
Normal file
64
kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package scientifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.internal.AsmBuilder
|
||||||
|
import scientifik.kmath.asm.internal.MstType
|
||||||
|
import scientifik.kmath.asm.internal.buildAlgebraOperationCall
|
||||||
|
import scientifik.kmath.asm.internal.buildName
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
import scientifik.kmath.ast.MstExpression
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compile given MST to an Expression using AST compiler
|
||||||
|
*/
|
||||||
|
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
|
fun AsmBuilder<T>.visit(node: MST) {
|
||||||
|
when (node) {
|
||||||
|
is MST.Symbolic -> {
|
||||||
|
val symbol = try {
|
||||||
|
algebra.symbol(node.value)
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
|
||||||
|
if (symbol != null)
|
||||||
|
loadTConstant(symbol)
|
||||||
|
else
|
||||||
|
loadVariable(node.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Numeric -> loadNumeric(node.value)
|
||||||
|
|
||||||
|
is MST.Unary -> buildAlgebraOperationCall(
|
||||||
|
context = algebra,
|
||||||
|
name = node.operation,
|
||||||
|
fallbackMethodName = "unaryOperation",
|
||||||
|
parameterTypes = arrayOf(MstType.fromMst(node.value))
|
||||||
|
) { visit(node.value) }
|
||||||
|
|
||||||
|
is MST.Binary -> buildAlgebraOperationCall(
|
||||||
|
context = algebra,
|
||||||
|
name = node.operation,
|
||||||
|
fallbackMethodName = "binaryOperation",
|
||||||
|
parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
|
||||||
|
) {
|
||||||
|
visit(node.left)
|
||||||
|
visit(node.right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compile an [MST] to ASM using given algebra
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize performance of an [MstExpression] using ASM codegen
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
@ -0,0 +1,568 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import org.objectweb.asm.*
|
||||||
|
import org.objectweb.asm.Opcodes.*
|
||||||
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
|
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import scientifik.kmath.operations.NumericAlgebra
|
||||||
|
import java.util.*
|
||||||
|
import java.util.stream.Collectors
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
||||||
|
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
||||||
|
*
|
||||||
|
* @property T the type of AsmExpression to unwrap.
|
||||||
|
* @property algebra the algebra the applied AsmExpressions use.
|
||||||
|
* @property className the unique class name of new loaded class.
|
||||||
|
* @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
||||||
|
*/
|
||||||
|
internal class AsmBuilder<T> internal constructor(
|
||||||
|
private val classOfT: KClass<*>,
|
||||||
|
private val algebra: Algebra<T>,
|
||||||
|
private val className: String,
|
||||||
|
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||||
|
) {
|
||||||
|
/**
|
||||||
|
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||||
|
*/
|
||||||
|
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||||
|
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The instance of [ClassLoader] used by this builder.
|
||||||
|
*/
|
||||||
|
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM Type for [algebra].
|
||||||
|
*/
|
||||||
|
private val tAlgebraType: Type = algebra::class.asm
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [T].
|
||||||
|
*/
|
||||||
|
internal val tType: Type = classOfT.asm
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for new class.
|
||||||
|
*/
|
||||||
|
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Index of `this` variable in invoke method of the built subclass.
|
||||||
|
*/
|
||||||
|
private val invokeThisVar: Int = 0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Index of `arguments` variable in invoke method of the built subclass.
|
||||||
|
*/
|
||||||
|
private val invokeArgumentsVar: Int = 1
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List of constants to provide to the subclass.
|
||||||
|
*/
|
||||||
|
private val constants: MutableList<Any> = mutableListOf()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method visitor of `invoke` method of the subclass.
|
||||||
|
*/
|
||||||
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
|
/**
|
||||||
|
* State if this [AsmBuilder] needs to generate constants field.
|
||||||
|
*/
|
||||||
|
private var hasConstants: Boolean = true
|
||||||
|
|
||||||
|
/**
|
||||||
|
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||||
|
*/
|
||||||
|
internal var primitiveMode: Boolean = false
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
|
*/
|
||||||
|
internal var primitiveMask: Type = OBJECT_TYPE
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
|
*/
|
||||||
|
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stack of useful objects types on stack to verify types.
|
||||||
|
*/
|
||||||
|
private val typeStack: ArrayDeque<Type> = ArrayDeque()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stack of useful objects types on stack expected by algebra calls.
|
||||||
|
*/
|
||||||
|
internal val expectationStack: ArrayDeque<Type> = ArrayDeque(listOf(tType))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The cache for instance built by this builder.
|
||||||
|
*/
|
||||||
|
private var generatedInstance: Expression<T>? = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||||
|
*
|
||||||
|
* The built instance is cached.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
internal fun getInstance(): Expression<T> {
|
||||||
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
|
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
||||||
|
primitiveMode = true
|
||||||
|
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
|
||||||
|
primitiveMaskBoxed = tType
|
||||||
|
}
|
||||||
|
|
||||||
|
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||||
|
visit(
|
||||||
|
V1_8,
|
||||||
|
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
||||||
|
classType.internalName,
|
||||||
|
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
||||||
|
OBJECT_TYPE.internalName,
|
||||||
|
arrayOf(EXPRESSION_TYPE.internalName)
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMethod(
|
||||||
|
ACC_PUBLIC or ACC_FINAL,
|
||||||
|
"invoke",
|
||||||
|
Type.getMethodDescriptor(tType, MAP_TYPE),
|
||||||
|
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
||||||
|
null
|
||||||
|
).instructionAdapter {
|
||||||
|
invokeMethodVisitor = this
|
||||||
|
visitCode()
|
||||||
|
val l0 = label()
|
||||||
|
invokeLabel0Visitor()
|
||||||
|
areturn(tType)
|
||||||
|
val l1 = label()
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"this",
|
||||||
|
classType.descriptor,
|
||||||
|
null,
|
||||||
|
l0,
|
||||||
|
l1,
|
||||||
|
invokeThisVar
|
||||||
|
)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"arguments",
|
||||||
|
MAP_TYPE.descriptor,
|
||||||
|
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
|
||||||
|
l0,
|
||||||
|
l1,
|
||||||
|
invokeArgumentsVar
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMaxs(0, 2)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
visitMethod(
|
||||||
|
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||||
|
"invoke",
|
||||||
|
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
).instructionAdapter {
|
||||||
|
val thisVar = 0
|
||||||
|
val argumentsVar = 1
|
||||||
|
visitCode()
|
||||||
|
val l0 = label()
|
||||||
|
load(thisVar, OBJECT_TYPE)
|
||||||
|
load(argumentsVar, MAP_TYPE)
|
||||||
|
invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false)
|
||||||
|
areturn(tType)
|
||||||
|
val l1 = label()
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"this",
|
||||||
|
classType.descriptor,
|
||||||
|
null,
|
||||||
|
l0,
|
||||||
|
l1,
|
||||||
|
thisVar
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMaxs(0, 2)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
hasConstants = constants.isNotEmpty()
|
||||||
|
|
||||||
|
visitField(
|
||||||
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
|
name = "algebra",
|
||||||
|
descriptor = tAlgebraType.descriptor,
|
||||||
|
signature = null,
|
||||||
|
value = null,
|
||||||
|
block = FieldVisitor::visitEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
if (hasConstants)
|
||||||
|
visitField(
|
||||||
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
|
name = "constants",
|
||||||
|
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
||||||
|
signature = null,
|
||||||
|
value = null,
|
||||||
|
block = FieldVisitor::visitEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMethod(
|
||||||
|
ACC_PUBLIC,
|
||||||
|
"<init>",
|
||||||
|
|
||||||
|
Type.getMethodDescriptor(
|
||||||
|
Type.VOID_TYPE,
|
||||||
|
tAlgebraType,
|
||||||
|
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
||||||
|
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
).instructionAdapter {
|
||||||
|
val thisVar = 0
|
||||||
|
val algebraVar = 1
|
||||||
|
val constantsVar = 2
|
||||||
|
val l0 = label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
||||||
|
label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
load(algebraVar, tAlgebraType)
|
||||||
|
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||||
|
|
||||||
|
if (hasConstants) {
|
||||||
|
label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
load(constantsVar, OBJECT_ARRAY_TYPE)
|
||||||
|
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
label()
|
||||||
|
visitInsn(RETURN)
|
||||||
|
val l4 = label()
|
||||||
|
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"algebra",
|
||||||
|
tAlgebraType.descriptor,
|
||||||
|
null,
|
||||||
|
l0,
|
||||||
|
l4,
|
||||||
|
algebraVar
|
||||||
|
)
|
||||||
|
|
||||||
|
if (hasConstants)
|
||||||
|
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
||||||
|
|
||||||
|
visitMaxs(0, 3)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
val new = classLoader
|
||||||
|
.defineClass(className, classWriter.toByteArray())
|
||||||
|
.constructors
|
||||||
|
.first()
|
||||||
|
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
|
||||||
|
|
||||||
|
generatedInstance = new
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads a [T] constant from [constants].
|
||||||
|
*/
|
||||||
|
internal fun loadTConstant(value: T) {
|
||||||
|
if (classOfT in INLINABLE_NUMBERS) {
|
||||||
|
val expectedType = expectationStack.pop()
|
||||||
|
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||||
|
loadNumberConstant(value as Number, mustBeBoxed)
|
||||||
|
|
||||||
|
if (mustBeBoxed)
|
||||||
|
invokeMethodVisitor.checkcast(tType)
|
||||||
|
|
||||||
|
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loadObjectConstant(value as Any, tType)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Boxes the current value and pushes it.
|
||||||
|
*/
|
||||||
|
private fun box(primitive: Type) {
|
||||||
|
val r = PRIMITIVES_TO_BOXED.getValue(primitive)
|
||||||
|
|
||||||
|
invokeMethodVisitor.invokestatic(
|
||||||
|
r.internalName,
|
||||||
|
"valueOf",
|
||||||
|
Type.getMethodDescriptor(r, primitive),
|
||||||
|
false
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unboxes the current boxed value and pushes it.
|
||||||
|
*/
|
||||||
|
private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual(
|
||||||
|
NUMBER_TYPE.internalName,
|
||||||
|
NUMBER_CONVERTER_METHODS.getValue(primitive),
|
||||||
|
Type.getMethodDescriptor(primitive),
|
||||||
|
false
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads [java.lang.Object] constant from constants.
|
||||||
|
*/
|
||||||
|
private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||||
|
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||||
|
loadThis()
|
||||||
|
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
|
iconst(idx)
|
||||||
|
visitInsn(AALOAD)
|
||||||
|
checkcast(type)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun loadNumeric(value: Number) {
|
||||||
|
if (expectationStack.peek() == NUMBER_TYPE) {
|
||||||
|
loadNumberConstant(value, true)
|
||||||
|
expectationStack.pop()
|
||||||
|
typeStack.push(NUMBER_TYPE)
|
||||||
|
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
|
||||||
|
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads this variable.
|
||||||
|
*/
|
||||||
|
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
|
||||||
|
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
|
||||||
|
* from it).
|
||||||
|
*/
|
||||||
|
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
||||||
|
val boxed = value::class.asm
|
||||||
|
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
||||||
|
|
||||||
|
if (primitive != null) {
|
||||||
|
when (primitive) {
|
||||||
|
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
|
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
||||||
|
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
|
||||||
|
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
|
||||||
|
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
|
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mustBeBoxed)
|
||||||
|
box(primitive)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loadObjectConstant(value, boxed)
|
||||||
|
|
||||||
|
if (!mustBeBoxed)
|
||||||
|
unboxTo(primitiveMask)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
|
||||||
|
* provided.
|
||||||
|
*/
|
||||||
|
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||||
|
load(invokeArgumentsVar, MAP_TYPE)
|
||||||
|
aconst(name)
|
||||||
|
|
||||||
|
if (defaultValue != null)
|
||||||
|
loadTConstant(defaultValue)
|
||||||
|
|
||||||
|
invokestatic(
|
||||||
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
|
"getOrFail",
|
||||||
|
|
||||||
|
Type.getMethodDescriptor(
|
||||||
|
OBJECT_TYPE,
|
||||||
|
MAP_TYPE,
|
||||||
|
OBJECT_TYPE,
|
||||||
|
*OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
|
||||||
|
false
|
||||||
|
)
|
||||||
|
|
||||||
|
checkcast(tType)
|
||||||
|
val expectedType = expectationStack.pop()
|
||||||
|
|
||||||
|
if (expectedType.sort == Type.OBJECT)
|
||||||
|
typeStack.push(tType)
|
||||||
|
else {
|
||||||
|
unboxTo(primitiveMask)
|
||||||
|
typeStack.push(primitiveMask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads algebra from according field of the class and casts it to class of [algebra] provided.
|
||||||
|
*/
|
||||||
|
internal fun loadAlgebra() {
|
||||||
|
loadThis()
|
||||||
|
invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
|
||||||
|
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
|
||||||
|
* called before the arguments and this operation.
|
||||||
|
*
|
||||||
|
* The result is casted to [T] automatically.
|
||||||
|
*/
|
||||||
|
internal fun invokeAlgebraOperation(
|
||||||
|
owner: String,
|
||||||
|
method: String,
|
||||||
|
descriptor: String,
|
||||||
|
expectedArity: Int,
|
||||||
|
opcode: Int = INVOKEINTERFACE
|
||||||
|
) {
|
||||||
|
run loop@{
|
||||||
|
repeat(expectedArity) {
|
||||||
|
if (typeStack.isEmpty()) return@loop
|
||||||
|
typeStack.pop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
invokeMethodVisitor.visitMethodInsn(
|
||||||
|
opcode,
|
||||||
|
owner,
|
||||||
|
method,
|
||||||
|
descriptor,
|
||||||
|
opcode == INVOKEINTERFACE
|
||||||
|
)
|
||||||
|
|
||||||
|
invokeMethodVisitor.checkcast(tType)
|
||||||
|
val isLastExpr = expectationStack.size == 1
|
||||||
|
val expectedType = expectationStack.pop()
|
||||||
|
|
||||||
|
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
||||||
|
typeStack.push(tType)
|
||||||
|
else {
|
||||||
|
unboxTo(primitiveMask)
|
||||||
|
typeStack.push(primitiveMask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes a LDC Instruction with string constant provided.
|
||||||
|
*/
|
||||||
|
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
||||||
|
|
||||||
|
internal companion object {
|
||||||
|
/**
|
||||||
|
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
||||||
|
*/
|
||||||
|
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
|
||||||
|
hashMapOf(
|
||||||
|
java.lang.Byte::class to Type.BYTE_TYPE,
|
||||||
|
java.lang.Short::class to Type.SHORT_TYPE,
|
||||||
|
java.lang.Integer::class to Type.INT_TYPE,
|
||||||
|
java.lang.Long::class to Type.LONG_TYPE,
|
||||||
|
java.lang.Float::class to Type.FLOAT_TYPE,
|
||||||
|
java.lang.Double::class to Type.DOUBLE_TYPE
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||||
|
*/
|
||||||
|
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||||
|
*/
|
||||||
|
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
|
||||||
|
BOXED_TO_PRIMITIVES.entries.stream().collect(
|
||||||
|
Collectors.toMap(
|
||||||
|
Map.Entry<Type, Type>::value,
|
||||||
|
Map.Entry<Type, Type>::key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps primitive ASM types to [Number] functions unboxing them.
|
||||||
|
*/
|
||||||
|
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
|
||||||
|
hashMapOf(
|
||||||
|
Type.BYTE_TYPE to "byteValue",
|
||||||
|
Type.SHORT_TYPE to "shortValue",
|
||||||
|
Type.INT_TYPE to "intValue",
|
||||||
|
Type.LONG_TYPE to "longValue",
|
||||||
|
Type.FLOAT_TYPE to "floatValue",
|
||||||
|
Type.DOUBLE_TYPE to "doubleValue"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
||||||
|
*/
|
||||||
|
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [Expression].
|
||||||
|
*/
|
||||||
|
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.lang.Number].
|
||||||
|
*/
|
||||||
|
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.util.Map].
|
||||||
|
*/
|
||||||
|
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.lang.Object].
|
||||||
|
*/
|
||||||
|
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for array of [java.lang.Object].
|
||||||
|
*/
|
||||||
|
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
||||||
|
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [Algebra].
|
||||||
|
*/
|
||||||
|
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.lang.String].
|
||||||
|
*/
|
||||||
|
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for MapIntrinsics.
|
||||||
|
*/
|
||||||
|
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
|
||||||
|
internal enum class MstType {
|
||||||
|
GENERAL,
|
||||||
|
NUMBER;
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun fromMst(mst: MST): MstType {
|
||||||
|
if (mst is MST.Numeric)
|
||||||
|
return NUMBER
|
||||||
|
|
||||||
|
return GENERAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,178 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import org.objectweb.asm.*
|
||||||
|
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
||||||
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import java.lang.reflect.Method
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||||
|
hashMapOf(
|
||||||
|
"+" to 2 to "add",
|
||||||
|
"*" to 2 to "multiply",
|
||||||
|
"/" to 2 to "divide",
|
||||||
|
"+" to 1 to "unaryPlus",
|
||||||
|
"-" to 1 to "unaryMinus",
|
||||||
|
"-" to 2 to "minus"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal val KClass<*>.asm: Type
|
||||||
|
get() = Type.getType(java)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
||||||
|
*/
|
||||||
|
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> =
|
||||||
|
if (predicate(this)) arrayOf(this) else emptyArray()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
||||||
|
*/
|
||||||
|
private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
|
||||||
|
*/
|
||||||
|
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
|
||||||
|
instructionAdapter().apply(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a [Label], then applies it to this visitor.
|
||||||
|
*/
|
||||||
|
internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
||||||
|
*
|
||||||
|
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
|
||||||
|
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
||||||
|
*/
|
||||||
|
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
||||||
|
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||||
|
|
||||||
|
try {
|
||||||
|
Class.forName(name)
|
||||||
|
} catch (ignored: ClassNotFoundException) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildName(mst, collision + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("FunctionName")
|
||||||
|
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
|
||||||
|
ClassWriter(flags).apply(block)
|
||||||
|
|
||||||
|
internal inline fun ClassWriter.visitField(
|
||||||
|
access: Int,
|
||||||
|
name: String,
|
||||||
|
descriptor: String,
|
||||||
|
signature: String?,
|
||||||
|
value: Any?,
|
||||||
|
block: FieldVisitor.() -> Unit
|
||||||
|
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
||||||
|
|
||||||
|
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
|
||||||
|
context.javaClass.methods.find { method ->
|
||||||
|
val nameValid = method.name == name
|
||||||
|
val arityValid = method.parameters.size == parameterTypes.size
|
||||||
|
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
|
||||||
|
|
||||||
|
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
|
||||||
|
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameValid && arityValid && notBridgeInPrimitive && paramsValid
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
|
||||||
|
* type expectation stack for needed arity.
|
||||||
|
*
|
||||||
|
* @return `true` if contains, else `false`.
|
||||||
|
*/
|
||||||
|
private fun <T> AsmBuilder<T>.buildExpectationStack(
|
||||||
|
context: Algebra<T>,
|
||||||
|
name: String,
|
||||||
|
parameterTypes: Array<MstType>
|
||||||
|
): Boolean {
|
||||||
|
val arity = parameterTypes.size
|
||||||
|
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
|
||||||
|
|
||||||
|
if (specific != null)
|
||||||
|
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
||||||
|
else
|
||||||
|
repeat(arity) { expectationStack.push(tType) }
|
||||||
|
|
||||||
|
return specific != null
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
|
||||||
|
.parameterTypes
|
||||||
|
.zip(parameterTypes)
|
||||||
|
.map { (type, mstType) ->
|
||||||
|
when {
|
||||||
|
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
|
||||||
|
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
|
||||||
|
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
||||||
|
*
|
||||||
|
* @return `true` if contains, else `false`.
|
||||||
|
*/
|
||||||
|
private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
||||||
|
context: Algebra<T>,
|
||||||
|
name: String,
|
||||||
|
parameterTypes: Array<MstType>
|
||||||
|
): Boolean {
|
||||||
|
val arity = parameterTypes.size
|
||||||
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
|
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
||||||
|
val owner = context::class.asm
|
||||||
|
|
||||||
|
invokeAlgebraOperation(
|
||||||
|
owner = owner.internalName,
|
||||||
|
method = theName,
|
||||||
|
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
|
||||||
|
expectedArity = arity,
|
||||||
|
opcode = INVOKEVIRTUAL
|
||||||
|
)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
||||||
|
*/
|
||||||
|
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||||
|
context: Algebra<T>,
|
||||||
|
name: String,
|
||||||
|
fallbackMethodName: String,
|
||||||
|
parameterTypes: Array<MstType>,
|
||||||
|
parameters: AsmBuilder<T>.() -> Unit
|
||||||
|
) {
|
||||||
|
val arity = parameterTypes.size
|
||||||
|
loadAlgebra()
|
||||||
|
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
|
||||||
|
parameters()
|
||||||
|
|
||||||
|
if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
|
||||||
|
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||||
|
method = fallbackMethodName,
|
||||||
|
|
||||||
|
descriptor = Type.getMethodDescriptor(
|
||||||
|
AsmBuilder.OBJECT_TYPE,
|
||||||
|
AsmBuilder.STRING_TYPE,
|
||||||
|
*Array(arity) { AsmBuilder.OBJECT_TYPE }
|
||||||
|
),
|
||||||
|
|
||||||
|
expectedArity = arity
|
||||||
|
)
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
@file:JvmName("MapIntrinsics")
|
||||||
|
|
||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
@JvmOverloads
|
||||||
|
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
||||||
|
this[key] ?: default ?: error("Parameter not found: $key")
|
@ -0,0 +1,110 @@
|
|||||||
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.ast.mstInRing
|
||||||
|
import scientifik.kmath.ast.mstInSpace
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.ByteRing
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class TestAsmAlgebras {
|
||||||
|
@Test
|
||||||
|
fun space() {
|
||||||
|
val res1 = ByteRing.mstInSpace {
|
||||||
|
binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||||
|
add(number(1), number(1)),
|
||||||
|
2
|
||||||
|
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) + symbol("x") + zero
|
||||||
|
}("x" to 2.toByte())
|
||||||
|
|
||||||
|
val res2 = ByteRing.mstInSpace {
|
||||||
|
binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||||
|
add(number(1), number(1)),
|
||||||
|
2
|
||||||
|
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) + symbol("x") + zero
|
||||||
|
}.compile()("x" to 2.toByte())
|
||||||
|
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun ring() {
|
||||||
|
val res1 = ByteRing.mstInRing {
|
||||||
|
binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
(symbol("x") - (2.toByte() + (multiply(
|
||||||
|
add(number(1), number(1)),
|
||||||
|
2
|
||||||
|
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) * number(2)
|
||||||
|
}("x" to 3.toByte())
|
||||||
|
|
||||||
|
val res2 = ByteRing.mstInRing {
|
||||||
|
binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
(symbol("x") - (2.toByte() + (multiply(
|
||||||
|
add(number(1), number(1)),
|
||||||
|
2
|
||||||
|
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) * number(2)
|
||||||
|
}.compile()("x" to 3.toByte())
|
||||||
|
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun field() {
|
||||||
|
val res1 = RealField.mstInField {
|
||||||
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||||
|
"+",
|
||||||
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
|
+ number(1),
|
||||||
|
number(1) / 2 + number(2.0) * one
|
||||||
|
) + zero
|
||||||
|
}("x" to 2.0)
|
||||||
|
|
||||||
|
val res2 = RealField.mstInField {
|
||||||
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||||
|
"+",
|
||||||
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
|
+ number(1),
|
||||||
|
number(1) / 2 + number(2.0) * one
|
||||||
|
) + zero
|
||||||
|
}.compile()("x" to 2.0)
|
||||||
|
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.ast.mstInSpace
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class TestAsmExpressions {
|
||||||
|
@Test
|
||||||
|
fun testUnaryOperationInvocation() {
|
||||||
|
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||||
|
val res = expression("x" to 2.0)
|
||||||
|
assertEquals(-2.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBinaryOperationInvocation() {
|
||||||
|
val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile()
|
||||||
|
val res = expression("x" to 2.0)
|
||||||
|
assertEquals(-1.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testConstProductInvocation() {
|
||||||
|
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
||||||
|
assertEquals(4.0, res)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,55 @@
|
|||||||
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class TestAsmSpecialization {
|
||||||
|
@Test
|
||||||
|
fun testUnaryPlus() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
||||||
|
assertEquals(2.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testUnaryMinus() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
||||||
|
assertEquals(-2.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAdd() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
||||||
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSine() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
|
||||||
|
assertEquals(0.0, expr("x" to 0.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
||||||
|
assertEquals(0.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivide() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||||
|
assertEquals(1.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPower() {
|
||||||
|
val expr = RealField
|
||||||
|
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
|
||||||
|
.compile()
|
||||||
|
|
||||||
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,22 @@
|
|||||||
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.mstInRing
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.ByteRing
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFailsWith
|
||||||
|
|
||||||
|
internal class TestAsmVariables {
|
||||||
|
@Test
|
||||||
|
fun testVariableWithoutDefault() {
|
||||||
|
val expr = ByteRing.mstInRing { symbol("x") }
|
||||||
|
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testVariableWithoutDefaultFails() {
|
||||||
|
val expr = ByteRing.mstInRing { symbol("x") }
|
||||||
|
assertFailsWith<IllegalStateException> { expr() }
|
||||||
|
}
|
||||||
|
}
|
25
kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt
Normal file
25
kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.asm.expression
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.ast.parseMath
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Complex
|
||||||
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class AsmTest {
|
||||||
|
@Test
|
||||||
|
fun `compile MST`() {
|
||||||
|
val res = ComplexField.expression("2+2*(2+2)".parseMath())()
|
||||||
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `compile MSTExpression`() {
|
||||||
|
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()()
|
||||||
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,36 @@
|
|||||||
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.evaluate
|
||||||
|
import scientifik.kmath.ast.parseMath
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class ParserPrecedenceTest {
|
||||||
|
private val f: Field<Double> = RealField
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
|
||||||
|
}
|
@ -0,0 +1,60 @@
|
|||||||
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.evaluate
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.ast.parseMath
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import scientifik.kmath.operations.Complex
|
||||||
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class ParserTest {
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST`() {
|
||||||
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
|
val res = ComplexField.evaluate(mst)
|
||||||
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MSTExpression`() {
|
||||||
|
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||||
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with singular`() {
|
||||||
|
val mst = "i".parseMath()
|
||||||
|
val res = ComplexField.evaluate(mst)
|
||||||
|
assertEquals(ComplexField.i, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with unary function`() {
|
||||||
|
val mst = "sin(0)".parseMath()
|
||||||
|
val res = RealField.evaluate(mst)
|
||||||
|
assertEquals(0.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with binary function`() {
|
||||||
|
val magicalAlgebra = object : Algebra<String> {
|
||||||
|
override fun symbol(value: String): String = value
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||||
|
"magic" -> "$left ★ $right"
|
||||||
|
else -> throw NotImplementedError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val mst = "magic(a, b)".parseMath()
|
||||||
|
val res = magicalAlgebra.evaluate(mst)
|
||||||
|
assertEquals("a ★ b", res)
|
||||||
|
}
|
||||||
|
}
|
@ -2,7 +2,7 @@ package scientifik.kmath.commons.expressions
|
|||||||
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.expressions.ExpressionContext
|
import scientifik.kmath.expressions.ExpressionAlgebra
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedField
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
@ -59,8 +59,10 @@ class DerivativeStructureField(
|
|||||||
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
||||||
|
|
||||||
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
||||||
|
|
||||||
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
||||||
|
override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
|
||||||
|
override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
|
||||||
|
override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
|
||||||
|
|
||||||
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
|
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
|
||||||
is Double -> arg.pow(pow)
|
is Double -> arg.pow(pow)
|
||||||
@ -74,10 +76,10 @@ class DerivativeStructureField(
|
|||||||
|
|
||||||
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
||||||
|
|
||||||
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble())
|
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
||||||
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble())
|
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||||
operator fun Number.plus(s: DerivativeStructure) = s + this
|
override operator fun Number.plus(b: DerivativeStructure) = b + this
|
||||||
operator fun Number.minus(s: DerivativeStructure) = s - this
|
override operator fun Number.minus(b: DerivativeStructure) = b - this
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -113,7 +115,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
|||||||
/**
|
/**
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||||
*/
|
*/
|
||||||
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
||||||
override fun variable(name: String, default: Double?) =
|
override fun variable(name: String, default: Double?) =
|
||||||
DiffExpression { variable(name, default?.const()) }
|
DiffExpression { variable(name, default?.const()) }
|
||||||
|
|
||||||
@ -136,6 +138,3 @@ object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression>
|
|||||||
override fun divide(a: DiffExpression, b: DiffExpression) =
|
override fun divide(a: DiffExpression, b: DiffExpression) =
|
||||||
DiffExpression { a.function(this) / b.function(this) }
|
DiffExpression { a.function(this) / b.function(this) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import org.apache.commons.math3.linear.RealMatrix
|
|||||||
import org.apache.commons.math3.linear.RealVector
|
import org.apache.commons.math3.linear.RealVector
|
||||||
import scientifik.kmath.linear.*
|
import scientifik.kmath.linear.*
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
||||||
FeaturedMatrix<Double> {
|
FeaturedMatrix<Double> {
|
||||||
@ -19,6 +20,16 @@ class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
|||||||
CMMatrix(origin, this.features + features)
|
CMMatrix(origin, this.features + features)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = origin.hashCode()
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
package scientifik.kmath.commons.prob
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator
|
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
|
||||||
|
|
||||||
inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator {
|
|
||||||
override fun nextDouble(): Double = generator.nextDouble()
|
|
||||||
|
|
||||||
override fun nextInt(): Int = generator.nextInt()
|
|
||||||
|
|
||||||
override fun nextLong(): Long = generator.nextLong()
|
|
||||||
|
|
||||||
override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) }
|
|
||||||
|
|
||||||
override fun fork(): RandomGenerator {
|
|
||||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun CMRandom.asKmathGenerator(): RandomGenerator = CMRandomGeneratorWrapper(this)
|
|
||||||
|
|
||||||
fun RandomGenerator.asCMGenerator(): CMRandom =
|
|
||||||
(this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper")
|
|
||||||
|
|
||||||
val RandomGenerator.Companion.default: RandomGenerator by lazy { JDKRandomGenerator().asKmathGenerator() }
|
|
||||||
|
|
||||||
fun RandomGenerator.Companion.jdk(seed: Int? = null): RandomGenerator = if (seed == null) {
|
|
||||||
JDKRandomGenerator()
|
|
||||||
} else {
|
|
||||||
JDKRandomGenerator(seed)
|
|
||||||
}.asKmathGenerator()
|
|
@ -1,82 +0,0 @@
|
|||||||
package scientifik.kmath.commons.prob
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.*
|
|
||||||
import scientifik.kmath.prob.Distribution
|
|
||||||
import scientifik.kmath.prob.RandomChain
|
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
|
||||||
import scientifik.kmath.prob.UnivariateDistribution
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
|
||||||
|
|
||||||
class CMRealDistributionWrapper(val builder: (CMRandom?) -> RealDistribution) : UnivariateDistribution<Double> {
|
|
||||||
|
|
||||||
private val defaultDistribution by lazy { builder(null) }
|
|
||||||
|
|
||||||
override fun probability(arg: Double): Double = defaultDistribution.probability(arg)
|
|
||||||
|
|
||||||
override fun cumulative(arg: Double): Double = defaultDistribution.cumulativeProbability(arg)
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): RandomChain<Double> {
|
|
||||||
val distribution = builder(generator.asCMGenerator())
|
|
||||||
return RandomChain(generator) { distribution.sample() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class CMIntDistributionWrapper(val builder: (CMRandom?) -> IntegerDistribution) : UnivariateDistribution<Int> {
|
|
||||||
|
|
||||||
private val defaultDistribution by lazy { builder(null) }
|
|
||||||
|
|
||||||
override fun probability(arg: Int): Double = defaultDistribution.probability(arg)
|
|
||||||
|
|
||||||
override fun cumulative(arg: Int): Double = defaultDistribution.cumulativeProbability(arg)
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): RandomChain<Int> {
|
|
||||||
val distribution = builder(generator.asCMGenerator())
|
|
||||||
return RandomChain(generator) { distribution.sample() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fun Distribution.Companion.normal(mean: Double = 0.0, sigma: Double = 1.0): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator -> NormalDistribution(generator, mean, sigma) }
|
|
||||||
|
|
||||||
fun Distribution.Companion.poisson(mean: Double): UnivariateDistribution<Int> = CMIntDistributionWrapper { generator ->
|
|
||||||
PoissonDistribution(
|
|
||||||
generator,
|
|
||||||
mean,
|
|
||||||
PoissonDistribution.DEFAULT_EPSILON,
|
|
||||||
PoissonDistribution.DEFAULT_MAX_ITERATIONS
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.binomial(trials: Int, p: Double): UnivariateDistribution<Int> =
|
|
||||||
CMIntDistributionWrapper { generator ->
|
|
||||||
BinomialDistribution(generator, trials, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.student(degreesOfFreedom: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
TDistribution(generator, degreesOfFreedom, TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.chi2(degreesOfFreedom: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
ChiSquaredDistribution(generator, degreesOfFreedom)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.fisher(
|
|
||||||
numeratorDegreesOfFreedom: Double,
|
|
||||||
denominatorDegreesOfFreedom: Double
|
|
||||||
): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
FDistribution(generator, numeratorDegreesOfFreedom, denominatorDegreesOfFreedom)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.exponential(mean: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
ExponentialDistribution(generator, mean)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.uniform(a: Double, b: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
UniformRealDistribution(generator, a, b)
|
|
||||||
}
|
|
@ -0,0 +1,38 @@
|
|||||||
|
package scientifik.kmath.commons.random
|
||||||
|
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
|
||||||
|
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
|
||||||
|
org.apache.commons.math3.random.RandomGenerator {
|
||||||
|
private var generator = factory(intArrayOf())
|
||||||
|
|
||||||
|
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||||
|
|
||||||
|
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||||
|
|
||||||
|
override fun setSeed(seed: Int) {
|
||||||
|
generator = factory(intArrayOf(seed))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setSeed(seed: IntArray) {
|
||||||
|
generator = factory(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setSeed(seed: Long) {
|
||||||
|
setSeed(seed.toInt())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nextBytes(bytes: ByteArray) {
|
||||||
|
generator.fillBytes(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nextInt(): Int = generator.nextInt()
|
||||||
|
|
||||||
|
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
||||||
|
|
||||||
|
override fun nextGaussian(): Double = TODO()
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = generator.nextDouble()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = generator.nextLong()
|
||||||
|
}
|
@ -18,7 +18,7 @@ object Transformations {
|
|||||||
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
|
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
|
||||||
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
|
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
|
||||||
|
|
||||||
private fun Buffer<Double>.asArray() = if (this is DoubleBuffer) {
|
private fun Buffer<Double>.asArray() = if (this is RealBuffer) {
|
||||||
array
|
array
|
||||||
} else {
|
} else {
|
||||||
DoubleArray(size) { i -> get(i) }
|
DoubleArray(size) { i -> get(i) }
|
||||||
|
40
kmath-core/README.md
Normal file
40
kmath-core/README.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# The Core Module (`kmath-ast`)
|
||||||
|
|
||||||
|
The core features of KMath:
|
||||||
|
|
||||||
|
- Algebraic structures: contexts and elements.
|
||||||
|
- ND structures.
|
||||||
|
- Buffers.
|
||||||
|
- Functional Expressions.
|
||||||
|
- Domains.
|
||||||
|
- Automatic differentiation.
|
||||||
|
|
||||||
|
> #### Artifact:
|
||||||
|
> This module is distributed in the artifact `scientifik:kmath-core:0.1.4-dev-8`.
|
||||||
|
>
|
||||||
|
> **Gradle:**
|
||||||
|
>
|
||||||
|
> ```gradle
|
||||||
|
> repositories {
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
|
> maven { url https://dl.bintray.com/hotkeytlt/maven' }
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {
|
||||||
|
> implementation 'scientifik:kmath-core:0.1.4-dev-8'
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
> **Gradle Kotlin DSL:**
|
||||||
|
>
|
||||||
|
> ```kotlin
|
||||||
|
> repositories {
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/scientifik")
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {``
|
||||||
|
> implementation("scientifik:kmath-core:0.1.4-dev-8")
|
||||||
|
> }
|
||||||
|
> ```
|
@ -1,11 +1,7 @@
|
|||||||
plugins {
|
plugins { id("scientifik.mpp") }
|
||||||
id("scientifik.mpp")
|
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies { api(project(":kmath-memory")) }
|
||||||
api(project(":kmath-memory"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple geometric domain.
|
||||||
|
*
|
||||||
|
* @param T the type of element of this domain.
|
||||||
|
*/
|
||||||
|
interface Domain<T : Any> {
|
||||||
|
/**
|
||||||
|
* Checks if the specified point is contained in this domain.
|
||||||
|
*/
|
||||||
|
operator fun contains(point: Point<T>): Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of hyperspace dimensions.
|
||||||
|
*/
|
||||||
|
val dimension: Int
|
||||||
|
}
|
@ -0,0 +1,68 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Alexander Nozik.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.structures.RealBuffer
|
||||||
|
import scientifik.kmath.structures.indices
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* HyperSquareDomain class.
|
||||||
|
*
|
||||||
|
* @author Alexander Nozik
|
||||||
|
*/
|
||||||
|
class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
|
||||||
|
|
||||||
|
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
|
||||||
|
point[i] in lower[i]..upper[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val dimension: Int get() = lower.size
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int): Double? = lower[num]
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int): Double? = upper[num]
|
||||||
|
|
||||||
|
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||||
|
val res = DoubleArray(point.size) { i ->
|
||||||
|
when {
|
||||||
|
point[i] < lower[i] -> lower[i]
|
||||||
|
point[i] > upper[i] -> upper[i]
|
||||||
|
else -> point[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return RealBuffer(*res)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun volume(): Double {
|
||||||
|
var res = 1.0
|
||||||
|
for (i in 0 until dimension) {
|
||||||
|
if (lower[i].isInfinite() || upper[i].isInfinite()) {
|
||||||
|
return Double.POSITIVE_INFINITY
|
||||||
|
}
|
||||||
|
if (upper[i] > lower[i]) {
|
||||||
|
res *= upper[i] - lower[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,63 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Alexander Nozik.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
|
/**
|
||||||
|
* n-dimensional volume
|
||||||
|
*
|
||||||
|
* @author Alexander Nozik
|
||||||
|
*/
|
||||||
|
interface RealDomain : Domain<Double> {
|
||||||
|
fun nearestInDomain(point: Point<Double>): Point<Double>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The lower edge for the domain going down from point
|
||||||
|
* @param num
|
||||||
|
* @param point
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getLowerBound(num: Int, point: Point<Double>): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The upper edge of the domain going up from point
|
||||||
|
* @param num
|
||||||
|
* @param point
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getUpperBound(num: Int, point: Point<Double>): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Global lower edge
|
||||||
|
* @param num
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getLowerBound(num: Int): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Global upper edge
|
||||||
|
* @param num
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getUpperBound(num: Int): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hyper volume
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun volume(): Double
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Alexander Nozik.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
|
class UnconstrainedDomain(override val dimension: Int) : RealDomain {
|
||||||
|
override operator fun contains(point: Point<Double>): Boolean = true
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
|
||||||
|
|
||||||
|
override fun nearestInDomain(point: Point<Double>): Point<Double> = point
|
||||||
|
|
||||||
|
override fun volume(): Double = Double.POSITIVE_INFINITY
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
|
||||||
|
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
|
||||||
|
operator fun contains(d: Double): Boolean = range.contains(d)
|
||||||
|
|
||||||
|
override operator fun contains(point: Point<Double>): Boolean {
|
||||||
|
require(point.size == 0)
|
||||||
|
return contains(point[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||||
|
require(point.size == 1)
|
||||||
|
val value = point[0]
|
||||||
|
return when {
|
||||||
|
value in range -> point
|
||||||
|
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
|
||||||
|
else -> doubleArrayOf(range.start).asBuffer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int, point: Point<Double>): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.start
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int, point: Point<Double>): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.endInclusive
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.start
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.endInclusive
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun volume(): Double = range.endInclusive - range.start
|
||||||
|
|
||||||
|
override val dimension: Int get() = 1
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.ExtendedField
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a functional expression with this [Space].
|
||||||
|
*/
|
||||||
|
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionSpace(this).run(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a functional expression with this [Ring].
|
||||||
|
*/
|
||||||
|
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionRing(this).run(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a functional expression with this [Field].
|
||||||
|
*/
|
||||||
|
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionField(this).run(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a functional expression with this [ExtendedField].
|
||||||
|
*/
|
||||||
|
fun <T> ExtendedField<T>.fieldExpression(
|
||||||
|
block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>
|
||||||
|
): Expression<T> = FunctionalExpressionExtendedField(this).run(block)
|
@ -1,92 +1,49 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Ring
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments
|
||||||
*/
|
*/
|
||||||
interface Expression<T> {
|
interface Expression<T> {
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param arguments the map of arguments.
|
||||||
|
* @return the value.
|
||||||
|
*/
|
||||||
operator fun invoke(arguments: Map<String, T>): T
|
operator fun invoke(arguments: Map<String, T>): T
|
||||||
|
|
||||||
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create simple lazily evaluated expression inside given algebra
|
||||||
|
*/
|
||||||
|
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> =
|
||||||
|
object : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = block(arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param pairs the pair of arguments' names to values.
|
||||||
|
* @return the value.
|
||||||
|
*/
|
||||||
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
*/
|
*/
|
||||||
interface ExpressionContext<T> {
|
interface ExpressionAlgebra<T, E> : Algebra<E> {
|
||||||
/**
|
/**
|
||||||
* Introduce a variable into expression context
|
* Introduce a variable into expression context
|
||||||
*/
|
*/
|
||||||
fun variable(name: String, default: T? = null): Expression<T>
|
fun variable(name: String, default: T? = null): E
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constant expression which does not depend on arguments
|
* A constant expression which does not depend on arguments
|
||||||
*/
|
*/
|
||||||
fun const(value: T): Expression<T>
|
fun const(value: T): E
|
||||||
}
|
|
||||||
|
|
||||||
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = value
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>, val second: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
|
||||||
Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
open class ExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionContext<T> {
|
|
||||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
|
||||||
|
|
||||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
|
||||||
|
|
||||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
|
||||||
|
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
|
||||||
|
|
||||||
|
|
||||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
|
||||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
|
||||||
|
|
||||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
|
||||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionSpace<T>(field) {
|
|
||||||
override val one: Expression<T> = ConstantExpression(field.one)
|
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
|
||||||
|
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
|
||||||
|
|
||||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
|
||||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
|
||||||
|
|
||||||
operator fun T.times(arg: Expression<T>) = arg * this
|
|
||||||
operator fun T.div(arg: Expression<T>) = arg / this
|
|
||||||
}
|
}
|
@ -0,0 +1,175 @@
|
|||||||
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
||||||
|
Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalBinaryOperation<T>(
|
||||||
|
val context: Algebra<T>,
|
||||||
|
val name: String,
|
||||||
|
val first: Expression<T>,
|
||||||
|
val second: Expression<T>
|
||||||
|
) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T =
|
||||||
|
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T =
|
||||||
|
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = value
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalConstProductExpression<T>(
|
||||||
|
val context: Space<T>,
|
||||||
|
private val expr: Expression<T>,
|
||||||
|
val const: Number
|
||||||
|
) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [Expression] construction.
|
||||||
|
*
|
||||||
|
* @param algebra The algebra to provide for Expressions built.
|
||||||
|
*/
|
||||||
|
abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) : ExpressionAlgebra<T, Expression<T>> {
|
||||||
|
/**
|
||||||
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
|
*/
|
||||||
|
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression to access a variable.
|
||||||
|
*/
|
||||||
|
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
|
*/
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
|
*/
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
FunctionalUnaryOperation(algebra, operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [Expression] construction for [Space] algebras.
|
||||||
|
*/
|
||||||
|
open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||||
|
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||||
|
override val zero: Expression<T> get() = const(algebra.zero)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of addition of two another expressions.
|
||||||
|
*/
|
||||||
|
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of multiplication of expression by number.
|
||||||
|
*/
|
||||||
|
override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
||||||
|
FunctionalConstProductExpression(algebra, a, k)
|
||||||
|
|
||||||
|
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
|
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||||
|
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||||
|
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||||
|
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||||
|
override val one: Expression<T>
|
||||||
|
get() = const(algebra.one)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of multiplication of two expressions.
|
||||||
|
*/
|
||||||
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
|
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||||
|
FunctionalExpressionRing<T, A>(algebra),
|
||||||
|
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
||||||
|
/**
|
||||||
|
* Builds an Expression of division an expression by another one.
|
||||||
|
*/
|
||||||
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
|
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||||
|
FunctionalExpressionField<T, A>(algebra),
|
||||||
|
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||||
|
override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||||
|
override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun asin(arg: Expression<T>): Expression<T> =
|
||||||
|
unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun acos(arg: Expression<T>): Expression<T> =
|
||||||
|
unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun atan(arg: Expression<T>): Expression<T> =
|
||||||
|
unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
||||||
|
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||||
|
|
||||||
|
override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||||
|
override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionField>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionSpace(this).block()
|
||||||
|
|
||||||
|
inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T, A>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionRing(this).block()
|
||||||
|
|
||||||
|
inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionField(this).block()
|
||||||
|
|
||||||
|
inline fun <T, A : ExtendedField<T>> A.expressionInExtendedField(block: FunctionalExpressionExtendedField<T, A>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionExtendedField(this).block()
|
@ -19,22 +19,20 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
|
|
||||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||||
|
|
||||||
companion object {
|
companion object
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||||
|
|
||||||
override val elementContext = RealField
|
override val elementContext: RealField get() = RealField
|
||||||
|
|
||||||
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||||
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
return BufferMatrix(rows, columns, buffer)
|
return BufferMatrix(rows, columns, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer)
|
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size, initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
class BufferMatrix<T : Any>(
|
class BufferMatrix<T : Any>(
|
||||||
@ -52,7 +50,7 @@ class BufferMatrix<T : Any>(
|
|||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
||||||
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
||||||
|
|
||||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||||
@ -84,8 +82,8 @@ class BufferMatrix<T : Any>(
|
|||||||
override fun toString(): String {
|
override fun toString(): String {
|
||||||
return if (rowNum <= 5 && colNum <= 5) {
|
return if (rowNum <= 5 && colNum <= 5) {
|
||||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
||||||
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") {
|
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
|
||||||
it.asSequence().joinToString(separator = "\t") { it.toString() }
|
buffer.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
||||||
@ -101,8 +99,15 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
|
|||||||
|
|
||||||
val array = DoubleArray(this.rowNum * other.colNum)
|
val array = DoubleArray(this.rowNum * other.colNum)
|
||||||
|
|
||||||
val a = this.buffer.array
|
//convert to array to insure there is not memory indirection
|
||||||
val b = other.buffer.array
|
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is RealBuffer) {
|
||||||
|
array
|
||||||
|
} else {
|
||||||
|
DoubleArray(size) { get(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
val a = this.buffer.unsafeArray()
|
||||||
|
val b = other.buffer.unsafeArray()
|
||||||
|
|
||||||
for (i in (0 until rowNum)) {
|
for (i in (0 until rowNum)) {
|
||||||
for (j in (0 until other.colNum)) {
|
for (j in (0 until other.colNum)) {
|
||||||
@ -112,6 +117,6 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val buffer = DoubleBuffer(array)
|
val buffer = RealBuffer(array)
|
||||||
return BufferMatrix(rowNum, other.colNum, buffer)
|
return BufferMatrix(rowNum, other.colNum, buffer)
|
||||||
}
|
}
|
@ -23,12 +23,10 @@ interface FeaturedMatrix<T : Any> : Matrix<T> {
|
|||||||
*/
|
*/
|
||||||
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
||||||
|
|
||||||
companion object {
|
companion object
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
|
||||||
MatrixContext.real.produce(rows, columns, initializer)
|
MatrixContext.real.produce(rows, columns, initializer)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -41,7 +39,7 @@ fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T
|
|||||||
return BufferMatrix(size, size, buffer)
|
return BufferMatrix(size, size, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet()
|
val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if matrix has the given feature class
|
* Check if matrix has the given feature class
|
||||||
@ -68,7 +66,7 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
|
|||||||
* A virtual matrix of zeroes
|
* A virtual matrix of zeroes
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||||
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
|
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero }
|
||||||
|
|
||||||
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class LUPDecomposition<T : Any>(
|
|||||||
private val even: Boolean
|
private val even: Boolean
|
||||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||||
|
|
||||||
val elementContext get() = context.elementContext
|
val elementContext: Field<T> get() = context.elementContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the matrix L of the decomposition.
|
* Returns the matrix L of the decomposition.
|
||||||
@ -67,7 +67,7 @@ class LUPDecomposition<T : Any>(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
|
||||||
if (value > elementContext.zero) value else with(elementContext) { -value }
|
if (value > elementContext.zero) value else with(elementContext) { -value }
|
||||||
|
|
||||||
|
|
||||||
@ -128,14 +128,14 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
luRow[col] = sum
|
luRow[col] = sum
|
||||||
|
|
||||||
// maintain best permutation choice
|
// maintain best permutation choice
|
||||||
if (abs(sum) > largest) {
|
if (this@lup.abs(sum) > largest) {
|
||||||
largest = abs(sum)
|
largest = this@lup.abs(sum)
|
||||||
max = row
|
max = row
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Singularity check
|
// Singularity check
|
||||||
if (checkSingular(abs(lu[max, col]))) {
|
if (checkSingular(this@lup.abs(lu[max, col]))) {
|
||||||
error("The matrix is singular")
|
error("The matrix is singular")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,9 +169,10 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
|||||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
noinline checkSingular: (T) -> Boolean
|
noinline checkSingular: (T) -> Boolean
|
||||||
) = lup(T::class, matrix, checkSingular)
|
): LUPDecomposition<T> = lup(T::class, matrix, checkSingular)
|
||||||
|
|
||||||
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 }
|
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
|
||||||
|
lup(Double::class, matrix) { it < 1e-11 }
|
||||||
|
|
||||||
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
||||||
|
|
||||||
@ -185,7 +186,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
|
|||||||
// Apply permutations to b
|
// Apply permutations to b
|
||||||
val bp = create { _, _ -> zero }
|
val bp = create { _, _ -> zero }
|
||||||
|
|
||||||
for (row in 0 until pivot.size) {
|
for (row in pivot.indices) {
|
||||||
val bpRow = bp.row(row)
|
val bpRow = bp.row(row)
|
||||||
val pRow = pivot[row]
|
val pRow = pivot[row]
|
||||||
for (col in 0 until matrix.colNum) {
|
for (col in 0 until matrix.colNum) {
|
||||||
@ -194,7 +195,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Solve LY = b
|
// Solve LY = b
|
||||||
for (col in 0 until pivot.size) {
|
for (col in pivot.indices) {
|
||||||
val bpCol = bp.row(col)
|
val bpCol = bp.row(col)
|
||||||
for (i in col + 1 until pivot.size) {
|
for (i in col + 1 until pivot.size) {
|
||||||
val bpI = bp.row(i)
|
val bpI = bp.row(i)
|
||||||
@ -225,7 +226,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix)
|
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> = solve(T::class, matrix)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Solve a linear equation **a*x = b**
|
* Solve a linear equation **a*x = b**
|
||||||
@ -240,13 +241,12 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
|
|||||||
return decomposition.solve(T::class, b)
|
return decomposition.solve(T::class, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) =
|
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> = solve(a, b) { it < 1e-11 }
|
||||||
solve(a, b) { it < 1e-11 }
|
|
||||||
|
|
||||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
noinline checkSingular: (T) -> Boolean
|
noinline checkSingular: (T) -> Boolean
|
||||||
) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||||
|
|
||||||
fun RealMatrixContext.inverse(matrix: Matrix<Double>) =
|
fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
|
||||||
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
@ -1,12 +1,8 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import scientifik.kmath.operations.Norm
|
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.VirtualBuffer
|
import scientifik.kmath.structures.VirtualBuffer
|
||||||
import scientifik.kmath.structures.asSequence
|
|
||||||
|
|
||||||
typealias Point<T> = Buffer<T>
|
typealias Point<T> = Buffer<T>
|
||||||
|
|
||||||
@ -19,8 +15,6 @@ interface LinearSolver<T : Any> {
|
|||||||
fun inverse(a: Matrix<T>): Matrix<T>
|
fun inverse(a: Matrix<T>): Matrix<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
typealias RealMatrix = Matrix<Double>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert matrix to vector if it is possible
|
* Convert matrix to vector if it is possible
|
||||||
*/
|
*/
|
||||||
@ -31,4 +25,4 @@ fun <T : Any> Matrix<T>.asPoint(): Point<T> =
|
|||||||
error("Can't convert matrix with more than one column to vector")
|
error("Can't convert matrix with more than one column to vector")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> Point<T>.asMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
||||||
|
@ -1,14 +1,46 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.Buffer
|
||||||
|
import scientifik.kmath.structures.BufferFactory
|
||||||
import scientifik.kmath.structures.Structure2D
|
import scientifik.kmath.structures.Structure2D
|
||||||
import scientifik.kmath.structures.asBuffer
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
|
||||||
class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
|
class MatrixBuilder(val rows: Int, val columns: Int) {
|
||||||
operator fun invoke(vararg elements: T): FeaturedMatrix<T> {
|
operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
|
||||||
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
||||||
val buffer = elements.asBuffer()
|
val buffer = elements.asBuffer()
|
||||||
return BufferMatrix(rows, columns, buffer)
|
return BufferMatrix(rows, columns, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO add specific matrix builder functions like diagonal, etc
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
|
fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns)
|
||||||
|
|
||||||
|
fun <T : Any> Structure2D.Companion.row(vararg values: T): FeaturedMatrix<T> {
|
||||||
|
val buffer = values.asBuffer()
|
||||||
|
return BufferMatrix(1, values.size, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Structure2D.Companion.row(
|
||||||
|
size: Int,
|
||||||
|
factory: BufferFactory<T> = Buffer.Companion::auto,
|
||||||
|
noinline builder: (Int) -> T
|
||||||
|
): FeaturedMatrix<T> {
|
||||||
|
val buffer = factory(size, builder)
|
||||||
|
return BufferMatrix(1, size, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T : Any> Structure2D.Companion.column(vararg values: T): FeaturedMatrix<T> {
|
||||||
|
val buffer = values.asBuffer()
|
||||||
|
return BufferMatrix(values.size, 1, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Structure2D.Companion.column(
|
||||||
|
size: Int,
|
||||||
|
factory: BufferFactory<T> = Buffer.Companion::auto,
|
||||||
|
noinline builder: (Int) -> T
|
||||||
|
): FeaturedMatrix<T> {
|
||||||
|
val buffer = factory(size, builder)
|
||||||
|
return BufferMatrix(size, 1, buffer)
|
||||||
|
}
|
||||||
|
@ -29,7 +29,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
|||||||
/**
|
/**
|
||||||
* Non-boxing double matrix
|
* Non-boxing double matrix
|
||||||
*/
|
*/
|
||||||
val real = RealMatrixContext
|
val real: RealMatrixContext = RealMatrixContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structured matrix with custom buffer
|
* A structured matrix with custom buffer
|
||||||
@ -82,12 +82,12 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.unaryMinus() =
|
override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||||
|
|
||||||
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]")
|
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]")
|
||||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) + b[i, j] } }
|
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||||
@ -96,7 +96,7 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||||
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) * k } }
|
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } }
|
||||||
|
|
||||||
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A marker interface representing some matrix feature like diagonal, sparce, zero, etc. Features used to optimize matrix
|
* A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix
|
||||||
* operations performance in some cases.
|
* operations performance in some cases.
|
||||||
*/
|
*/
|
||||||
interface MatrixFeature
|
interface MatrixFeature
|
||||||
@ -36,19 +36,19 @@ interface DeterminantFeature<T : Any> : MatrixFeature {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
fun <T: Any> DeterminantFeature(determinant: T) = object: DeterminantFeature<T>{
|
fun <T : Any> DeterminantFeature(determinant: T): DeterminantFeature<T> = object : DeterminantFeature<T> {
|
||||||
override val determinant: T = determinant
|
override val determinant: T = determinant
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lower triangular matrix
|
* Lower triangular matrix
|
||||||
*/
|
*/
|
||||||
object LFeature: MatrixFeature
|
object LFeature : MatrixFeature
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Upper triangular feature
|
* Upper triangular feature
|
||||||
*/
|
*/
|
||||||
object UFeature: MatrixFeature
|
object UFeature : MatrixFeature
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO add documentation
|
* TODO add documentation
|
||||||
|
@ -54,7 +54,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
|||||||
size: Int,
|
size: Int,
|
||||||
space: S,
|
space: S,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
) = BufferVectorSpace(size, space, bufferFactory)
|
): BufferVectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Automatic buffered vector, unboxed if it is possible
|
* Automatic buffered vector, unboxed if it is possible
|
||||||
@ -70,6 +70,6 @@ class BufferVectorSpace<T : Any, S : Space<T>>(
|
|||||||
override val space: S,
|
override val space: S,
|
||||||
val bufferFactory: BufferFactory<T>
|
val bufferFactory: BufferFactory<T>
|
||||||
) : VectorSpace<T, S> {
|
) : VectorSpace<T, S> {
|
||||||
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
|
override fun produce(initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
|
||||||
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
||||||
}
|
}
|
@ -20,7 +20,7 @@ class VirtualMatrix<T : Any>(
|
|||||||
|
|
||||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
override fun get(i: Int, j: Int): T = generator(i, j)
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
|
||||||
VirtualMatrix(rowNum, colNum, this.features + features, generator)
|
VirtualMatrix(rowNum, colNum, this.features + features, generator)
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
|
@ -22,12 +22,12 @@ class DerivationResult<T : Any>(
|
|||||||
val deriv: Map<Variable<T>, T>,
|
val deriv: Map<Variable<T>, T>,
|
||||||
val context: Field<T>
|
val context: Field<T>
|
||||||
) : Variable<T>(value) {
|
) : Variable<T>(value) {
|
||||||
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
|
fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* compute divergence
|
* compute divergence
|
||||||
*/
|
*/
|
||||||
fun div() = context.run { sum(deriv.values) }
|
fun div(): T = context.run { sum(deriv.values) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute a gradient for variables in given order
|
* Compute a gradient for variables in given order
|
||||||
@ -53,7 +53,7 @@ class DerivationResult<T : Any>(
|
|||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
|
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
|
||||||
AutoDiffContext<T, F>(this).run {
|
AutoDiffContext(this).run {
|
||||||
val result = body()
|
val result = body()
|
||||||
result.d = context.one// computing derivative w.r.t result
|
result.d = context.one// computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
@ -86,24 +86,24 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
|||||||
|
|
||||||
abstract fun variable(value: T): Variable<T>
|
abstract fun variable(value: T): Variable<T>
|
||||||
|
|
||||||
inline fun variable(block: F.() -> T) = variable(context.block())
|
inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
operator fun Number.plus(that: Variable<T>): Variable<T> =
|
override operator fun Number.plus(b: Variable<T>): Variable<T> =
|
||||||
derive(variable { this@plus.toDouble() * one + that.value }) { z ->
|
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
|
||||||
that.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||||
|
|
||||||
operator fun Number.minus(that: Variable<T>): Variable<T> =
|
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
||||||
derive(variable { this@minus.toDouble() * one - that.value }) { z ->
|
derive(variable { this@minus.toDouble() * one - b.value }) { z ->
|
||||||
that.d -= z.d
|
b.d -= z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Variable<T>.minus(that: Number): Variable<T> =
|
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
||||||
derive(variable { this@minus.value - one * that.toDouble() }) { z ->
|
derive(variable { this@minus.value - one * b.toDouble() }) { z ->
|
||||||
this@minus.d += z.d
|
this@minus.d += z.d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package scientifik.kmath.misc
|
package scientifik.kmath.misc
|
||||||
|
|
||||||
|
import kotlin.math.abs
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert double range to sequence.
|
* Convert double range to sequence.
|
||||||
*
|
*
|
||||||
@ -8,28 +10,36 @@ package scientifik.kmath.misc
|
|||||||
*
|
*
|
||||||
* If step is negative, the same goes from upper boundary downwards
|
* If step is negative, the same goes from upper boundary downwards
|
||||||
*/
|
*/
|
||||||
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> =
|
fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Sequence<Double> = when {
|
||||||
when {
|
step == 0.0 -> error("Zero step in double progression")
|
||||||
step == 0.0 -> error("Zero step in double progression")
|
step > 0 -> sequence {
|
||||||
step > 0 -> sequence {
|
var current = start
|
||||||
var current = start
|
while (current <= endInclusive) {
|
||||||
while (current <= endInclusive) {
|
yield(current)
|
||||||
yield(current)
|
current += step
|
||||||
current += step
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else -> sequence {
|
|
||||||
var current = endInclusive
|
|
||||||
while (current >= start) {
|
|
||||||
yield(current)
|
|
||||||
current += step
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
else -> sequence {
|
||||||
|
var current = endInclusive
|
||||||
|
while (current >= start) {
|
||||||
|
yield(current)
|
||||||
|
current += step
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert double range to sequence with the fixed number of points
|
||||||
|
*/
|
||||||
|
fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Sequence<Double> {
|
||||||
|
require(numPoints > 1) { "The number of points should be more than 2" }
|
||||||
|
return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
|
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
|
||||||
*/
|
*/
|
||||||
|
@Deprecated("Replace by 'toSequenceWithPoints'")
|
||||||
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
||||||
if (numPoints < 2) error("Can't create generic grid with less than two points")
|
if (numPoints < 2) error("Can't create generic grid with less than two points")
|
||||||
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
package scientifik.kmath.misc
|
package scientifik.kmath.misc
|
||||||
|
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generic cumulative operation on iterator
|
* Generic cumulative operation on iterator.
|
||||||
* @param T type of initial iterable
|
*
|
||||||
* @param R type of resulting iterable
|
* @param T the type of initial iterable.
|
||||||
* @param initial lazy evaluated
|
* @param R the type of resulting iterable.
|
||||||
|
* @param initial lazy evaluated.
|
||||||
*/
|
*/
|
||||||
fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> {
|
fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> {
|
||||||
var state: R = initial
|
var state: R = initial
|
||||||
@ -36,41 +37,41 @@ fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
|
|||||||
/**
|
/**
|
||||||
* Cumulative sum with custom space
|
* Cumulative sum with custom space
|
||||||
*/
|
*/
|
||||||
fun <T> Iterable<T>.cumulativeSum(space: Space<T>) = with(space) {
|
fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> = space {
|
||||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
cumulative(zero) { element: T, sum: T -> sum + element }
|
||||||
}
|
}
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
fun Iterable<Double>.cumulativeSum(): Iterable<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun Iterable<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
fun Iterable<Int>.cumulativeSum(): Iterable<Int> = this.cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
fun Iterable<Long>.cumulativeSum(): Iterable<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
||||||
|
|
||||||
fun <T> Sequence<T>.cumulativeSum(space: Space<T>) = with(space) {
|
fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> = with(space) {
|
||||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
cumulative(zero) { element: T, sum: T -> sum + element }
|
||||||
}
|
}
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
fun Sequence<Double>.cumulativeSum(): Sequence<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun Sequence<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
fun Sequence<Int>.cumulativeSum(): Sequence<Int> = this.cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
fun Sequence<Long>.cumulativeSum(): Sequence<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
||||||
|
|
||||||
fun <T> List<T>.cumulativeSum(space: Space<T>) = with(space) {
|
fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> = with(space) {
|
||||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
cumulative(zero) { element: T, sum: T -> sum + element }
|
||||||
}
|
}
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun List<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
fun List<Double>.cumulativeSum(): List<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun List<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
fun List<Int>.cumulativeSum(): List<Int> = this.cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun List<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
fun List<Long>.cumulativeSum(): List<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
||||||
|
@ -1,95 +1,340 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stub for DSL the [Algebra] is.
|
||||||
|
*/
|
||||||
@DslMarker
|
@DslMarker
|
||||||
annotation class KMathContext
|
annotation class KMathContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Marker interface for any algebra
|
* Represents an algebraic structure.
|
||||||
|
*
|
||||||
|
* @param T the type of element of this structure.
|
||||||
*/
|
*/
|
||||||
interface Algebra<T>
|
interface Algebra<T> {
|
||||||
|
/**
|
||||||
|
* Wrap raw string or variable
|
||||||
|
*/
|
||||||
|
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
|
||||||
|
|
||||||
inline operator fun <T : Algebra<*>, R> T.invoke(block: T.() -> R): R = run(block)
|
/**
|
||||||
|
* Dynamic call of unary operation with name [operation] on [arg]
|
||||||
|
*/
|
||||||
|
fun unaryOperation(operation: String, arg: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dynamic call of binary operation [operation] on [left] and [right]
|
||||||
|
*/
|
||||||
|
fun binaryOperation(operation: String, left: T, right: T): T
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Space-like operations without neutral element
|
* An algebraic structure where elements can have numeric representation.
|
||||||
|
*
|
||||||
|
* @param T the type of element of this structure.
|
||||||
|
*/
|
||||||
|
interface NumericAlgebra<T> : Algebra<T> {
|
||||||
|
/**
|
||||||
|
* Wraps a number.
|
||||||
|
*/
|
||||||
|
fun number(value: Number): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
|
||||||
|
*/
|
||||||
|
fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||||
|
binaryOperation(operation, number(left), right)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
|
||||||
|
*/
|
||||||
|
fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||||
|
leftSideNumberOperation(operation, right, left)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Call a block with an [Algebra] as receiver.
|
||||||
|
*/
|
||||||
|
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as
|
||||||
|
* multiplication by scalars.
|
||||||
|
*
|
||||||
|
* @param T the type of element of this semispace.
|
||||||
*/
|
*/
|
||||||
interface SpaceOperations<T> : Algebra<T> {
|
interface SpaceOperations<T> : Algebra<T> {
|
||||||
/**
|
/**
|
||||||
* Addition operation for two context elements
|
* Addition of two elements.
|
||||||
|
*
|
||||||
|
* @param a the addend.
|
||||||
|
* @param b the augend.
|
||||||
|
* @return the sum.
|
||||||
*/
|
*/
|
||||||
fun add(a: T, b: T): T
|
fun add(a: T, b: T): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Multiplication operation for context element and real number
|
* Multiplication of element by scalar.
|
||||||
|
*
|
||||||
|
* @param a the multiplier.
|
||||||
|
* @param k the multiplicand.
|
||||||
|
* @return the produce.
|
||||||
*/
|
*/
|
||||||
fun multiply(a: T, k: Number): T
|
fun multiply(a: T, k: Number): T
|
||||||
|
|
||||||
//Operation to be performed in this context
|
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The negation of this element.
|
||||||
|
*
|
||||||
|
* @receiver this value.
|
||||||
|
* @return the additive inverse of this value.
|
||||||
|
*/
|
||||||
operator fun T.unaryMinus(): T = multiply(this, -1.0)
|
operator fun T.unaryMinus(): T = multiply(this, -1.0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns this value.
|
||||||
|
*
|
||||||
|
* @receiver this value.
|
||||||
|
* @return this value.
|
||||||
|
*/
|
||||||
|
operator fun T.unaryPlus(): T = this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Addition of two elements.
|
||||||
|
*
|
||||||
|
* @receiver the addend.
|
||||||
|
* @param b the augend.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
operator fun T.plus(b: T): T = add(this, b)
|
operator fun T.plus(b: T): T = add(this, b)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtraction of two elements.
|
||||||
|
*
|
||||||
|
* @receiver the minuend.
|
||||||
|
* @param b the subtrahend.
|
||||||
|
* @return the difference.
|
||||||
|
*/
|
||||||
operator fun T.minus(b: T): T = add(this, -b)
|
operator fun T.minus(b: T): T = add(this, -b)
|
||||||
operator fun T.times(k: Number) = multiply(this, k.toDouble())
|
|
||||||
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
|
/**
|
||||||
operator fun Number.times(b: T) = b * this
|
* Multiplication of this element by a scalar.
|
||||||
|
*
|
||||||
|
* @receiver the multiplier.
|
||||||
|
* @param k the multiplicand.
|
||||||
|
* @return the product.
|
||||||
|
*/
|
||||||
|
operator fun T.times(k: Number): T = multiply(this, k.toDouble())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Division of this element by scalar.
|
||||||
|
*
|
||||||
|
* @receiver the dividend.
|
||||||
|
* @param k the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
|
operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplication of this number by element.
|
||||||
|
*
|
||||||
|
* @receiver the multiplier.
|
||||||
|
* @param b the multiplicand.
|
||||||
|
* @return the product.
|
||||||
|
*/
|
||||||
|
operator fun Number.times(b: T): T = b * this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||||
|
PLUS_OPERATION -> arg
|
||||||
|
MINUS_OPERATION -> -arg
|
||||||
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||||
|
PLUS_OPERATION -> add(left, right)
|
||||||
|
MINUS_OPERATION -> left - right
|
||||||
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* The identifier of addition.
|
||||||
|
*/
|
||||||
|
const val PLUS_OPERATION: String = "+"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of subtraction (and negation).
|
||||||
|
*/
|
||||||
|
const val MINUS_OPERATION: String = "-"
|
||||||
|
|
||||||
|
const val NOT_OPERATION: String = "!"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A general interface representing linear context of some kind.
|
* Represents linear space, i.e. algebraic structure with associative binary operation called "addition" and its neutral
|
||||||
* The context defines sum operation for its elements and multiplication by real value.
|
* element as well as multiplication by scalars.
|
||||||
* One must note that in some cases context is a singleton class, but in some cases it
|
|
||||||
* works as a context for operations inside it.
|
|
||||||
*
|
*
|
||||||
* TODO do we need non-commutative context?
|
* @param T the type of element of this group.
|
||||||
*/
|
*/
|
||||||
interface Space<T> : SpaceOperations<T> {
|
interface Space<T> : SpaceOperations<T> {
|
||||||
/**
|
/**
|
||||||
* Neutral element for sum operation
|
* The neutral element of addition.
|
||||||
*/
|
*/
|
||||||
val zero: T
|
val zero: T
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Operations on ring without multiplication neutral element
|
* Represents semiring, i.e. algebraic structure with two associative binary operations called "addition" and
|
||||||
|
* "multiplication".
|
||||||
|
*
|
||||||
|
* @param T the type of element of this semiring.
|
||||||
*/
|
*/
|
||||||
interface RingOperations<T> : SpaceOperations<T> {
|
interface RingOperations<T> : SpaceOperations<T> {
|
||||||
/**
|
/**
|
||||||
* Multiplication for two field elements
|
* Multiplies two elements.
|
||||||
|
*
|
||||||
|
* @param a the multiplier.
|
||||||
|
* @param b the multiplicand.
|
||||||
*/
|
*/
|
||||||
fun multiply(a: T, b: T): T
|
fun multiply(a: T, b: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies this element by scalar.
|
||||||
|
*
|
||||||
|
* @receiver the multiplier.
|
||||||
|
* @param b the multiplicand.
|
||||||
|
*/
|
||||||
operator fun T.times(b: T): T = multiply(this, b)
|
operator fun T.times(b: T): T = multiply(this, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||||
|
TIMES_OPERATION -> multiply(left, right)
|
||||||
|
else -> super.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* The identifier of multiplication.
|
||||||
|
*/
|
||||||
|
const val TIMES_OPERATION: String = "*"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The same as {@link Space} but with additional multiplication operation
|
* Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and
|
||||||
|
* "multiplication" and their neutral elements.
|
||||||
|
*
|
||||||
|
* @param T the type of element of this ring.
|
||||||
*/
|
*/
|
||||||
interface Ring<T> : Space<T>, RingOperations<T> {
|
interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||||
/**
|
/**
|
||||||
* neutral operation for multiplication
|
* neutral operation for multiplication
|
||||||
*/
|
*/
|
||||||
val one: T
|
val one: T
|
||||||
|
|
||||||
// operator fun T.plus(b: Number) = this.plus(b * one)
|
override fun number(value: Number): T = one * value.toDouble()
|
||||||
// operator fun Number.plus(b: T) = b + this
|
|
||||||
//
|
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
|
||||||
// operator fun T.minus(b: Number) = this.minus(b * one)
|
SpaceOperations.PLUS_OPERATION -> left + right
|
||||||
// operator fun Number.minus(b: T) = -b + this
|
SpaceOperations.MINUS_OPERATION -> left - right
|
||||||
|
RingOperations.TIMES_OPERATION -> left * right
|
||||||
|
else -> super.leftSideNumberOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||||
|
SpaceOperations.PLUS_OPERATION -> left + right
|
||||||
|
SpaceOperations.MINUS_OPERATION -> left - right
|
||||||
|
RingOperations.TIMES_OPERATION -> left * right
|
||||||
|
else -> super.rightSideNumberOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Addition of element and scalar.
|
||||||
|
*
|
||||||
|
* @receiver the addend.
|
||||||
|
* @param b the augend.
|
||||||
|
*/
|
||||||
|
operator fun T.plus(b: Number): T = this + number(b)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Addition of scalar and element.
|
||||||
|
*
|
||||||
|
* @receiver the addend.
|
||||||
|
* @param b the augend.
|
||||||
|
*/
|
||||||
|
operator fun Number.plus(b: T): T = b + this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtraction of element from number.
|
||||||
|
*
|
||||||
|
* @receiver the minuend.
|
||||||
|
* @param b the subtrahend.
|
||||||
|
* @receiver the difference.
|
||||||
|
*/
|
||||||
|
operator fun T.minus(b: Number): T = this - number(b)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtraction of number from element.
|
||||||
|
*
|
||||||
|
* @receiver the minuend.
|
||||||
|
* @param b the subtrahend.
|
||||||
|
* @receiver the difference.
|
||||||
|
*/
|
||||||
|
operator fun Number.minus(b: T): T = -b + this
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* All ring operations but without neutral elements
|
* Represents semifield, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
|
||||||
|
* and "division".
|
||||||
|
*
|
||||||
|
* @param T the type of element of this semifield.
|
||||||
*/
|
*/
|
||||||
interface FieldOperations<T> : RingOperations<T> {
|
interface FieldOperations<T> : RingOperations<T> {
|
||||||
|
/**
|
||||||
|
* Division of two elements.
|
||||||
|
*
|
||||||
|
* @param a the dividend.
|
||||||
|
* @param b the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
fun divide(a: T, b: T): T
|
fun divide(a: T, b: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Division of two elements.
|
||||||
|
*
|
||||||
|
* @receiver the dividend.
|
||||||
|
* @param b the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
operator fun T.div(b: T): T = divide(this, b)
|
operator fun T.div(b: T): T = divide(this, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||||
|
DIV_OPERATION -> divide(left, right)
|
||||||
|
else -> super.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* The identifier of division.
|
||||||
|
*/
|
||||||
|
const val DIV_OPERATION: String = "/"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Four operations algebra
|
* Represents field, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
|
||||||
|
* and "division" and their neutral elements.
|
||||||
|
*
|
||||||
|
* @param T the type of element of this semifield.
|
||||||
*/
|
*/
|
||||||
interface Field<T> : Ring<T>, FieldOperations<T> {
|
interface Field<T> : Ring<T>, FieldOperations<T> {
|
||||||
operator fun Number.div(b: T) = this * divide(one, b)
|
/**
|
||||||
|
* Division of element by scalar.
|
||||||
|
*
|
||||||
|
* @receiver the dividend.
|
||||||
|
* @param b the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
|
operator fun Number.div(b: T): T = this * divide(one, b)
|
||||||
}
|
}
|
||||||
|
@ -2,47 +2,107 @@ package scientifik.kmath.operations
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* The generic mathematics elements which is able to store its context
|
* The generic mathematics elements which is able to store its context
|
||||||
* @param T the type of space operation results
|
*
|
||||||
* @param I self type of the element. Needed for static type checking
|
* @param C the type of mathematical context for this element.
|
||||||
* @param C the type of mathematical context for this element
|
|
||||||
*/
|
*/
|
||||||
interface MathElement<C> {
|
interface MathElement<C> {
|
||||||
/**
|
/**
|
||||||
* The context this element belongs to
|
* The context this element belongs to.
|
||||||
*/
|
*/
|
||||||
val context: C
|
val context: C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents element that can be wrapped to its "primitive" value.
|
||||||
|
*
|
||||||
|
* @param T the type wrapped by this wrapper.
|
||||||
|
* @param I the type of this wrapper.
|
||||||
|
*/
|
||||||
interface MathWrapper<T, I> {
|
interface MathWrapper<T, I> {
|
||||||
|
/**
|
||||||
|
* Unwraps [I] to [T].
|
||||||
|
*/
|
||||||
fun unwrap(): T
|
fun unwrap(): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps [T] to [I].
|
||||||
|
*/
|
||||||
fun T.wrap(): I
|
fun T.wrap(): I
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The element of linear context
|
* The element of [Space].
|
||||||
* @param T the type of space operation results
|
*
|
||||||
* @param I self type of the element. Needed for static type checking
|
* @param T the type of space operation results.
|
||||||
* @param S the type of space
|
* @param I self type of the element. Needed for static type checking.
|
||||||
|
* @param S the type of space.
|
||||||
*/
|
*/
|
||||||
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> {
|
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> {
|
||||||
|
/**
|
||||||
|
* Adds element to this one.
|
||||||
|
*
|
||||||
|
* @param b the augend.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
operator fun plus(b: T): I = context.add(unwrap(), b).wrap()
|
||||||
|
|
||||||
operator fun plus(b: T) = context.add(unwrap(), b).wrap()
|
/**
|
||||||
operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
|
* Subtracts element from this one.
|
||||||
operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap()
|
*
|
||||||
operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
|
* @param b the subtrahend.
|
||||||
|
* @return the difference.
|
||||||
|
*/
|
||||||
|
operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies this element by number.
|
||||||
|
*
|
||||||
|
* @param k the multiplicand.
|
||||||
|
* @return the product.
|
||||||
|
*/
|
||||||
|
operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Divides this element by number.
|
||||||
|
*
|
||||||
|
* @param k the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
|
operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ring element
|
* The element of [Ring].
|
||||||
|
*
|
||||||
|
* @param T the type of space operation results.
|
||||||
|
* @param I self type of the element. Needed for static type checking.
|
||||||
|
* @param R the type of space.
|
||||||
*/
|
*/
|
||||||
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
||||||
operator fun times(b: T) = context.multiply(unwrap(), b).wrap()
|
/**
|
||||||
|
* Multiplies this element by another one.
|
||||||
|
*
|
||||||
|
* @param b the multiplicand.
|
||||||
|
* @return the product.
|
||||||
|
*/
|
||||||
|
operator fun times(b: T): I = context.multiply(unwrap(), b).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Field element
|
* The element of [Field].
|
||||||
|
*
|
||||||
|
* @param T the type of space operation results.
|
||||||
|
* @param I self type of the element. Needed for static type checking.
|
||||||
|
* @param F the type of field.
|
||||||
*/
|
*/
|
||||||
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
|
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
|
||||||
override val context: F
|
override val context: F
|
||||||
operator fun div(b: T) = context.divide(unwrap(), b).wrap()
|
|
||||||
|
/**
|
||||||
|
* Divides this element by another one.
|
||||||
|
*
|
||||||
|
* @param b the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
|
operator fun div(b: T): I = context.divide(unwrap(), b).wrap()
|
||||||
}
|
}
|
@ -1,15 +1,107 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the sum of all elements in the iterable in this [Space].
|
||||||
|
*
|
||||||
|
* @receiver the algebra that provides addition.
|
||||||
|
* @param data the iterable to sum up.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> add(left, right) }
|
fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> add(left, right) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the sum of all elements in the sequence in this [Space].
|
||||||
|
*
|
||||||
|
* @receiver the algebra that provides addition.
|
||||||
|
* @param data the sequence to sum up.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> add(left, right) }
|
fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> add(left, right) }
|
||||||
|
|
||||||
fun <T : Any, S : Space<T>> Iterable<T>.sumWith(space: S): T = space.sum(this)
|
/**
|
||||||
|
* Returns an average value of elements in the iterable in this [Space].
|
||||||
|
*
|
||||||
|
* @receiver the algebra that provides addition and division.
|
||||||
|
* @param data the iterable to find average.
|
||||||
|
* @return the average value.
|
||||||
|
*/
|
||||||
|
fun <T> Space<T>.average(data: Iterable<T>): T = sum(data) / data.count()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an average value of elements in the sequence in this [Space].
|
||||||
|
*
|
||||||
|
* @receiver the algebra that provides addition and division.
|
||||||
|
* @param data the sequence to find average.
|
||||||
|
* @return the average value.
|
||||||
|
*/
|
||||||
|
fun <T> Space<T>.average(data: Sequence<T>): T = sum(data) / data.count()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the sum of all elements in the iterable in provided space.
|
||||||
|
*
|
||||||
|
* @receiver the collection to sum up.
|
||||||
|
* @param space the algebra that provides addition.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
fun <T> Iterable<T>.sumWith(space: Space<T>): T = space.sum(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the sum of all elements in the sequence in provided space.
|
||||||
|
*
|
||||||
|
* @receiver the collection to sum up.
|
||||||
|
* @param space the algebra that provides addition.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
fun <T> Sequence<T>.sumWith(space: Space<T>): T = space.sum(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an average value of elements in the iterable in this [Space].
|
||||||
|
*
|
||||||
|
* @receiver the iterable to find average.
|
||||||
|
* @param space the algebra that provides addition and division.
|
||||||
|
* @return the average value.
|
||||||
|
*/
|
||||||
|
fun <T> Iterable<T>.averageWith(space: Space<T>): T = space.average(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an average value of elements in the sequence in this [Space].
|
||||||
|
*
|
||||||
|
* @receiver the sequence to find average.
|
||||||
|
* @param space the algebra that provides addition and division.
|
||||||
|
* @return the average value.
|
||||||
|
*/
|
||||||
|
fun <T> Sequence<T>.averageWith(space: Space<T>): T = space.average(this)
|
||||||
|
|
||||||
//TODO optimized power operation
|
//TODO optimized power operation
|
||||||
fun <T> RingOperations<T>.power(arg: T, power: Int): T {
|
|
||||||
|
/**
|
||||||
|
* Raises [arg] to the natural power [power].
|
||||||
|
*
|
||||||
|
* @receiver the algebra to provide multiplication.
|
||||||
|
* @param arg the base.
|
||||||
|
* @param power the exponent.
|
||||||
|
* @return the base raised to the power.
|
||||||
|
*/
|
||||||
|
fun <T> Ring<T>.power(arg: T, power: Int): T {
|
||||||
|
require(power >= 0) { "The power can't be negative." }
|
||||||
|
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
|
||||||
|
if (power == 0) return one
|
||||||
var res = arg
|
var res = arg
|
||||||
repeat(power - 1) {
|
repeat(power - 1) { res *= arg }
|
||||||
res *= arg
|
|
||||||
}
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raises [arg] to the integer power [power].
|
||||||
|
*
|
||||||
|
* @receiver the algebra to provide multiplication and division.
|
||||||
|
* @param arg the base.
|
||||||
|
* @param power the exponent.
|
||||||
|
* @return the base raised to the power.
|
||||||
|
*/
|
||||||
|
fun <T> Field<T>.power(arg: T, power: Int): T {
|
||||||
|
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
|
||||||
|
if (power == 0) return one
|
||||||
|
if (power < 0) return one / (this as Ring<T>).power(arg, -power)
|
||||||
|
return (this as Ring<T>).power(arg, power)
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.operations
|
|||||||
|
|
||||||
import scientifik.kmath.operations.BigInt.Companion.BASE
|
import scientifik.kmath.operations.BigInt.Companion.BASE
|
||||||
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
|
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||||
|
import scientifik.kmath.structures.*
|
||||||
import kotlin.math.log2
|
import kotlin.math.log2
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
@ -193,8 +194,8 @@ class BigInt internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
infix fun or(other: BigInt): BigInt {
|
infix fun or(other: BigInt): BigInt {
|
||||||
if (this == ZERO) return other;
|
if (this == ZERO) return other
|
||||||
if (other == ZERO) return this;
|
if (other == ZERO) return this
|
||||||
val resSize = max(this.magnitude.size, other.magnitude.size)
|
val resSize = max(this.magnitude.size, other.magnitude.size)
|
||||||
val newMagnitude: Magnitude = Magnitude(resSize)
|
val newMagnitude: Magnitude = Magnitude(resSize)
|
||||||
for (i in 0 until resSize) {
|
for (i in 0 until resSize) {
|
||||||
@ -209,7 +210,7 @@ class BigInt internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
infix fun and(other: BigInt): BigInt {
|
infix fun and(other: BigInt): BigInt {
|
||||||
if ((this == ZERO) or (other == ZERO)) return ZERO;
|
if ((this == ZERO) or (other == ZERO)) return ZERO
|
||||||
val resSize = min(this.magnitude.size, other.magnitude.size)
|
val resSize = min(this.magnitude.size, other.magnitude.size)
|
||||||
val newMagnitude: Magnitude = Magnitude(resSize)
|
val newMagnitude: Magnitude = Magnitude(resSize)
|
||||||
for (i in 0 until resSize) {
|
for (i in 0 until resSize) {
|
||||||
@ -259,7 +260,7 @@ class BigInt internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
const val BASE = 0xffffffffUL
|
const val BASE: ULong = 0xffffffffUL
|
||||||
const val BASE_SIZE: Int = 32
|
const val BASE_SIZE: Int = 32
|
||||||
val ZERO: BigInt = BigInt(0, uintArrayOf())
|
val ZERO: BigInt = BigInt(0, uintArrayOf())
|
||||||
val ONE: BigInt = BigInt(1, uintArrayOf(1u))
|
val ONE: BigInt = BigInt(1, uintArrayOf(1u))
|
||||||
@ -393,12 +394,12 @@ fun abs(x: BigInt): BigInt = x.abs()
|
|||||||
/**
|
/**
|
||||||
* Convert this [Int] to [BigInt]
|
* Convert this [Int] to [BigInt]
|
||||||
*/
|
*/
|
||||||
fun Int.toBigInt() = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))
|
fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this [Long] to [BigInt]
|
* Convert this [Long] to [BigInt]
|
||||||
*/
|
*/
|
||||||
fun Long.toBigInt() = BigInt(
|
fun Long.toBigInt(): BigInt = BigInt(
|
||||||
sign.toByte(), stripLeadingZeros(
|
sign.toByte(), stripLeadingZeros(
|
||||||
uintArrayOf(
|
uintArrayOf(
|
||||||
(kotlin.math.abs(this).toULong() and BASE).toUInt(),
|
(kotlin.math.abs(this).toULong() and BASE).toUInt(),
|
||||||
@ -410,17 +411,17 @@ fun Long.toBigInt() = BigInt(
|
|||||||
/**
|
/**
|
||||||
* Convert UInt to [BigInt]
|
* Convert UInt to [BigInt]
|
||||||
*/
|
*/
|
||||||
fun UInt.toBigInt() = BigInt(1, uintArrayOf(this))
|
fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert ULong to [BigInt]
|
* Convert ULong to [BigInt]
|
||||||
*/
|
*/
|
||||||
fun ULong.toBigInt() = BigInt(
|
fun ULong.toBigInt(): BigInt = BigInt(
|
||||||
1,
|
1,
|
||||||
stripLeadingZeros(
|
stripLeadingZeros(
|
||||||
uintArrayOf(
|
uintArrayOf(
|
||||||
(this and BigInt.BASE).toUInt(),
|
(this and BASE).toUInt(),
|
||||||
((this shr BigInt.BASE_SIZE) and BigInt.BASE).toUInt()
|
((this shr BASE_SIZE) and BASE).toUInt()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -433,7 +434,7 @@ fun UIntArray.toBigInt(sign: Byte): BigInt {
|
|||||||
return BigInt(sign, this.copyOf())
|
return BigInt(sign, this.copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
val hexChToInt = hashMapOf(
|
val hexChToInt: MutableMap<Char, Int> = hashMapOf(
|
||||||
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
|
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
|
||||||
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
|
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
|
||||||
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
|
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
|
||||||
@ -482,3 +483,18 @@ fun String.parseBigInteger(): BigInt? {
|
|||||||
}
|
}
|
||||||
return res * sign
|
return res * sign
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
||||||
|
boxing(size, initializer)
|
||||||
|
|
||||||
|
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
|
||||||
|
boxing(size, initializer)
|
||||||
|
|
||||||
|
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
||||||
|
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
||||||
|
|
||||||
|
fun NDElement.Companion.bigInt(
|
||||||
|
vararg shape: Int,
|
||||||
|
initializer: BigIntField.(IntArray) -> BigInt
|
||||||
|
): BufferedNDRingElement<BigInt, BigIntField> =
|
||||||
|
NDAlgebra.bigInt(*shape).produce(initializer)
|
||||||
|
@ -8,15 +8,20 @@ import scientifik.memory.MemorySpec
|
|||||||
import scientifik.memory.MemoryWriter
|
import scientifik.memory.MemoryWriter
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
|
private val PI_DIV_2 = Complex(PI / 2, 0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for complex numbers
|
* A field of [Complex].
|
||||||
*/
|
*/
|
||||||
object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
object ComplexField : ExtendedField<Complex> {
|
||||||
override val zero: Complex = Complex(0.0, 0.0)
|
override val zero: Complex = Complex(0.0, 0.0)
|
||||||
|
|
||||||
override val one: Complex = Complex(1.0, 0.0)
|
override val one: Complex = Complex(1.0, 0.0)
|
||||||
|
|
||||||
val i = Complex(0.0, 1.0)
|
/**
|
||||||
|
* The imaginary unit.
|
||||||
|
*/
|
||||||
|
val i: Complex = Complex(0.0, 1.0)
|
||||||
|
|
||||||
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
||||||
|
|
||||||
@ -30,9 +35,11 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
|||||||
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
|
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg))
|
override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
|
||||||
|
|
||||||
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
|
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
|
||||||
|
override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg)
|
||||||
|
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg)
|
||||||
|
override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2
|
||||||
|
|
||||||
override fun power(arg: Complex, pow: Number): Complex =
|
override fun power(arg: Complex, pow: Number): Complex =
|
||||||
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
|
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
|
||||||
@ -41,19 +48,59 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
|||||||
|
|
||||||
override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re)
|
override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re)
|
||||||
|
|
||||||
operator fun Double.plus(c: Complex) = add(this.toComplex(), c)
|
/**
|
||||||
|
* Adds complex number to real one.
|
||||||
|
*
|
||||||
|
* @receiver the addend.
|
||||||
|
* @param c the augend.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c)
|
||||||
|
|
||||||
operator fun Double.minus(c: Complex) = add(this.toComplex(), -c)
|
/**
|
||||||
|
* Subtracts complex number from real one.
|
||||||
|
*
|
||||||
|
* @receiver the minuend.
|
||||||
|
* @param c the subtrahend.
|
||||||
|
* @return the difference.
|
||||||
|
*/
|
||||||
|
operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c)
|
||||||
|
|
||||||
operator fun Complex.plus(d: Double) = d + this
|
/**
|
||||||
|
* Adds real number to complex one.
|
||||||
|
*
|
||||||
|
* @receiver the addend.
|
||||||
|
* @param d the augend.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
|
operator fun Complex.plus(d: Double): Complex = d + this
|
||||||
|
|
||||||
operator fun Complex.minus(d: Double) = add(this, -d.toComplex())
|
/**
|
||||||
|
* Subtracts real number from complex one.
|
||||||
|
*
|
||||||
|
* @receiver the minuend.
|
||||||
|
* @param d the subtrahend.
|
||||||
|
* @return the difference.
|
||||||
|
*/
|
||||||
|
operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex())
|
||||||
|
|
||||||
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this)
|
/**
|
||||||
|
* Multiplies real number by complex one.
|
||||||
|
*
|
||||||
|
* @receiver the multiplier.
|
||||||
|
* @param c the multiplicand.
|
||||||
|
* @receiver the product.
|
||||||
|
*/
|
||||||
|
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
|
||||||
|
|
||||||
|
override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Complex number class
|
* Represents complex number.
|
||||||
|
*
|
||||||
|
* @property re The real part.
|
||||||
|
* @property im The imaginary part.
|
||||||
*/
|
*/
|
||||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
||||||
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
||||||
@ -94,7 +141,13 @@ val Complex.r: Double get() = sqrt(re * re + im * im)
|
|||||||
*/
|
*/
|
||||||
val Complex.theta: Double get() = atan(im / re)
|
val Complex.theta: Double get() = atan(im / re)
|
||||||
|
|
||||||
fun Double.toComplex() = Complex(this, 0.0)
|
/**
|
||||||
|
* Creates a complex number with real part equal to this real.
|
||||||
|
*
|
||||||
|
* @receiver the real part.
|
||||||
|
* @return the new complex number.
|
||||||
|
*/
|
||||||
|
fun Double.toComplex(): Complex = Complex(this, 0.0)
|
||||||
|
|
||||||
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||||
return MemoryBuffer.create(Complex, size, init)
|
return MemoryBuffer.create(Complex, size, init)
|
||||||
|
@ -4,19 +4,45 @@ import kotlin.math.abs
|
|||||||
import kotlin.math.pow as kpow
|
import kotlin.math.pow as kpow
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Advanced Number-like field that implements basic operations
|
* Advanced Number-like semifield that implements basic operations.
|
||||||
*/
|
*/
|
||||||
interface ExtendedFieldOperations<T> :
|
interface ExtendedFieldOperations<T> :
|
||||||
FieldOperations<T>,
|
InverseTrigonometricOperations<T>,
|
||||||
TrigonometricOperations<T>,
|
|
||||||
PowerOperations<T>,
|
PowerOperations<T>,
|
||||||
ExponentialOperations<T>
|
ExponentialOperations<T> {
|
||||||
|
|
||||||
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>
|
override fun tan(arg: T): T = sin(arg) / cos(arg)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||||
|
TrigonometricOperations.COS_OPERATION -> cos(arg)
|
||||||
|
TrigonometricOperations.SIN_OPERATION -> sin(arg)
|
||||||
|
TrigonometricOperations.TAN_OPERATION -> tan(arg)
|
||||||
|
InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg)
|
||||||
|
InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg)
|
||||||
|
InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg)
|
||||||
|
PowerOperations.SQRT_OPERATION -> sqrt(arg)
|
||||||
|
ExponentialOperations.EXP_OPERATION -> exp(arg)
|
||||||
|
ExponentialOperations.LN_OPERATION -> ln(arg)
|
||||||
|
else -> super.unaryOperation(operation, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Advanced Number-like field that implements basic operations.
|
||||||
|
*/
|
||||||
|
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||||
|
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||||
|
PowerOperations.POW_OPERATION -> power(left, right)
|
||||||
|
else -> super.rightSideNumberOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Real field element wrapping double.
|
* Real field element wrapping double.
|
||||||
*
|
*
|
||||||
|
* @property value the [Double] value wrapped by this [Real].
|
||||||
|
*
|
||||||
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
|
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
|
||||||
*/
|
*/
|
||||||
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||||
@ -24,74 +50,90 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
|||||||
|
|
||||||
override fun Double.wrap(): Real = Real(value)
|
override fun Double.wrap(): Real = Real(value)
|
||||||
|
|
||||||
override val context get() = RealField
|
override val context: RealField get() = RealField
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for double without boxing. Does not produce appropriate field element
|
* A field for [Double] without boxing. Does not produce appropriate field element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||||
override val zero: Double = 0.0
|
override val zero: Double = 0.0
|
||||||
override inline fun add(a: Double, b: Double) = a + b
|
override inline fun add(a: Double, b: Double): Double = a + b
|
||||||
override inline fun multiply(a: Double, b: Double) = a * b
|
override inline fun multiply(a: Double, b: Double): Double = a * b
|
||||||
override inline fun multiply(a: Double, k: Number) = a * k.toDouble()
|
override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||||
|
|
||||||
override val one: Double = 1.0
|
override val one: Double = 1.0
|
||||||
override inline fun divide(a: Double, b: Double) = a / b
|
override inline fun divide(a: Double, b: Double): Double = a / b
|
||||||
|
|
||||||
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
|
override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
|
||||||
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
|
override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
|
||||||
|
override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
|
||||||
|
override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
|
||||||
|
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
|
||||||
|
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
|
||||||
|
|
||||||
override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble())
|
override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
|
||||||
|
|
||||||
override inline fun exp(arg: Double) = kotlin.math.exp(arg)
|
override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
|
||||||
override inline fun ln(arg: Double) = kotlin.math.ln(arg)
|
override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
|
||||||
|
|
||||||
override inline fun norm(arg: Double) = abs(arg)
|
override inline fun norm(arg: Double): Double = abs(arg)
|
||||||
|
|
||||||
override inline fun Double.unaryMinus() = -this
|
override inline fun Double.unaryMinus(): Double = -this
|
||||||
|
|
||||||
override inline fun Double.plus(b: Double) = this + b
|
override inline fun Double.plus(b: Double): Double = this + b
|
||||||
|
|
||||||
override inline fun Double.minus(b: Double) = this - b
|
override inline fun Double.minus(b: Double): Double = this - b
|
||||||
|
|
||||||
override inline fun Double.times(b: Double) = this * b
|
override inline fun Double.times(b: Double): Double = this * b
|
||||||
|
|
||||||
override inline fun Double.div(b: Double) = this / b
|
override inline fun Double.div(b: Double): Double = this / b
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
|
||||||
|
PowerOperations.POW_OPERATION -> left pow right
|
||||||
|
else -> super.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field for [Float] without boxing. Does not produce appropriate field element.
|
||||||
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||||
override val zero: Float = 0f
|
override val zero: Float = 0f
|
||||||
override inline fun add(a: Float, b: Float) = a + b
|
override inline fun add(a: Float, b: Float): Float = a + b
|
||||||
override inline fun multiply(a: Float, b: Float) = a * b
|
override inline fun multiply(a: Float, b: Float): Float = a * b
|
||||||
override inline fun multiply(a: Float, k: Number) = a * k.toFloat()
|
override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
|
||||||
|
|
||||||
override val one: Float = 1f
|
override val one: Float = 1f
|
||||||
override inline fun divide(a: Float, b: Float) = a / b
|
override inline fun divide(a: Float, b: Float): Float = a / b
|
||||||
|
|
||||||
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
|
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
|
||||||
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
|
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
|
||||||
|
override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
|
||||||
|
override inline fun acos(arg: Float): Float = kotlin.math.acos(arg)
|
||||||
|
override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
|
||||||
|
override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
|
||||||
|
|
||||||
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
|
override inline fun power(arg: Float, pow: Number): Float = arg.pow(pow.toFloat())
|
||||||
|
|
||||||
override inline fun exp(arg: Float) = kotlin.math.exp(arg)
|
override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
|
||||||
override inline fun ln(arg: Float) = kotlin.math.ln(arg)
|
override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
|
||||||
|
|
||||||
override inline fun norm(arg: Float) = abs(arg)
|
override inline fun norm(arg: Float): Float = abs(arg)
|
||||||
|
|
||||||
override inline fun Float.unaryMinus() = -this
|
override inline fun Float.unaryMinus(): Float = -this
|
||||||
|
|
||||||
override inline fun Float.plus(b: Float) = this + b
|
override inline fun Float.plus(b: Float): Float = this + b
|
||||||
|
|
||||||
override inline fun Float.minus(b: Float) = this - b
|
override inline fun Float.minus(b: Float): Float = this - b
|
||||||
|
|
||||||
override inline fun Float.times(b: Float) = this * b
|
override inline fun Float.times(b: Float): Float = this * b
|
||||||
|
|
||||||
override inline fun Float.div(b: Float) = this / b
|
override inline fun Float.div(b: Float): Float = this / b
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -100,14 +142,14 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
|||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object IntRing : Ring<Int>, Norm<Int, Int> {
|
object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||||
override val zero: Int = 0
|
override val zero: Int = 0
|
||||||
override inline fun add(a: Int, b: Int) = a + b
|
override inline fun add(a: Int, b: Int): Int = a + b
|
||||||
override inline fun multiply(a: Int, b: Int) = a * b
|
override inline fun multiply(a: Int, b: Int): Int = a * b
|
||||||
override inline fun multiply(a: Int, k: Number) = (k * a)
|
override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
|
||||||
override val one: Int = 1
|
override val one: Int = 1
|
||||||
|
|
||||||
override inline fun norm(arg: Int) = abs(arg)
|
override inline fun norm(arg: Int): Int = abs(arg)
|
||||||
|
|
||||||
override inline fun Int.unaryMinus() = -this
|
override inline fun Int.unaryMinus(): Int = -this
|
||||||
|
|
||||||
override inline fun Int.plus(b: Int): Int = this + b
|
override inline fun Int.plus(b: Int): Int = this + b
|
||||||
|
|
||||||
@ -122,20 +164,20 @@ object IntRing : Ring<Int>, Norm<Int, Int> {
|
|||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object ShortRing : Ring<Short>, Norm<Short, Short> {
|
object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||||
override val zero: Short = 0
|
override val zero: Short = 0
|
||||||
override inline fun add(a: Short, b: Short) = (a + b).toShort()
|
override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||||
override inline fun multiply(a: Short, b: Short) = (a * b).toShort()
|
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
||||||
override inline fun multiply(a: Short, k: Number) = (a * k)
|
override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
|
||||||
override val one: Short = 1
|
override val one: Short = 1
|
||||||
|
|
||||||
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||||
|
|
||||||
override inline fun Short.unaryMinus() = (-this).toShort()
|
override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
||||||
|
|
||||||
override inline fun Short.plus(b: Short) = (this + b).toShort()
|
override inline fun Short.plus(b: Short): Short = (this + b).toShort()
|
||||||
|
|
||||||
override inline fun Short.minus(b: Short) = (this - b).toShort()
|
override inline fun Short.minus(b: Short): Short = (this - b).toShort()
|
||||||
|
|
||||||
override inline fun Short.times(b: Short) = (this * b).toShort()
|
override inline fun Short.times(b: Short): Short = (this * b).toShort()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -144,20 +186,20 @@ object ShortRing : Ring<Short>, Norm<Short, Short> {
|
|||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||||
override val zero: Byte = 0
|
override val zero: Byte = 0
|
||||||
override inline fun add(a: Byte, b: Byte) = (a + b).toByte()
|
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||||
override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte()
|
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
||||||
override inline fun multiply(a: Byte, k: Number) = (a * k)
|
override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
|
||||||
override val one: Byte = 1
|
override val one: Byte = 1
|
||||||
|
|
||||||
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||||
|
|
||||||
override inline fun Byte.unaryMinus() = (-this).toByte()
|
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
||||||
|
|
||||||
override inline fun Byte.plus(b: Byte) = (this + b).toByte()
|
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
|
||||||
|
|
||||||
override inline fun Byte.minus(b: Byte) = (this - b).toByte()
|
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
|
||||||
|
|
||||||
override inline fun Byte.times(b: Byte) = (this * b).toByte()
|
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -166,18 +208,18 @@ object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
|||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object LongRing : Ring<Long>, Norm<Long, Long> {
|
object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||||
override val zero: Long = 0
|
override val zero: Long = 0
|
||||||
override inline fun add(a: Long, b: Long) = (a + b)
|
override inline fun add(a: Long, b: Long): Long = (a + b)
|
||||||
override inline fun multiply(a: Long, b: Long) = (a * b)
|
override inline fun multiply(a: Long, b: Long): Long = (a * b)
|
||||||
override inline fun multiply(a: Long, k: Number) = (a * k)
|
override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
|
||||||
override val one: Long = 1
|
override val one: Long = 1
|
||||||
|
|
||||||
override fun norm(arg: Long): Long = abs(arg)
|
override fun norm(arg: Long): Long = abs(arg)
|
||||||
|
|
||||||
override inline fun Long.unaryMinus() = (-this)
|
override inline fun Long.unaryMinus(): Long = (-this)
|
||||||
|
|
||||||
override inline fun Long.plus(b: Long) = (this + b)
|
override inline fun Long.plus(b: Long): Long = (this + b)
|
||||||
|
|
||||||
override inline fun Long.minus(b: Long) = (this - b)
|
override inline fun Long.minus(b: Long): Long = (this - b)
|
||||||
|
|
||||||
override inline fun Long.times(b: Long) = (this * b)
|
override inline fun Long.times(b: Long): Long = (this * b)
|
||||||
}
|
}
|
@ -1,57 +1,214 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
|
||||||
/* Trigonometric operations */
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A container for trigonometric operations for specific type. Trigonometric operations are limited to fields.
|
* A container for trigonometric operations for specific type. They are limited to semifields.
|
||||||
*
|
*
|
||||||
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
|
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
|
||||||
* It also allows to override behavior for optional operations
|
* It also allows to override behavior for optional operations.
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
interface TrigonometricOperations<T> : FieldOperations<T> {
|
interface TrigonometricOperations<T> : FieldOperations<T> {
|
||||||
|
/**
|
||||||
|
* Computes the sine of [arg].
|
||||||
|
*/
|
||||||
fun sin(arg: T): T
|
fun sin(arg: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the cosine of [arg].
|
||||||
|
*/
|
||||||
fun cos(arg: T): T
|
fun cos(arg: T): T
|
||||||
|
|
||||||
fun tg(arg: T): T = sin(arg) / cos(arg)
|
/**
|
||||||
|
* Computes the tangent of [arg].
|
||||||
|
*/
|
||||||
|
fun tan(arg: T): T
|
||||||
|
|
||||||
fun ctg(arg: T): T = cos(arg) / sin(arg)
|
companion object {
|
||||||
|
/**
|
||||||
|
* The identifier of sine.
|
||||||
|
*/
|
||||||
|
const val SIN_OPERATION: String = "sin"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of cosine.
|
||||||
|
*/
|
||||||
|
const val COS_OPERATION: String = "cos"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of tangent.
|
||||||
|
*/
|
||||||
|
const val TAN_OPERATION: String = "tan"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
|
|
||||||
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
|
|
||||||
fun <T : MathElement<out TrigonometricOperations<T>>> tg(arg: T): T = arg.context.tg(arg)
|
|
||||||
fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.context.ctg(arg)
|
|
||||||
|
|
||||||
/* Power and roots */
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context extension to include power operations like square roots, etc
|
* A container for inverse trigonometric operations for specific type. They are limited to semifields.
|
||||||
|
*
|
||||||
|
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
|
||||||
|
* It also allows to override behavior for optional operations.
|
||||||
|
*/
|
||||||
|
interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
|
||||||
|
/**
|
||||||
|
* Computes the inverse sine of [arg].
|
||||||
|
*/
|
||||||
|
fun asin(arg: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the inverse cosine of [arg].
|
||||||
|
*/
|
||||||
|
fun acos(arg: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the inverse tangent of [arg].
|
||||||
|
*/
|
||||||
|
fun atan(arg: T): T
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* The identifier of inverse sine.
|
||||||
|
*/
|
||||||
|
const val ASIN_OPERATION: String = "asin"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of inverse cosine.
|
||||||
|
*/
|
||||||
|
const val ACOS_OPERATION: String = "acos"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of inverse tangent.
|
||||||
|
*/
|
||||||
|
const val ATAN_OPERATION: String = "atan"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the sine of [arg].
|
||||||
|
*/
|
||||||
|
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the cosine of [arg].
|
||||||
|
*/
|
||||||
|
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the tangent of [arg].
|
||||||
|
*/
|
||||||
|
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the inverse sine of [arg].
|
||||||
|
*/
|
||||||
|
fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the inverse cosine of [arg].
|
||||||
|
*/
|
||||||
|
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the inverse tangent of [arg].
|
||||||
|
*/
|
||||||
|
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context extension to include power operations based on exponentiation.
|
||||||
*/
|
*/
|
||||||
interface PowerOperations<T> : Algebra<T> {
|
interface PowerOperations<T> : Algebra<T> {
|
||||||
|
/**
|
||||||
|
* Raises [arg] to the power [pow].
|
||||||
|
*/
|
||||||
fun power(arg: T, pow: Number): T
|
fun power(arg: T, pow: Number): T
|
||||||
fun sqrt(arg: T) = power(arg, 0.5)
|
|
||||||
|
|
||||||
infix fun T.pow(pow: Number) = power(this, pow)
|
/**
|
||||||
|
* Computes the square root of the value [arg].
|
||||||
|
*/
|
||||||
|
fun sqrt(arg: T): T = power(arg, 0.5)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raises this value to the power [pow].
|
||||||
|
*/
|
||||||
|
infix fun T.pow(pow: Number): T = power(this, pow)
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* The identifier of exponentiation.
|
||||||
|
*/
|
||||||
|
const val POW_OPERATION: String = "pow"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of square root.
|
||||||
|
*/
|
||||||
|
const val SQRT_OPERATION: String = "sqrt"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raises this element to the power [pow].
|
||||||
|
*
|
||||||
|
* @receiver the base.
|
||||||
|
* @param power the exponent.
|
||||||
|
* @return the base raised to the power.
|
||||||
|
*/
|
||||||
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the square root of the value [arg].
|
||||||
|
*/
|
||||||
fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
|
fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the square of the value [arg].
|
||||||
|
*/
|
||||||
fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
|
fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
|
||||||
|
|
||||||
/* Exponential */
|
/**
|
||||||
|
* A container for operations related to `exp` and `ln` functions.
|
||||||
interface ExponentialOperations<T>: Algebra<T> {
|
*/
|
||||||
|
interface ExponentialOperations<T> : Algebra<T> {
|
||||||
|
/**
|
||||||
|
* Computes Euler's number `e` raised to the power of the value [arg].
|
||||||
|
*/
|
||||||
fun exp(arg: T): T
|
fun exp(arg: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the natural logarithm (base `e`) of the value [arg].
|
||||||
|
*/
|
||||||
fun ln(arg: T): T
|
fun ln(arg: T): T
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* The identifier of exponential function.
|
||||||
|
*/
|
||||||
|
const val EXP_OPERATION: String = "exp"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of natural logarithm.
|
||||||
|
*/
|
||||||
|
const val LN_OPERATION: String = "ln"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of exponential function.
|
||||||
|
*/
|
||||||
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
|
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The identifier of natural logarithm.
|
||||||
|
*/
|
||||||
fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
|
fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A container for norm functional on element.
|
||||||
|
*/
|
||||||
interface Norm<in T : Any, out R> {
|
interface Norm<in T : Any, out R> {
|
||||||
|
/**
|
||||||
|
* Computes the norm of [arg] (i.e. absolute value or vector length).
|
||||||
|
*/
|
||||||
fun norm(arg: T): R
|
fun norm(arg: T): R
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the norm of [arg] (i.e. absolute value or vector length).
|
||||||
|
*/
|
||||||
fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)
|
fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)
|
@ -3,7 +3,6 @@ package scientifik.kmath.structures
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.FieldElement
|
import scientifik.kmath.operations.FieldElement
|
||||||
|
|
||||||
|
|
||||||
class BoxingNDField<T, F : Field<T>>(
|
class BoxingNDField<T, F : Field<T>>(
|
||||||
override val shape: IntArray,
|
override val shape: IntArray,
|
||||||
override val elementContext: F,
|
override val elementContext: F,
|
||||||
@ -19,10 +18,10 @@ class BoxingNDField<T, F : Field<T>>(
|
|||||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero by lazy { produce { zero } }
|
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
||||||
override val one by lazy { produce { one } }
|
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
|
||||||
|
|
||||||
override fun produce(initializer: F.(IntArray) -> T) =
|
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
||||||
BufferedNDFieldElement(
|
BufferedNDFieldElement(
|
||||||
this,
|
this,
|
||||||
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
||||||
|
@ -18,10 +18,10 @@ class BoxingNDRing<T, R : Ring<T>>(
|
|||||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero by lazy { produce { zero } }
|
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
|
||||||
override val one by lazy { produce { one } }
|
override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
|
||||||
|
|
||||||
override fun produce(initializer: R.(IntArray) -> T) =
|
override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
|
||||||
BufferedNDRingElement(
|
BufferedNDRingElement(
|
||||||
this,
|
this,
|
||||||
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
||||||
|
@ -7,16 +7,16 @@ import kotlin.reflect.KClass
|
|||||||
*/
|
*/
|
||||||
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
||||||
|
|
||||||
operator fun Buffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
|
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
|
||||||
|
|
||||||
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
||||||
set(i + colNum * j, value)
|
set(i + colNum * j, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun create(init: (i: Int, j: Int) -> T) =
|
inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer<T> =
|
||||||
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||||
|
|
||||||
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] }
|
fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
|
||||||
|
|
||||||
//TODO optimize wrapper
|
//TODO optimize wrapper
|
||||||
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
||||||
@ -41,5 +41,5 @@ class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum
|
|||||||
/**
|
/**
|
||||||
* Get row
|
* Get row
|
||||||
*/
|
*/
|
||||||
fun MutableBuffer<T>.row(i: Int) = Row(this, i)
|
fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||||
val strides: Strides
|
val strides: Strides
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>) {
|
||||||
@ -11,7 +11,8 @@ interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert any [NDStructure] to buffered structure using strides from this context.
|
* Convert any [NDStructure] to buffered structure using strides from this context.
|
||||||
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes
|
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over
|
||||||
|
* indices.
|
||||||
*
|
*
|
||||||
* If the argument is [NDBuffer] with different strides structure, the new element will be produced.
|
* If the argument is [NDBuffer] with different strides structure, the new element will be produced.
|
||||||
*/
|
*/
|
||||||
@ -30,7 +31,7 @@ interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T,S> {
|
interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
|
||||||
override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
|
override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,12 +3,12 @@ package scientifik.kmath.structures
|
|||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base interface for an element with context, containing strides
|
* Base class for an element with context, containing strides
|
||||||
*/
|
*/
|
||||||
interface BufferedNDElement<T, C> : NDBuffer<T>, NDElement<T, C, NDBuffer<T>> {
|
abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> {
|
||||||
override val context: BufferedNDAlgebra<T, C>
|
abstract override val context: BufferedNDAlgebra<T, C>
|
||||||
|
|
||||||
override val strides get() = context.strides
|
override val strides: Strides get() = context.strides
|
||||||
|
|
||||||
override val shape: IntArray get() = context.shape
|
override val shape: IntArray get() = context.shape
|
||||||
}
|
}
|
||||||
@ -16,7 +16,7 @@ interface BufferedNDElement<T, C> : NDBuffer<T>, NDElement<T, C, NDBuffer<T>> {
|
|||||||
class BufferedNDSpaceElement<T, S : Space<T>>(
|
class BufferedNDSpaceElement<T, S : Space<T>>(
|
||||||
override val context: BufferedNDSpace<T, S>,
|
override val context: BufferedNDSpace<T, S>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, S>, SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> {
|
) : BufferedNDElement<T, S>(), SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class BufferedNDSpaceElement<T, S : Space<T>>(
|
|||||||
class BufferedNDRingElement<T, R : Ring<T>>(
|
class BufferedNDRingElement<T, R : Ring<T>>(
|
||||||
override val context: BufferedNDRing<T, R>,
|
override val context: BufferedNDRing<T, R>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, R>, RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
|
) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ class BufferedNDRingElement<T, R : Ring<T>>(
|
|||||||
class BufferedNDFieldElement<T, F : Field<T>>(
|
class BufferedNDFieldElement<T, F : Field<T>>(
|
||||||
override val context: BufferedNDField<T, F>,
|
override val context: BufferedNDField<T, F>,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : BufferedNDElement<T, F>, FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
|
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
|
||||||
|
|
||||||
override fun unwrap(): NDBuffer<T> = this
|
override fun unwrap(): NDBuffer<T> = this
|
||||||
|
|
||||||
@ -54,9 +54,9 @@ class BufferedNDFieldElement<T, F : Field<T>>(
|
|||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
* Element by element application of any operation on elements to the whole array. Just like in numpy.
|
||||||
*/
|
*/
|
||||||
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>) =
|
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>): MathElement<out BufferedNDAlgebra<T, F>> =
|
||||||
ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
|
ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
|
||||||
|
|
||||||
/* plus and minus */
|
/* plus and minus */
|
||||||
@ -64,13 +64,13 @@ operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedN
|
|||||||
/**
|
/**
|
||||||
* Summation operation for [BufferedNDElement] and single element
|
* Summation operation for [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T) =
|
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||||
context.map(this) { it + arg }.wrap()
|
context.map(this) { it + arg }.wrap()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [BufferedNDElement] and single element
|
* Subtraction operation between [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) =
|
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||||
context.map(this) { it - arg }.wrap()
|
context.map(this) { it - arg }.wrap()
|
||||||
|
|
||||||
/* prod and div */
|
/* prod and div */
|
||||||
@ -78,11 +78,11 @@ operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) =
|
|||||||
/**
|
/**
|
||||||
* Product operation for [BufferedNDElement] and single element
|
* Product operation for [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T) =
|
operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||||
context.map(this) { it * arg }.wrap()
|
context.map(this) { it * arg }.wrap()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Division operation between [BufferedNDElement] and single element
|
* Division operation between [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T) =
|
operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||||
context.map(this) { it / arg }.wrap()
|
context.map(this) { it / arg }.wrap()
|
@ -4,42 +4,51 @@ import scientifik.kmath.operations.Complex
|
|||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Function that produces [Buffer] from its size and function that supplies values.
|
||||||
|
*
|
||||||
|
* @param T the type of buffer.
|
||||||
|
*/
|
||||||
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
||||||
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A generic random access structure for both primitives and objects
|
* Function that produces [MutableBuffer] from its size and function that supplies values.
|
||||||
|
*
|
||||||
|
* @param T the type of buffer.
|
||||||
|
*/
|
||||||
|
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A generic immutable random-access structure for both primitives and objects.
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
*/
|
*/
|
||||||
interface Buffer<T> {
|
interface Buffer<T> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The size of the buffer
|
* The size of this buffer.
|
||||||
*/
|
*/
|
||||||
val size: Int
|
val size: Int
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get element at given index
|
* Gets element at given index.
|
||||||
*/
|
*/
|
||||||
operator fun get(index: Int): T
|
operator fun get(index: Int): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Iterate over all elements
|
* Iterates over all elements.
|
||||||
*/
|
*/
|
||||||
operator fun iterator(): Iterator<T>
|
operator fun iterator(): Iterator<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check content eqiality with another buffer
|
* Checks content equality with another buffer.
|
||||||
*/
|
*/
|
||||||
fun contentEquals(other: Buffer<*>): Boolean =
|
fun contentEquals(other: Buffer<*>): Boolean =
|
||||||
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
|
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
|
||||||
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer {
|
|
||||||
val array = DoubleArray(size) { initializer(it) }
|
val array = DoubleArray(size) { initializer(it) }
|
||||||
return DoubleBuffer(array)
|
return RealBuffer(array)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -51,7 +60,7 @@ interface Buffer<T> {
|
|||||||
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||||
//TODO add resolution based on Annotation or companion resolution
|
//TODO add resolution based on Annotation or companion resolution
|
||||||
return when (type) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
||||||
@ -69,17 +78,34 @@ interface Buffer<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a sequence that returns all elements from this [Buffer].
|
||||||
|
*/
|
||||||
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
|
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
|
||||||
|
|
||||||
fun <T> Buffer<T>.asIterable(): Iterable<T> = asSequence().asIterable()
|
/**
|
||||||
|
* Creates an iterable that returns all elements from this [Buffer].
|
||||||
|
*/
|
||||||
|
fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator)
|
||||||
|
|
||||||
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1)
|
/**
|
||||||
|
* Returns an [IntRange] of the valid indices for this [Buffer].
|
||||||
|
*/
|
||||||
|
val Buffer<*>.indices: IntRange get() = 0 until size
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A generic mutable random-access structure for both primitives and objects.
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
|
*/
|
||||||
interface MutableBuffer<T> : Buffer<T> {
|
interface MutableBuffer<T> : Buffer<T> {
|
||||||
|
/**
|
||||||
|
* Sets the array element at the specified [index] to the specified [value].
|
||||||
|
*/
|
||||||
operator fun set(index: Int, value: T)
|
operator fun set(index: Int, value: T)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A shallow copy of the buffer
|
* Returns a shallow copy of the buffer.
|
||||||
*/
|
*/
|
||||||
fun copy(): MutableBuffer<T>
|
fun copy(): MutableBuffer<T>
|
||||||
|
|
||||||
@ -93,7 +119,7 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||||
return when (type) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
||||||
@ -109,14 +135,18 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
auto(T::class, size, initializer)
|
auto(T::class, size, initializer)
|
||||||
|
|
||||||
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
||||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [Buffer] implementation over [List].
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
|
* @property list The underlying list.
|
||||||
|
*/
|
||||||
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
||||||
|
|
||||||
override val size: Int
|
override val size: Int
|
||||||
get() = list.size
|
get() = list.size
|
||||||
|
|
||||||
@ -125,11 +155,26 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
|||||||
override fun iterator(): Iterator<T> = list.iterator()
|
override fun iterator(): Iterator<T> = list.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> List<T>.asBuffer() = ListBuffer<T>(this)
|
/**
|
||||||
|
* Returns an [ListBuffer] that wraps the original list.
|
||||||
|
*/
|
||||||
|
fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
/**
|
||||||
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
|
* Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
|
* [init] function.
|
||||||
|
*
|
||||||
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
|
* It should return the value for an array element given its index.
|
||||||
|
*/
|
||||||
|
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [MutableBuffer] implementation over [MutableList].
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
|
* @property list The underlying list.
|
||||||
|
*/
|
||||||
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
||||||
|
|
||||||
override val size: Int
|
override val size: Int
|
||||||
@ -145,8 +190,14 @@ inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
|||||||
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [MutableBuffer] implementation over [Array].
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
|
* @property array The underlying array.
|
||||||
|
*/
|
||||||
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
||||||
//Can't inline because array is invariant
|
// Can't inline because array is invariant
|
||||||
override val size: Int
|
override val size: Int
|
||||||
get() = array.size
|
get() = array.size
|
||||||
|
|
||||||
@ -161,99 +212,30 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
|||||||
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
|
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> Array<T>.asBuffer() = ArrayBuffer(this)
|
/**
|
||||||
|
* Returns an [ArrayBuffer] that wraps the original array.
|
||||||
inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
*/
|
||||||
override val size: Int get() = array.size
|
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
|
||||||
|
|
||||||
override fun get(index: Int): Double = array[index]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: Double) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun iterator() = array.iterator()
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf())
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
|
||||||
inline fun DoubleBuffer(size: Int, init: (Int) -> Double) = DoubleBuffer(DoubleArray(size) { init(it) })
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transform buffer of doubles into array for high performance operations
|
* Immutable wrapper for [MutableBuffer].
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
|
* @property buffer The underlying buffer.
|
||||||
*/
|
*/
|
||||||
val Buffer<out Double>.array: DoubleArray
|
|
||||||
get() = if (this is DoubleBuffer) {
|
|
||||||
array
|
|
||||||
} else {
|
|
||||||
DoubleArray(size) { get(it) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fun DoubleArray.asBuffer() = DoubleBuffer(this)
|
|
||||||
|
|
||||||
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
|
||||||
override val size: Int get() = array.size
|
|
||||||
|
|
||||||
override fun get(index: Int): Short = array[index]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: Short) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun iterator() = array.iterator()
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Short> = ShortBuffer(array.copyOf())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fun ShortArray.asBuffer() = ShortBuffer(this)
|
|
||||||
|
|
||||||
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
|
||||||
override val size: Int get() = array.size
|
|
||||||
|
|
||||||
override fun get(index: Int): Int = array[index]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: Int) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun iterator() = array.iterator()
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Int> = IntBuffer(array.copyOf())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fun IntArray.asBuffer() = IntBuffer(this)
|
|
||||||
|
|
||||||
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
|
||||||
override val size: Int get() = array.size
|
|
||||||
|
|
||||||
override fun get(index: Int): Long = array[index]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: Long) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun iterator() = array.iterator()
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Long> = LongBuffer(array.copyOf())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fun LongArray.asBuffer() = LongBuffer(this)
|
|
||||||
|
|
||||||
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
||||||
override val size: Int get() = buffer.size
|
override val size: Int get() = buffer.size
|
||||||
|
|
||||||
override fun get(index: Int): T = buffer.get(index)
|
override fun get(index: Int): T = buffer[index]
|
||||||
|
|
||||||
override fun iterator() = buffer.iterator()
|
override fun iterator(): Iterator<T> = buffer.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A buffer with content calculated on-demand. The calculated contect is not stored, so it is recalculated on each call.
|
* A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call.
|
||||||
* Useful when one needs single element from the buffer.
|
* Useful when one needs single element from the buffer.
|
||||||
|
*
|
||||||
|
* @param T the type of elements provided by the buffer.
|
||||||
*/
|
*/
|
||||||
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
||||||
override fun get(index: Int): T {
|
override fun get(index: Int): T {
|
||||||
@ -273,17 +255,16 @@ class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this buffer to read-only buffer
|
* Convert this buffer to read-only buffer.
|
||||||
*/
|
*/
|
||||||
fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) {
|
fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) ReadOnlyBuffer(this) else this
|
||||||
ReadOnlyBuffer(this)
|
|
||||||
} else {
|
|
||||||
this
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Typealias for buffer transformations
|
* Typealias for buffer transformations.
|
||||||
*/
|
*/
|
||||||
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
|
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Typealias for buffer transformations with suspend function.
|
||||||
|
*/
|
||||||
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>
|
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>
|
@ -17,8 +17,8 @@ class ComplexNDField(override val shape: IntArray) :
|
|||||||
override val strides: Strides = DefaultStrides(shape)
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
override val elementContext: ComplexField get() = ComplexField
|
override val elementContext: ComplexField get() = ComplexField
|
||||||
override val zero by lazy { produce { zero } }
|
override val zero: ComplexNDElement by lazy { produce { zero } }
|
||||||
override val one by lazy { produce { one } }
|
override val one: ComplexNDElement by lazy { produce { one } }
|
||||||
|
|
||||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
|
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
|
||||||
Buffer.complex(size) { initializer(it) }
|
Buffer.complex(size) { initializer(it) }
|
||||||
@ -69,16 +69,23 @@ class ComplexNDField(override val shape: IntArray) :
|
|||||||
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
|
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
|
||||||
BufferedNDFieldElement(this@ComplexNDField, buffer)
|
BufferedNDFieldElement(this@ComplexNDField, buffer)
|
||||||
|
|
||||||
override fun power(arg: NDBuffer<Complex>, pow: Number) = map(arg) { power(it, pow) }
|
override fun power(arg: NDBuffer<Complex>, pow: Number): ComplexNDElement = map(arg) { power(it, pow) }
|
||||||
|
|
||||||
override fun exp(arg: NDBuffer<Complex>) = map(arg) { exp(it) }
|
override fun exp(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { exp(it) }
|
||||||
|
|
||||||
override fun ln(arg: NDBuffer<Complex>) = map(arg) { ln(it) }
|
override fun ln(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { ln(it) }
|
||||||
|
|
||||||
override fun sin(arg: NDBuffer<Complex>) = map(arg) { sin(it) }
|
override fun sin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sin(it) }
|
||||||
|
|
||||||
override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) }
|
override fun cos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cos(it) }
|
||||||
|
|
||||||
|
override fun tan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tan(it) }
|
||||||
|
|
||||||
|
override fun asin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asin(it) }
|
||||||
|
|
||||||
|
override fun acos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acos(it) }
|
||||||
|
|
||||||
|
override fun atan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atan(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -91,13 +98,13 @@ inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline init
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map one [ComplexNDElement] using function with indexes
|
* Map one [ComplexNDElement] using function with indices.
|
||||||
*/
|
*/
|
||||||
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex) =
|
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
|
||||||
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map one [ComplexNDElement] using function without indexes
|
* Map one [ComplexNDElement] using function without indices.
|
||||||
*/
|
*/
|
||||||
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
||||||
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
||||||
@ -107,7 +114,7 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) ->
|
|||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||||
*/
|
*/
|
||||||
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
|
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
|
||||||
ndElement.map { this@invoke(it) }
|
ndElement.map { this@invoke(it) }
|
||||||
|
|
||||||
|
|
||||||
@ -116,19 +123,18 @@ operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
|
|||||||
/**
|
/**
|
||||||
* Summation operation for [BufferedNDElement] and single element
|
* Summation operation for [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun ComplexNDElement.plus(arg: Complex) =
|
operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg }
|
||||||
map { it + arg }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [BufferedNDElement] and single element
|
* Subtraction operation between [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun ComplexNDElement.minus(arg: Complex) =
|
operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement =
|
||||||
map { it - arg }
|
map { it - arg }
|
||||||
|
|
||||||
operator fun ComplexNDElement.plus(arg: Double) =
|
operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement =
|
||||||
map { it + arg }
|
map { it + arg }
|
||||||
|
|
||||||
operator fun ComplexNDElement.minus(arg: Double) =
|
operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement =
|
||||||
map { it - arg }
|
map { it - arg }
|
||||||
|
|
||||||
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.ExtendedField
|
||||||
|
|
||||||
interface ExtendedNDField<T : Any, F, N : NDStructure<T>> :
|
|
||||||
NDField<T, F, N>,
|
|
||||||
TrigonometricOperations<N>,
|
|
||||||
PowerOperations<N>,
|
|
||||||
ExponentialOperations<N>
|
|
||||||
where F : ExtendedFieldOperations<T>, F : Field<T>
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [ExtendedField] over [NDStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param N the type of ND structure.
|
||||||
|
* @param F the extended field of structure elements.
|
||||||
|
*/
|
||||||
|
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
|
||||||
|
|
||||||
///**
|
///**
|
||||||
// * NDField that supports [ExtendedField] operations on its elements
|
// * NDField that supports [ExtendedField] operations on its elements
|
||||||
@ -41,5 +42,3 @@ interface ExtendedNDField<T : Any, F, N : NDStructure<T>> :
|
|||||||
// return produce { with(elementContext) { cos(arg[it]) } }
|
// return produce { with(elementContext) { cos(arg[it]) } }
|
||||||
// }
|
// }
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.experimental.and
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents flags to supply additional info about values of buffer.
|
||||||
|
*
|
||||||
|
* @property mask bit mask value of this flag.
|
||||||
|
*/
|
||||||
|
enum class ValueFlag(val mask: Byte) {
|
||||||
|
/**
|
||||||
|
* Reports the value is NaN.
|
||||||
|
*/
|
||||||
|
NAN(0b0000_0001),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reports the value doesn't present in the buffer (when the type of value doesn't support `null`).
|
||||||
|
*/
|
||||||
|
MISSING(0b0000_0010),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reports the value is negative infinity.
|
||||||
|
*/
|
||||||
|
NEGATIVE_INFINITY(0b0000_0100),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reports the value is positive infinity
|
||||||
|
*/
|
||||||
|
POSITIVE_INFINITY(0b0000_1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A buffer with flagged values.
|
||||||
|
*/
|
||||||
|
interface FlaggedBuffer<T> : Buffer<T> {
|
||||||
|
fun getFlag(index: Int): Byte
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The value is valid if all flags are down
|
||||||
|
*/
|
||||||
|
fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
|
||||||
|
|
||||||
|
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
|
||||||
|
|
||||||
|
fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A real buffer which supports flags for each value like NaN or Missing
|
||||||
|
*/
|
||||||
|
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
||||||
|
init {
|
||||||
|
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getFlag(index: Int): Byte = flags[index]
|
||||||
|
|
||||||
|
override val size: Int get() = values.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
|
||||||
|
if (isValid(it)) values[it] else null
|
||||||
|
}.iterator()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
||||||
|
for (i in indices) {
|
||||||
|
if (isValid(i)) {
|
||||||
|
block(values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized [MutableBuffer] implementation over [FloatArray].
|
||||||
|
*
|
||||||
|
* @property array the underlying array.
|
||||||
|
*/
|
||||||
|
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Float = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Float) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): FloatIterator = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Float> =
|
||||||
|
FloatBuffer(array.copyOf())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new [FloatBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
|
* [init] function.
|
||||||
|
*
|
||||||
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
|
* It should return the value for an buffer element given its index.
|
||||||
|
*/
|
||||||
|
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a new [FloatBuffer] of given elements.
|
||||||
|
*/
|
||||||
|
fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a [FloatArray] containing all of the elements of this [MutableBuffer].
|
||||||
|
*/
|
||||||
|
val MutableBuffer<out Float>.array: FloatArray
|
||||||
|
get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns [FloatBuffer] over this array.
|
||||||
|
*
|
||||||
|
* @receiver the array.
|
||||||
|
* @return the new buffer.
|
||||||
|
*/
|
||||||
|
fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this)
|
@ -0,0 +1,50 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized [MutableBuffer] implementation over [IntArray].
|
||||||
|
*
|
||||||
|
* @property array the underlying array.
|
||||||
|
*/
|
||||||
|
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Int = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Int) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): IntIterator = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Int> =
|
||||||
|
IntBuffer(array.copyOf())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new [IntBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
|
* [init] function.
|
||||||
|
*
|
||||||
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
|
* It should return the value for an buffer element given its index.
|
||||||
|
*/
|
||||||
|
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a new [IntBuffer] of given elements.
|
||||||
|
*/
|
||||||
|
fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
|
||||||
|
*/
|
||||||
|
val MutableBuffer<out Int>.array: IntArray
|
||||||
|
get() = (if (this is IntBuffer) array else IntArray(size) { get(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns [IntBuffer] over this array.
|
||||||
|
*
|
||||||
|
* @receiver the array.
|
||||||
|
* @return the new buffer.
|
||||||
|
*/
|
||||||
|
fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)
|
@ -0,0 +1,50 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized [MutableBuffer] implementation over [LongArray].
|
||||||
|
*
|
||||||
|
* @property array the underlying array.
|
||||||
|
*/
|
||||||
|
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Long = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Long) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): LongIterator = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Long> =
|
||||||
|
LongBuffer(array.copyOf())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new [LongBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
|
* [init] function.
|
||||||
|
*
|
||||||
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
|
* It should return the value for an buffer element given its index.
|
||||||
|
*/
|
||||||
|
inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a new [LongBuffer] of given elements.
|
||||||
|
*/
|
||||||
|
fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
|
||||||
|
*/
|
||||||
|
val MutableBuffer<out Long>.array: LongArray
|
||||||
|
get() = (if (this is LongBuffer) array else LongArray(size) { get(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns [LongBuffer] over this array.
|
||||||
|
*
|
||||||
|
* @receiver the array.
|
||||||
|
* @return the new buffer.
|
||||||
|
*/
|
||||||
|
fun LongArray.asBuffer(): LongBuffer = LongBuffer(this)
|
@ -3,13 +3,16 @@ package scientifik.kmath.structures
|
|||||||
import scientifik.memory.*
|
import scientifik.memory.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A non-boxing buffer based on [ByteBuffer] storage
|
* A non-boxing buffer over [Memory] object.
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
|
* @property memory the underlying memory segment.
|
||||||
|
* @property spec the spec of [T] type.
|
||||||
*/
|
*/
|
||||||
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
|
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
|
||||||
|
|
||||||
override val size: Int get() = memory.size / spec.objectSize
|
override val size: Int get() = memory.size / spec.objectSize
|
||||||
|
|
||||||
private val reader = memory.reader()
|
private val reader: MemoryReader = memory.reader()
|
||||||
|
|
||||||
override fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
override fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
||||||
|
|
||||||
@ -17,7 +20,7 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
|
|||||||
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int) =
|
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
||||||
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
||||||
|
|
||||||
inline fun <T : Any> create(
|
inline fun <T : Any> create(
|
||||||
@ -33,24 +36,31 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A mutable non-boxing buffer over [Memory] object.
|
||||||
|
*
|
||||||
|
* @param T the type of elements contained in the buffer.
|
||||||
|
* @property memory the underlying memory segment.
|
||||||
|
* @property spec the spec of [T] type.
|
||||||
|
*/
|
||||||
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
|
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
|
||||||
MutableBuffer<T> {
|
MutableBuffer<T> {
|
||||||
|
|
||||||
private val writer = memory.writer()
|
private val writer: MemoryWriter = memory.writer()
|
||||||
|
|
||||||
override fun set(index: Int, value: T) = writer.write(spec, spec.objectSize * index, value)
|
override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
|
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int) =
|
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
|
||||||
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
||||||
|
|
||||||
inline fun <T : Any> create(
|
inline fun <T : Any> create(
|
||||||
spec: MemorySpec<T>,
|
spec: MemorySpec<T>,
|
||||||
size: Int,
|
size: Int,
|
||||||
crossinline initializer: (Int) -> T
|
crossinline initializer: (Int) -> T
|
||||||
) =
|
): MutableMemoryBuffer<T> =
|
||||||
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
||||||
(0 until size).forEach {
|
(0 until size).forEach {
|
||||||
buffer[it] = initializer(it)
|
buffer[it] = initializer(it)
|
||||||
|
@ -56,7 +56,7 @@ interface NDAlgebra<T, C, N : NDStructure<T>> {
|
|||||||
/**
|
/**
|
||||||
* element-by-element invoke a function working on [T] on a [NDStructure]
|
* element-by-element invoke a function working on [T] on a [NDStructure]
|
||||||
*/
|
*/
|
||||||
operator fun Function1<T, T>.invoke(structure: N) = map(structure) { value -> this@invoke(value) }
|
operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
}
|
}
|
||||||
@ -76,12 +76,12 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
|
|||||||
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
|
||||||
|
|
||||||
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
|
||||||
|
|
||||||
operator fun T.plus(arg: N) = map(arg) { value -> add(this@plus, value) }
|
operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
|
||||||
operator fun T.minus(arg: N) = map(arg) { value -> add(-this@minus, value) }
|
operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
}
|
}
|
||||||
@ -97,20 +97,19 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
|
|||||||
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) }
|
operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
|
||||||
|
|
||||||
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) }
|
operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Field for n-dimensional structures.
|
* Field of [NDStructure].
|
||||||
* @param shape - the list of dimensions of the array
|
*
|
||||||
* @param elementField - operations field defined on individual array element
|
* @param T the type of the element contained in ND structure.
|
||||||
* @param T - the type of the element contained in ND structure
|
* @param N the type of ND structure.
|
||||||
* @param F - field of structure elements
|
* @param F field of structure elements.
|
||||||
* @param R - actual nd-element type of this field
|
|
||||||
*/
|
*/
|
||||||
interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
|
interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
|
||||||
|
|
||||||
@ -120,9 +119,9 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
|||||||
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) }
|
operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
|
||||||
|
|
||||||
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
||||||
@ -131,7 +130,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
|||||||
/**
|
/**
|
||||||
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
||||||
*/
|
*/
|
||||||
fun real(vararg shape: Int) = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a nd-field with boxing generic buffer
|
* Create a nd-field with boxing generic buffer
|
||||||
@ -140,7 +139,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
|||||||
field: F,
|
field: F,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
) = BoxingNDField(shape, field, bufferFactory)
|
): BoxingNDField<T, F> = BoxingNDField(shape, field, bufferFactory)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a most suitable implementation for nd-field using reified class.
|
* Create a most suitable implementation for nd-field using reified class.
|
||||||
|
@ -23,19 +23,23 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Create a optimized NDArray of doubles
|
* Create a optimized NDArray of doubles
|
||||||
*/
|
*/
|
||||||
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }) =
|
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
|
||||||
NDField.real(*shape).produce(initializer)
|
NDField.real(*shape).produce(initializer)
|
||||||
|
|
||||||
|
|
||||||
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) =
|
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
|
||||||
real(intArrayOf(dim)) { initializer(it[0]) }
|
real(intArrayOf(dim)) { initializer(it[0]) }
|
||||||
|
|
||||||
|
|
||||||
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) =
|
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement =
|
||||||
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||||
|
|
||||||
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) =
|
fun real3D(
|
||||||
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
dim1: Int,
|
||||||
|
dim2: Int,
|
||||||
|
dim3: Int,
|
||||||
|
initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }
|
||||||
|
): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -62,16 +66,17 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T) =
|
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
|
||||||
context.mapIndexed(unwrap(), transform).wrap()
|
context.mapIndexed(unwrap(), transform).wrap()
|
||||||
|
|
||||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap()
|
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
|
||||||
|
context.map(unwrap(), transform).wrap()
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole [NDElement]
|
* Element by element application of any operation on elements to the whole [NDElement]
|
||||||
*/
|
*/
|
||||||
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>) =
|
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> =
|
||||||
ndElement.map { value -> this@invoke(value) }
|
ndElement.map { value -> this@invoke(value) }
|
||||||
|
|
||||||
/* plus and minus */
|
/* plus and minus */
|
||||||
@ -79,13 +84,13 @@ operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElem
|
|||||||
/**
|
/**
|
||||||
* Summation operation for [NDElement] and single element
|
* Summation operation for [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T) =
|
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> =
|
||||||
map { value -> arg + value }
|
map { value -> arg + value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [NDElement] and single element
|
* Subtraction operation between [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T) =
|
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> =
|
||||||
map { value -> arg - value }
|
map { value -> arg - value }
|
||||||
|
|
||||||
/* prod and div */
|
/* prod and div */
|
||||||
@ -93,13 +98,13 @@ operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg:
|
|||||||
/**
|
/**
|
||||||
* Product operation for [NDElement] and single element
|
* Product operation for [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T) =
|
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> =
|
||||||
map { value -> arg * value }
|
map { value -> arg * value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Division operation between [NDElement] and single element
|
* Division operation between [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T) =
|
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
|
||||||
map { value -> arg / value }
|
map { value -> arg / value }
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,70 +3,138 @@ package scientifik.kmath.structures
|
|||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents n-dimensional structure, i.e. multidimensional container of items of the same type and size. The number
|
||||||
|
* of dimensions and items in an array is defined by its shape, which is a sequence of non-negative integers that
|
||||||
|
* specify the sizes of each dimension.
|
||||||
|
*
|
||||||
|
* @param T the type of items.
|
||||||
|
*/
|
||||||
interface NDStructure<T> {
|
interface NDStructure<T> {
|
||||||
|
/**
|
||||||
|
* The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of
|
||||||
|
* this structure.
|
||||||
|
*/
|
||||||
val shape: IntArray
|
val shape: IntArray
|
||||||
|
|
||||||
val dimension get() = shape.size
|
/**
|
||||||
|
* The count of dimensions in this structure. It should be equal to size of [shape].
|
||||||
|
*/
|
||||||
|
val dimension: Int get() = shape.size
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the value at the specified indices.
|
||||||
|
*
|
||||||
|
* @param index the indices.
|
||||||
|
* @return the value.
|
||||||
|
*/
|
||||||
operator fun get(index: IntArray): T
|
operator fun get(index: IntArray): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the sequence of all the elements associated by their indices.
|
||||||
|
*
|
||||||
|
* @return the lazy sequence of pairs of indices to values.
|
||||||
|
*/
|
||||||
fun elements(): Sequence<Pair<IntArray, T>>
|
fun elements(): Sequence<Pair<IntArray, T>>
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean
|
||||||
|
|
||||||
|
override fun hashCode(): Int
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
/**
|
||||||
|
* Indicates whether some [NDStructure] is equal to another one.
|
||||||
|
*/
|
||||||
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||||
return when {
|
if (st1 === st2) return true
|
||||||
st1 === st2 -> true
|
|
||||||
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(
|
// fast comparison of buffers if possible
|
||||||
st2.buffer
|
if (
|
||||||
)
|
st1 is NDBuffer &&
|
||||||
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
st2 is NDBuffer &&
|
||||||
|
st1.strides == st2.strides
|
||||||
|
) {
|
||||||
|
return st1.buffer.contentEquals(st2.buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//element by element comparison if it could not be avoided
|
||||||
|
return st1.elements().all { (index, value) -> value == st2[index] }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a NDStructure with explicit buffer factory
|
* Creates a NDStructure with explicit buffer factory.
|
||||||
*
|
*
|
||||||
* Strides should be reused if possible
|
* Strides should be reused if possible.
|
||||||
*/
|
*/
|
||||||
fun <T> build(
|
fun <T> build(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T
|
||||||
) =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> auto(
|
||||||
|
strides: Strides,
|
||||||
|
crossinline initializer: (IntArray) -> T
|
||||||
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
inline fun <T : Any> auto(type: KClass<T>, strides: Strides, crossinline initializer: (IntArray) -> T) =
|
inline fun <T : Any> auto(
|
||||||
|
type: KClass<T>,
|
||||||
|
strides: Strides,
|
||||||
|
crossinline initializer: (IntArray) -> T
|
||||||
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
fun <T> build(
|
fun <T> build(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T
|
||||||
) = build(DefaultStrides(shape), bufferFactory, initializer)
|
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||||
|
|
||||||
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> auto(
|
||||||
|
shape: IntArray,
|
||||||
|
crossinline initializer: (IntArray) -> T
|
||||||
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
@JvmName("autoVarArg")
|
@JvmName("autoVarArg")
|
||||||
inline fun <reified T : Any> auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> auto(
|
||||||
|
vararg shape: Int,
|
||||||
|
crossinline initializer: (IntArray) -> T
|
||||||
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
inline fun <T : Any> auto(type: KClass<T>, vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
inline fun <T : Any> auto(
|
||||||
|
type: KClass<T>,
|
||||||
|
vararg shape: Int,
|
||||||
|
crossinline initializer: (IntArray) -> T
|
||||||
|
): BufferNDStructure<T> =
|
||||||
auto(type, DefaultStrides(shape), initializer)
|
auto(type, DefaultStrides(shape), initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the value at the specified indices.
|
||||||
|
*
|
||||||
|
* @param index the indices.
|
||||||
|
* @return the value.
|
||||||
|
*/
|
||||||
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents mutable [NDStructure].
|
||||||
|
*/
|
||||||
interface MutableNDStructure<T> : NDStructure<T> {
|
interface MutableNDStructure<T> : NDStructure<T> {
|
||||||
|
/**
|
||||||
|
* Inserts an item at the specified indices.
|
||||||
|
*
|
||||||
|
* @param index the indices.
|
||||||
|
* @param value the value.
|
||||||
|
*/
|
||||||
operator fun set(index: IntArray, value: T)
|
operator fun set(index: IntArray, value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,7 +145,7 @@ inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A way to convert ND index to linear one and back
|
* A way to convert ND index to linear one and back.
|
||||||
*/
|
*/
|
||||||
interface Strides {
|
interface Strides {
|
||||||
/**
|
/**
|
||||||
@ -114,11 +182,14 @@ interface Strides {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple implementation of [Strides].
|
||||||
|
*/
|
||||||
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
||||||
/**
|
/**
|
||||||
* Strides for memory access
|
* Strides for memory access
|
||||||
*/
|
*/
|
||||||
override val strides by lazy {
|
override val strides: List<Int> by lazy {
|
||||||
sequence {
|
sequence {
|
||||||
var current = 1
|
var current = 1
|
||||||
yield(1)
|
yield(1)
|
||||||
@ -153,19 +224,14 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
override val linearSize: Int
|
override val linearSize: Int
|
||||||
get() = strides[shape.size]
|
get() = strides[shape.size]
|
||||||
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other !is DefaultStrides) return false
|
if (other !is DefaultStrides) return false
|
||||||
|
|
||||||
if (!shape.contentEquals(other.shape)) return false
|
if (!shape.contentEquals(other.shape)) return false
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
override fun hashCode(): Int = shape.contentHashCode()
|
||||||
return shape.contentHashCode()
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||||
@ -177,15 +243,37 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
interface NDBuffer<T> : NDStructure<T> {
|
/**
|
||||||
val buffer: Buffer<T>
|
* Represents [NDStructure] over [Buffer].
|
||||||
val strides: Strides
|
*
|
||||||
|
* @param T the type of items.
|
||||||
|
*/
|
||||||
|
abstract class NDBuffer<T> : NDStructure<T> {
|
||||||
|
/**
|
||||||
|
* The underlying buffer.
|
||||||
|
*/
|
||||||
|
abstract val buffer: Buffer<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The strides to access elements of [Buffer] by linear indices.
|
||||||
|
*/
|
||||||
|
abstract val strides: Strides
|
||||||
|
|
||||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||||
|
|
||||||
override val shape: IntArray get() = strides.shape
|
override val shape: IntArray get() = strides.shape
|
||||||
|
|
||||||
override fun elements() = strides.indices().map { it to this[it] }
|
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = strides.hashCode()
|
||||||
|
result = 31 * result + buffer.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -194,34 +282,12 @@ interface NDBuffer<T> : NDStructure<T> {
|
|||||||
class BufferNDStructure<T>(
|
class BufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : NDBuffer<T> {
|
) : NDBuffer<T>() {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
if (strides.linearSize != buffer.size) {
|
||||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
|
||||||
|
|
||||||
override val shape: IntArray get() = strides.shape
|
|
||||||
|
|
||||||
override fun elements() = strides.indices().map { it to this[it] }
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
|
||||||
return when {
|
|
||||||
this === other -> true
|
|
||||||
other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer)
|
|
||||||
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
|
|
||||||
else -> false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
|
||||||
var result = strides.hashCode()
|
|
||||||
result = 31 * result + buffer.hashCode()
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -240,20 +306,20 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mutable ND buffer based on linear [autoBuffer]
|
* Mutable ND buffer based on linear [MutableBuffer].
|
||||||
*/
|
*/
|
||||||
class MutableBufferNDStructure<T>(
|
class MutableBufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: MutableBuffer<T>
|
override val buffer: MutableBuffer<T>
|
||||||
) : NDBuffer<T>, MutableNDStructure<T> {
|
) : NDBuffer<T>(), MutableNDStructure<T> {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
require(strides.linearSize == buffer.size) {
|
||||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
"Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value)
|
override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any> NDStructure<T>.combine(
|
inline fun <reified T : Any> NDStructure<T>.combine(
|
||||||
|
@ -0,0 +1,49 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized [MutableBuffer] implementation over [DoubleArray].
|
||||||
|
*
|
||||||
|
* @property array the underlying array.
|
||||||
|
*/
|
||||||
|
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Double = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Double) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): DoubleIterator = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Double> =
|
||||||
|
RealBuffer(array.copyOf())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new [RealBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
|
* [init] function.
|
||||||
|
*
|
||||||
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
|
* It should return the value for an buffer element given its index.
|
||||||
|
*/
|
||||||
|
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a new [RealBuffer] of given elements.
|
||||||
|
*/
|
||||||
|
fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
|
||||||
|
*/
|
||||||
|
val MutableBuffer<out Double>.array: DoubleArray
|
||||||
|
get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns [RealBuffer] over this array.
|
||||||
|
*
|
||||||
|
* @receiver the array.
|
||||||
|
* @return the new buffer.
|
||||||
|
*/
|
||||||
|
fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this)
|
@ -6,148 +6,180 @@ import kotlin.math.*
|
|||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple field over linear buffers of [Double]
|
* [ExtendedFieldOperations] over [RealBuffer].
|
||||||
*/
|
*/
|
||||||
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
|
||||||
|
return if (a is RealBuffer && b is RealBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
val bArray = b.array
|
val bArray = b.array
|
||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
||||||
} else {
|
} else
|
||||||
DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||||
val kValue = k.toDouble()
|
val kValue = k.toDouble()
|
||||||
return if (a is DoubleBuffer) {
|
|
||||||
|
return if (a is RealBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue })
|
RealBuffer(DoubleArray(a.size) { aArray[it] * kValue })
|
||||||
} else {
|
} else
|
||||||
DoubleBuffer(DoubleArray(a.size) { a[it] * kValue })
|
RealBuffer(DoubleArray(a.size) { a[it] * kValue })
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
|
||||||
|
return if (a is RealBuffer && b is RealBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
val bArray = b.array
|
val bArray = b.array
|
||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
||||||
} else {
|
} else
|
||||||
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
RealBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
|
||||||
|
return if (a is RealBuffer && b is RealBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
val bArray = b.array
|
val bArray = b.array
|
||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
||||||
} else {
|
} else
|
||||||
DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
return if (arg is DoubleBuffer) {
|
val array = arg.array
|
||||||
val array = arg.array
|
RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) })
|
} else {
|
||||||
} else {
|
RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) })
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
return if (arg is DoubleBuffer) {
|
val array = arg.array
|
||||||
val array = arg.array
|
RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) })
|
} else
|
||||||
} else {
|
RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) })
|
|
||||||
}
|
override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
|
||||||
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
|
||||||
|
|
||||||
|
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
|
||||||
|
} else {
|
||||||
|
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
return if (arg is DoubleBuffer) {
|
val array = arg.array
|
||||||
val array = arg.array
|
RealBuffer(DoubleArray(arg.size) { acos(array[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
} else
|
||||||
} else {
|
RealBuffer(DoubleArray(arg.size) { acos(arg[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
override fun atan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
return if (arg is DoubleBuffer) {
|
val array = arg.array
|
||||||
val array = arg.array
|
RealBuffer(DoubleArray(arg.size) { atan(array[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) })
|
} else
|
||||||
} else {
|
RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
|
||||||
return if (arg is DoubleBuffer) {
|
val array = arg.array
|
||||||
val array = arg.array
|
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
} else
|
||||||
} else {
|
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
||||||
DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
|
||||||
}
|
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
}
|
val array = arg.array
|
||||||
|
RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
|
||||||
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
|
||||||
|
|
||||||
|
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
||||||
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [ExtendedField] over [RealBuffer].
|
||||||
|
*
|
||||||
|
* @property size the size of buffers to operate on.
|
||||||
|
*/
|
||||||
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
||||||
|
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||||
|
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||||
|
|
||||||
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||||
|
|
||||||
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
|
|
||||||
|
|
||||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
|
||||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.add(a, b)
|
return RealBufferFieldOperations.add(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.multiply(a, k)
|
return RealBufferFieldOperations.multiply(a, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.multiply(a, b)
|
return RealBufferFieldOperations.multiply(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
|
||||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.divide(a, b)
|
return RealBufferFieldOperations.divide(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
override fun sin(arg: Buffer<Double>): RealBuffer {
|
||||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.sin(arg)
|
return RealBufferFieldOperations.sin(arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
override fun cos(arg: Buffer<Double>): RealBuffer {
|
||||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.cos(arg)
|
return RealBufferFieldOperations.cos(arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
override fun tan(arg: Buffer<Double>): RealBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.tan(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun asin(arg: Buffer<Double>): RealBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.asin(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun acos(arg: Buffer<Double>): RealBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.acos(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun atan(arg: Buffer<Double>): RealBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.atan(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
|
||||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.power(arg, pow)
|
return RealBufferFieldOperations.power(arg, pow)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
override fun exp(arg: Buffer<Double>): RealBuffer {
|
||||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.exp(arg)
|
return RealBufferFieldOperations.exp(arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
override fun ln(arg: Buffer<Double>): RealBuffer {
|
||||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
return RealBufferFieldOperations.ln(arg)
|
return RealBufferFieldOperations.ln(arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -12,11 +12,11 @@ class RealNDField(override val shape: IntArray) :
|
|||||||
override val strides: Strides = DefaultStrides(shape)
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
override val elementContext: RealField get() = RealField
|
override val elementContext: RealField get() = RealField
|
||||||
override val zero by lazy { produce { zero } }
|
override val zero: RealNDElement by lazy { produce { zero } }
|
||||||
override val one by lazy { produce { one } }
|
override val one: RealNDElement by lazy { produce { one } }
|
||||||
|
|
||||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inline transform an NDStructure to
|
* Inline transform an NDStructure to
|
||||||
@ -64,16 +64,23 @@ class RealNDField(override val shape: IntArray) :
|
|||||||
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
|
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
|
||||||
BufferedNDFieldElement(this@RealNDField, buffer)
|
BufferedNDFieldElement(this@RealNDField, buffer)
|
||||||
|
|
||||||
override fun power(arg: NDBuffer<Double>, pow: Number) = map(arg) { power(it, pow) }
|
override fun power(arg: NDBuffer<Double>, pow: Number): RealNDElement = map(arg) { power(it, pow) }
|
||||||
|
|
||||||
override fun exp(arg: NDBuffer<Double>) = map(arg) { exp(it) }
|
override fun exp(arg: NDBuffer<Double>): RealNDElement = map(arg) { exp(it) }
|
||||||
|
|
||||||
override fun ln(arg: NDBuffer<Double>) = map(arg) { ln(it) }
|
override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) }
|
||||||
|
|
||||||
override fun sin(arg: NDBuffer<Double>) = map(arg) { sin(it) }
|
override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) }
|
||||||
|
|
||||||
override fun cos(arg: NDBuffer<Double>) = map(arg) { cos(it) }
|
override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) }
|
||||||
|
|
||||||
|
override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) }
|
||||||
|
|
||||||
|
override fun asin(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { asin(it) }
|
||||||
|
|
||||||
|
override fun acos(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { acos(it) }
|
||||||
|
|
||||||
|
override fun atan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { atan(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -82,27 +89,27 @@ class RealNDField(override val shape: IntArray) :
|
|||||||
*/
|
*/
|
||||||
inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
|
inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
|
||||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
|
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
|
||||||
return BufferedNDFieldElement(this, DoubleBuffer(array))
|
return BufferedNDFieldElement(this, RealBuffer(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map one [RealNDElement] using function with indexes
|
* Map one [RealNDElement] using function with indices.
|
||||||
*/
|
*/
|
||||||
inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double) =
|
inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double): RealNDElement =
|
||||||
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map one [RealNDElement] using function without indexes
|
* Map one [RealNDElement] using function without indices.
|
||||||
*/
|
*/
|
||||||
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
|
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
|
||||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
|
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
|
||||||
return BufferedNDFieldElement(context, DoubleBuffer(array))
|
return BufferedNDFieldElement(context, RealBuffer(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
* Element by element application of any operation on elements to the whole array. Just like in numpy.
|
||||||
*/
|
*/
|
||||||
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement =
|
||||||
ndElement.map { this@invoke(it) }
|
ndElement.map { this@invoke(it) }
|
||||||
|
|
||||||
|
|
||||||
@ -111,13 +118,13 @@ operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
|||||||
/**
|
/**
|
||||||
* Summation operation for [BufferedNDElement] and single element
|
* Summation operation for [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun RealNDElement.plus(arg: Double) =
|
operator fun RealNDElement.plus(arg: Double): RealNDElement =
|
||||||
map { it + arg }
|
map { it + arg }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [BufferedNDElement] and single element
|
* Subtraction operation between [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun RealNDElement.minus(arg: Double) =
|
operator fun RealNDElement.minus(arg: Double): RealNDElement =
|
||||||
map { it - arg }
|
map { it - arg }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -0,0 +1,50 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized [MutableBuffer] implementation over [ShortArray].
|
||||||
|
*
|
||||||
|
* @property array the underlying array.
|
||||||
|
*/
|
||||||
|
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Short = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Short) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): ShortIterator = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Short> =
|
||||||
|
ShortBuffer(array.copyOf())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new [ShortBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
|
* [init] function.
|
||||||
|
*
|
||||||
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
|
* It should return the value for an buffer element given its index.
|
||||||
|
*/
|
||||||
|
inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a new [ShortBuffer] of given elements.
|
||||||
|
*/
|
||||||
|
fun ShortBuffer(vararg shorts: Short): ShortBuffer = ShortBuffer(shorts)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a [ShortArray] containing all of the elements of this [MutableBuffer].
|
||||||
|
*/
|
||||||
|
val MutableBuffer<out Short>.array: ShortArray
|
||||||
|
get() = (if (this is ShortBuffer) array else ShortArray(size) { get(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns [ShortBuffer] over this array.
|
||||||
|
*
|
||||||
|
* @receiver the array.
|
||||||
|
* @return the new buffer.
|
||||||
|
*/
|
||||||
|
fun ShortArray.asBuffer(): ShortBuffer = ShortBuffer(this)
|
@ -12,8 +12,8 @@ class ShortNDRing(override val shape: IntArray) :
|
|||||||
override val strides: Strides = DefaultStrides(shape)
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
override val elementContext: ShortRing get() = ShortRing
|
override val elementContext: ShortRing get() = ShortRing
|
||||||
override val zero by lazy { produce { ShortRing.zero } }
|
override val zero: ShortNDElement by lazy { produce { zero } }
|
||||||
override val one by lazy { produce { ShortRing.one } }
|
override val one: ShortNDElement by lazy { produce { one } }
|
||||||
|
|
||||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
|
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
|
||||||
ShortBuffer(ShortArray(size) { initializer(it) })
|
ShortBuffer(ShortArray(size) { initializer(it) })
|
||||||
@ -40,6 +40,7 @@ class ShortNDRing(override val shape: IntArray) :
|
|||||||
transform: ShortRing.(index: IntArray, Short) -> Short
|
transform: ShortRing.(index: IntArray, Short) -> Short
|
||||||
): ShortNDElement {
|
): ShortNDElement {
|
||||||
check(arg)
|
check(arg)
|
||||||
|
|
||||||
return BufferedNDRingElement(
|
return BufferedNDRingElement(
|
||||||
this,
|
this,
|
||||||
buildBuffer(arg.strides.linearSize) { offset ->
|
buildBuffer(arg.strides.linearSize) { offset ->
|
||||||
@ -67,7 +68,7 @@ class ShortNDRing(override val shape: IntArray) :
|
|||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fast element production using function inlining
|
* Fast element production using function inlining.
|
||||||
*/
|
*/
|
||||||
inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement {
|
inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement {
|
||||||
val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }
|
val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }
|
||||||
@ -75,22 +76,22 @@ inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initialize
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
* Element by element application of any operation on elements to the whole array.
|
||||||
*/
|
*/
|
||||||
operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement) =
|
operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement): ShortNDElement =
|
||||||
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
|
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
|
||||||
|
|
||||||
|
|
||||||
/* plus and minus */
|
/* plus and minus */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Summation operation for [StridedNDFieldElement] and single element
|
* Summation operation for [ShortNDElement] and single element.
|
||||||
*/
|
*/
|
||||||
operator fun ShortNDElement.plus(arg: Short) =
|
operator fun ShortNDElement.plus(arg: Short): ShortNDElement =
|
||||||
context.produceInline { i -> (buffer[i] + arg).toShort() }
|
context.produceInline { i -> (buffer[i] + arg).toShort() }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [StridedNDFieldElement] and single element
|
* Subtraction operation between [ShortNDElement] and single element.
|
||||||
*/
|
*/
|
||||||
operator fun ShortNDElement.minus(arg: Short) =
|
operator fun ShortNDElement.minus(arg: Short): ShortNDElement =
|
||||||
context.produceInline { i -> (buffer[i] - arg).toShort() }
|
context.produceInline { i -> (buffer[i] - arg).toShort() }
|
@ -39,14 +39,14 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
|||||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||||
asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
||||||
|
|
||||||
override fun get(index: Int): T = buffer.get(index)
|
override fun get(index: Int): T = buffer[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
*/
|
*/
|
||||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||||
if( this is NDBuffer){
|
if (this is NDBuffer) {
|
||||||
Buffer1DWrapper(this.buffer)
|
Buffer1DWrapper(this.buffer)
|
||||||
} else {
|
} else {
|
||||||
Structure1DWrapper(this)
|
Structure1DWrapper(this)
|
||||||
|
@ -14,7 +14,6 @@ interface Structure2D<T> : NDStructure<T> {
|
|||||||
return get(index[0], index[1])
|
return get(index[0], index[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
val rows: Buffer<Buffer<T>>
|
val rows: Buffer<Buffer<T>>
|
||||||
get() = VirtualBuffer(rowNum) { i ->
|
get() = VirtualBuffer(rowNum) { i ->
|
||||||
VirtualBuffer(colNum) { j -> get(i, j) }
|
VirtualBuffer(colNum) { j -> get(i, j) }
|
||||||
@ -33,9 +32,7 @@ interface Structure2D<T> : NDStructure<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -58,22 +55,4 @@ fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
|||||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise.
|
|
||||||
*/
|
|
||||||
fun <T> Structure2D<T>.as1D() = if (colNum == 1) {
|
|
||||||
object : Structure1D<T> {
|
|
||||||
override fun get(index: Int): T = get(index, 0)
|
|
||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(rowNum)
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = elements()
|
|
||||||
|
|
||||||
override val size: Int get() = rowNum
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
error("Can't convert matrix with more than one column to vector")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
typealias Matrix<T> = Structure2D<T>
|
typealias Matrix<T> = Structure2D<T>
|
@ -9,7 +9,7 @@ import kotlin.test.assertEquals
|
|||||||
class ExpressionFieldTest {
|
class ExpressionFieldTest {
|
||||||
@Test
|
@Test
|
||||||
fun testExpression() {
|
fun testExpression() {
|
||||||
val context = ExpressionField(RealField)
|
val context = FunctionalExpressionField(RealField)
|
||||||
val expression = with(context) {
|
val expression = with(context) {
|
||||||
val x = variable("x", 2.0)
|
val x = variable("x", 2.0)
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
@ -20,7 +20,7 @@ class ExpressionFieldTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testComplex() {
|
fun testComplex() {
|
||||||
val context = ExpressionField(ComplexField)
|
val context = FunctionalExpressionField(ComplexField)
|
||||||
val expression = with(context) {
|
val expression = with(context) {
|
||||||
val x = variable("x", Complex(2.0, 0.0))
|
val x = variable("x", Complex(2.0, 0.0))
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
@ -31,23 +31,23 @@ class ExpressionFieldTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun separateContext() {
|
fun separateContext() {
|
||||||
fun <T> ExpressionField<T>.expression(): Expression<T> {
|
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
||||||
val x = variable("x")
|
val x = variable("x")
|
||||||
return x * x + 2 * x + one
|
return x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = ExpressionField(RealField).expression()
|
val expression = FunctionalExpressionField(RealField).expression()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression("x" to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueExpression() {
|
fun valueExpression() {
|
||||||
val expressionBuilder: ExpressionField<Double>.() -> Expression<Double> = {
|
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
||||||
val x = variable("x")
|
val x = variable("x")
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = ExpressionField(RealField).expressionBuilder()
|
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression("x" to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -17,7 +17,7 @@ class MatrixTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBuilder() {
|
fun testBuilder() {
|
||||||
val matrix = Matrix.build<Double>(2, 3)(
|
val matrix = Matrix.build(2, 3)(
|
||||||
1.0, 0.0, 0.0,
|
1.0, 0.0, 0.0,
|
||||||
0.0, 1.0, 2.0
|
0.0, 1.0, 2.0
|
||||||
)
|
)
|
||||||
@ -49,17 +49,17 @@ class MatrixTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test2DDot() {
|
fun test2DDot() {
|
||||||
val firstMatrix = NDStructure.auto(2,3){ (i, j) -> (i + j).toDouble() }.as2D()
|
val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D()
|
||||||
val secondMatrix = NDStructure.auto(3,2){ (i, j) -> (i + j).toDouble() }.as2D()
|
val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D()
|
||||||
MatrixContext.real.run {
|
MatrixContext.real.run {
|
||||||
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
|
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
|
||||||
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
|
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
|
||||||
val result = firstMatrix dot secondMatrix
|
val result = firstMatrix dot secondMatrix
|
||||||
assertEquals(2, result.rowNum)
|
assertEquals(2, result.rowNum)
|
||||||
assertEquals(2, result.colNum)
|
assertEquals(2, result.colNum)
|
||||||
assertEquals(8.0, result[0,1])
|
assertEquals(8.0, result[0, 1])
|
||||||
assertEquals(8.0, result[1,0])
|
assertEquals(8.0, result[1, 0])
|
||||||
assertEquals(14.0, result[1,1])
|
assertEquals(14.0, result[1, 1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user