Merge remote-tracking branch 'origin/dev' into mp-samplers

# Conflicts:
#	kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt
#	kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt
#	kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt
#	kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt
This commit is contained in:
Iaroslav Postovalov 2020-09-10 23:42:33 +07:00
commit b5fa1ed6e8
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
145 changed files with 3503 additions and 1751 deletions

38
CHANGELOG.md Normal file
View File

@ -0,0 +1,38 @@
# KMath
## [Unreleased]
### Added
- Functional Expressions API
- Mathematical Syntax Tree, its interpreter and API
- String to MST parser (https://github.com/mipt-npm/kmath/pull/120)
- MST to JVM bytecode translator (https://github.com/mipt-npm/kmath/pull/94)
- FloatBuffer (specialized MutableBuffer over FloatArray)
- FlaggedBuffer to associate primitive numbers buffer with flags (to mark values infinite or missing, etc.)
- Specialized builder functions for all primitive buffers like `IntBuffer(25) { it + 1 }` (https://github.com/mipt-npm/kmath/pull/125)
- Interface `NumericAlgebra` where `number` operation is available to convert numbers to algebraic elements
- Inverse trigonometric functions support in ExtendedField (`asin`, `acos`, `atan`) (https://github.com/mipt-npm/kmath/pull/114)
- New space extensions: `average` and `averageWith`
- Local coding conventions
- Geometric Domains API in `kmath-core`
- Blocking chains in `kmath-coroutines`
- Full hyperbolic functions support and default implementations within `ExtendedField`
- Norm support for `Complex`
### Changed
- `readAsMemory` now has `throws IOException` in JVM signature.
- Several functions taking functional types were made `inline`.
- Several functions taking functional types now have `callsInPlace` contracts.
- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations
- `power(T, Int)` extension function has preconditions and supports `Field<T>`
- Memory objects have more preconditions (overflow checking)
- `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114)
- Gradle version: 6.3 -> 6.6
- Moved probability distributions to commons-rng and to `kmath-prob`
### Fixed
- Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106)
- D3.dim value in `kmath-dimensions`
- Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101)
- Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93)
- Multiplication of BigInt by scalar

View File

@ -5,12 +5,17 @@
Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion) Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion)
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion) Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
# KMath # KMath
Could be pronounced as `key-math`. Could be pronounced as `key-math`.
The Kotlin MATHematics library is intended as a Kotlin-based analog to Python's `numpy` library. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The Kotlin MATHematics library is intended as a Kotlin-based analog to Python's `numpy` library. In contrast to `numpy` and `scipy` it is modular and has a lightweight core.
## Publications
* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2)
* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814)
* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103)
# Goal # Goal
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future). * Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future).
* Provide basic multiplatform implementations for those abstractions (without significant performance optimization). * Provide basic multiplatform implementations for those abstractions (without significant performance optimization).

View File

@ -11,6 +11,7 @@ allprojects {
repositories { repositories {
jcenter() jcenter()
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/hotkeytlt/maven")
} }
group = "scientifik" group = "scientifik"

View File

@ -1,110 +1,124 @@
# Algebra and algebra elements # Algebraic Structures and Algebraic Elements
The mathematical operations in `kmath` are generally separated from mathematical objects. The mathematical operations in KMath are generally separated from mathematical objects. This means that to perform an
This means that in order to perform an operation, say `+`, one needs two objects of a type `T` and operation, say `+`, one needs two objects of a type `T` and an algebra context, which draws appropriate operation up,
and algebra context which defines appropriate operation, say `Space<T>`. Next one needs to run actual operation say `Space<T>`. Next one needs to run the actual operation in the context:
in the context:
```kotlin ```kotlin
val a: T import scientifik.kmath.operations.*
val b: T
val space: Space<T>
val c = space.run{a + b} val a: T = ...
val b: T = ...
val space: Space<T> = ...
val c = space { a + b }
``` ```
From the first glance, this distinction seems to be a needless complication, but in fact one needs At first glance, this distinction seems to be a needless complication, but in fact one needs to remember that in
to remember that in mathematics, one could define different operations on the same objects. For example, mathematics, one could draw up different operations on same objects. For example, one could use different types of
one could use different types of geometry for vectors. geometry for vectors.
## Algebra hierarchy ## Algebraic Structures
Mathematical contexts have the following hierarchy: Mathematical contexts have the following hierarchy:
**Space** <- **Ring** <- **Field** **Algebra** ← **Space****Ring** ← **Field**
All classes follow abstract mathematical constructs. These interfaces follow real algebraic structures:
[Space](http://mathworld.wolfram.com/Space.html) defines `zero` element, addition operation and multiplication by constant,
[Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and unit `one` element,
[Field](http://mathworld.wolfram.com/Field.html) adds division operation.
Typical case of `Field` is the `RealField` which works on doubles. And typical case of `Space` is a `VectorSpace`. - [Space](https://mathworld.wolfram.com/VectorSpace.html) defines addition, its neutral element (i.e. 0) and scalar
multiplication;
- [Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and its neutral element (i.e. 1);
- [Field](http://mathworld.wolfram.com/Field.html) adds division operation.
In some cases algebra context could hold additional operation like `exp` or `sin`, in this case it inherits appropriate A typical implementation of `Field<T>` is the `RealField` which works on doubles, and `VectorSpace` for `Space<T>`.
interface. Also a context could have an operation which produces an element outside of its context. For example
`Matrix` `dot` operation produces a matrix with new dimensions which can be incompatible with initial matrix in
terms of linear operations.
## Algebra element In some cases algebra context can hold additional operations like `exp` or `sin`, and then it inherits appropriate
interface. Also, contexts may have operations, which produce elements outside of the context. For example, `Matrix.dot`
operation produces a matrix with new dimensions, which can be incompatible with initial matrix in terms of linear
operations.
In order to achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving contexts ## Algebraic Element
`kmath` introduces special type objects called `MathElement`. A `MathElement` is basically some object coupled to
To achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving
contexts KMath submits special type objects called `MathElement`. A `MathElement` is basically some object coupled to
a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts, a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts,
but it also holds reference to the `ComplexField` singleton which allows to perform direct operations on `Complex` but it also holds reference to the `ComplexField` singleton, which allows performing direct operations on `Complex`
numbers without explicit involving the context like: numbers without explicit involving the context like:
```kotlin ```kotlin
val c1 = Complex(1.0, 1.0) import scientifik.kmath.operations.*
val c2 = Complex(1.0, -1.0)
val c3 = c1 + c2 + 3.0.toComplex() // Using elements
//or with field notation: val c1 = Complex(1.0, 1.0)
val c4 = ComplexField.run{c1 + i - 2.0} val c2 = Complex(1.0, -1.0)
val c3 = c1 + c2 + 3.0.toComplex()
// Using context
val c4 = ComplexField { c1 + i - 2.0 }
``` ```
Both notations have their pros and cons. Both notations have their pros and cons.
The hierarchy for algebra elements follows the hierarchy for the corresponding algebra. The hierarchy for algebraic elements follows the hierarchy for the corresponding algebraic structures.
**MathElement** <- **SpaceElement** <- **RingElement** <- **FieldElement** **MathElement** **SpaceElement****RingElement** ← **FieldElement**
**MathElement** is the generic common ancestor of the class with context. `MathElement<C>` is the generic common ancestor of the class with context.
One important distinction between algebra elements and algebra contexts is that algebra element has three type parameters: One major distinction between algebraic elements and algebraic contexts is that elements have three type
parameters:
1. The type of elements, field operates on. 1. The type of elements, the field operates on.
2. The self-type of the element returned from operation (must be algebra element). 2. The self-type of the element returned from operation (which has to be an algebraic element).
3. The type of the algebra over first type-parameter. 3. The type of the algebra over first type-parameter.
The middle type is needed in case algebra members do not store context. For example, it is not possible to add The middle type is needed for of algebra members do not store context. For example, it is impossible to add a context
a context to regular `Double`. The element performs automatic conversions from context types and back. to regular `Double`. The element performs automatic conversions from context types and back. One should use context
One should used context operations in all important places. The performance of element operations is not guaranteed. operations in all performance-critical places. The performance of element operations is not guaranteed.
## Spaces and fields ## Spaces and Fields
An obvious first choice of mathematical objects to implement in a context-oriented style are algebraic elements like spaces, KMath submits both contexts and elements for builtin algebraic structures:
rings and fields. Those are located in the `scientifik.kmath.operations.Algebra.kt` file. Alongside common contexts, the file includes definitions for algebra elements like `FieldElement`. A `FieldElement` object
stores a reference to the `Field` which contains additive and multiplicative operations, meaning
it has one fixed context attached and does not require explicit external context. So those `MathElements` can be operated without context:
```kotlin ```kotlin
import scientifik.kmath.operations.*
val c1 = Complex(1.0, 2.0) val c1 = Complex(1.0, 2.0)
val c2 = ComplexField.i val c2 = ComplexField.i
val c3 = c1 + c2 val c3 = c1 + c2
// or
val c3 = ComplexField { c1 + c2 }
``` ```
`ComplexField` also features special operations to mix complex and real numbers, for example: Also, `ComplexField` features special operations to mix complex and real numbers, for example:
```kotlin ```kotlin
import scientifik.kmath.operations.*
val c1 = Complex(1.0, 2.0) val c1 = Complex(1.0, 2.0)
val c2 = ComplexField.run{ c1 - 1.0} // Returns: [re:0.0, im: 2.0] val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0)
val c3 = ComplexField.run{ c1 - i*2.0} val c3 = ComplexField { c1 - i * 2.0 }
``` ```
**Note**: In theory it is possible to add behaviors directly to the context, but currently kotlin syntax does not support **Note**: In theory it is possible to add behaviors directly to the context, but as for now Kotlin does not support
that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and [KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates. that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and
[KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates.
## Nested fields ## Nested fields
Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex elements like so: Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex
elements like so:
```kotlin ```kotlin
val element = NDElement.complex(shape = intArrayOf(2,2)){ index: IntArray -> val element = NDElement.complex(shape = intArrayOf(2, 2)) { index: IntArray ->
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
} }
``` ```
The `element` in this example is a member of the `Field` of 2-d structures, each element of which is a member of its own The `element` in this example is a member of the `Field` of 2D structures, each element of which is a member of its own
`ComplexField`. The important thing is one does not need to create a special n-d class to hold complex `ComplexField`. It is important one does not need to create a special n-d class to hold complex
numbers and implement operations on it, one just needs to provide a field for its elements. numbers and implement operations on it, one just needs to provide a field for its elements.
**Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts like **Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts like

View File

@ -1,4 +1,5 @@
# Buffers # Buffers
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
There are different types of buffers: There are different types of buffers:
@ -12,4 +13,5 @@ Some kmath features require a `BufferFactory` class to operate properly. A gener
buffer for given reified type (for types with custom memory buffer it still better to use their own `MemoryBuffer.create()` factory). buffer for given reified type (for types with custom memory buffer it still better to use their own `MemoryBuffer.create()` factory).
## Buffer performance ## Buffer performance
One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead

34
doc/codestyle.md Normal file
View File

@ -0,0 +1,34 @@
# Coding Conventions
KMath code follows general [Kotlin conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but
with a number of small changes and clarifications.
## Utility Class Naming
Filename should coincide with a name of one of the classes contained in the file or start with small letter and
describe its contents.
The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that
file names should start with a capital letter even if file does not contain classes. Yet starting utility classes and
aggregators with a small letter seems to be a good way to visually separate those files.
This convention could be changed in future in a non-breaking way.
## Private Variable Naming
Private variables' names may start with underscore `_` for of the private mutable variable is shadowed by the public
read-only value with the same meaning.
This rule does not permit underscores in names, but it is sometimes useful to "underscore" the fact that public and
private versions draw up the same entity. It is allowed only for private variables.
This convention could be changed in future in a non-breaking way.
## Functions and Properties One-liners
Use one-liners when they occupy single code window line both for functions and properties with getters like
`val b: String get() = "fff"`. The same should be performed with multiline expressions when they could be
cleanly separated.
There is no universal consensus whenever use `fun a() = ...` or `fun a() { return ... }`. Yet from reader outlook
one-lines seem to better show that the property or function is easily calculated.

View File

@ -1,6 +1,6 @@
## Basic linear algebra layout ## Basic linear algebra layout
Kmath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared KMath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared
in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple
back-ends. The new operations added as extensions to contexts instead of being member functions of data structures. back-ends. The new operations added as extensions to contexts instead of being member functions of data structures.

View File

@ -1,4 +1,4 @@
# Nd-structure generation and operations # ND-structure generation and operations
**TODO** **TODO**

View File

@ -1,22 +0,0 @@
# Local coding conventions
Kmath and other `scientifik` projects use general [kotlin code conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but with a number of small changes and clarifications.
## Utility class names
File name should coincide with a name of one of the classes contained in the file or start with small letter and describe its contents.
The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that file names should start with capital letter even if file does not contain classes. Yet starting utility classes and aggregators with a small letter seems to be a good way to clearly visually separate those files.
This convention could be changed in future in a non-breaking way.
## Private variable names
Private variable names could start with underscore `_` in case the private mutable variable is shadowed by the public read-only value with the same meaning.
Code convention do not permit underscores in names, but is is sometimes useful to "underscore" the fact that public and private versions define the same entity. It is allowed only for private variables.
This convention could be changed in future in a non-breaking way.
## Functions and properties one-liners
Use one-liners when they occupy single code window line both for functions and properties with getters like `val b: String get() = "fff"`. The same should be done with multiline expressions when they could be cleanly separated.
There is not general consensus whenever use `fun a() = {}` or `fun a(){return}`. Yet from reader perspective one-lines seem to better show that the property or function is easily calculated.

View File

@ -56,9 +56,16 @@ benchmark {
} }
} }
kotlin.sourceSets.all {
with(languageSettings) {
useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts")
useExperimentalAnnotation("kotlin.ExperimentalUnsignedTypes")
}
}
tasks.withType<KotlinCompile> { tasks.withType<KotlinCompile> {
kotlinOptions { kotlinOptions {
jvmTarget = Scientifik.JVM_TARGET.toString() jvmTarget = Scientifik.JVM_TARGET.toString()
freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn"
} }
} }

View File

@ -4,46 +4,38 @@ import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
@State(Scope.Benchmark) @State(Scope.Benchmark)
class NDFieldBenchmark { class NDFieldBenchmark {
@Benchmark @Benchmark
fun autoFieldAdd() { fun autoFieldAdd() {
bufferedField.run { bufferedField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += one }
res += one
}
} }
} }
@Benchmark @Benchmark
fun autoElementAdd() { fun autoElementAdd() {
var res = genericField.one var res = genericField.one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
@Benchmark @Benchmark
fun specializedFieldAdd() { fun specializedFieldAdd() {
specializedField.run { specializedField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
} }
@Benchmark @Benchmark
fun boxingFieldAdd() { fun boxingFieldAdd() {
genericField.run { genericField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += one }
res += one
}
} }
} }

View File

@ -5,23 +5,22 @@ import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.viktor.ViktorNDField import scientifik.kmath.viktor.ViktorNDField
@State(Scope.Benchmark) @State(Scope.Benchmark)
class ViktorBenchmark { class ViktorBenchmark {
final val dim = 1000 final val dim = 1000
final val n = 100 final val n = 100
// automatically build context most suited for given type. // automatically build context most suited for given type.
final val autoField = NDField.auto(RealField, dim, dim) final val autoField: BufferedNDField<Double, RealField> = NDField.auto(RealField, dim, dim)
final val realField = NDField.real(dim, dim) final val realField: RealNDField = NDField.real(dim, dim)
final val viktorField: ViktorNDField = ViktorNDField(intArrayOf(dim, dim))
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
@Benchmark @Benchmark
fun automaticFieldAddition() { fun automaticFieldAddition() {
autoField.run { autoField {
var res = one var res = one
repeat(n) { res += one } repeat(n) { res += one }
} }
@ -29,7 +28,7 @@ class ViktorBenchmark {
@Benchmark @Benchmark
fun viktorFieldAddition() { fun viktorFieldAddition() {
viktorField.run { viktorField {
var res = one var res = one
repeat(n) { res += one } repeat(n) { res += one }
} }
@ -44,7 +43,7 @@ class ViktorBenchmark {
@Benchmark @Benchmark
fun realdFieldLog() { fun realdFieldLog() {
realField.run { realField {
val fortyTwo = produce { 42.0 } val fortyTwo = produce { 42.0 }
var res = one var res = one
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }

View File

@ -1,8 +1,11 @@
package scientifik.kmath.utils package scientifik.kmath.utils
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
internal inline fun measureAndPrint(title: String, block: () -> Unit) { internal inline fun measureAndPrint(title: String, block: () -> Unit) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
val time = measureTimeMillis(block) val time = measureTimeMillis(block)
println("$title completed in $time millis") println("$title completed in $time millis")
} }

View File

@ -5,6 +5,7 @@ import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.inverse import scientifik.kmath.commons.linear.inverse
import scientifik.kmath.commons.linear.toCM import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import kotlin.contracts.ExperimentalContracts import kotlin.contracts.ExperimentalContracts
import kotlin.random.Random import kotlin.random.Random
@ -21,29 +22,18 @@ fun main() {
val n = 5000 // iterations val n = 5000 // iterations
MatrixContext.real.run { MatrixContext.real {
repeat(50) { val res = inverse(matrix) }
repeat(50) { val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } }
val res = inverse(matrix)
}
val inverseTime = measureTimeMillis {
repeat(n) {
val res = inverse(matrix)
}
}
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis") println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
} }
//commons-math //commons-math
val commonsTime = measureTimeMillis { val commonsTime = measureTimeMillis {
CMMatrixContext.run { CMMatrixContext {
val cm = matrix.toCM() //avoid overhead on conversion val cm = matrix.toCM() //avoid overhead on conversion
repeat(n) { repeat(n) { val res = inverse(cm) }
val res = inverse(cm)
}
} }
} }
@ -53,7 +43,7 @@ fun main() {
//koma-ejml //koma-ejml
val komaTime = measureTimeMillis { val komaTime = measureTimeMillis {
KomaMatrixContext(EJMLMatrixFactory(), RealField).run { (KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
val km = matrix.toKoma() //avoid overhead on conversion val km = matrix.toKoma() //avoid overhead on conversion
repeat(n) { repeat(n) {
val res = inverse(km) val res = inverse(km)

View File

@ -4,6 +4,7 @@ import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.commons.linear.CMMatrixContext import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.toCM import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import kotlin.random.Random import kotlin.random.Random
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -18,7 +19,7 @@ fun main() {
// //warmup // //warmup
// matrix1 dot matrix2 // matrix1 dot matrix2
CMMatrixContext.run { CMMatrixContext {
val cmMatrix1 = matrix1.toCM() val cmMatrix1 = matrix1.toCM()
val cmMatrix2 = matrix2.toCM() val cmMatrix2 = matrix2.toCM()
@ -29,8 +30,7 @@ fun main() {
println("CM implementation time: $cmTime") println("CM implementation time: $cmTime")
} }
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
val komaMatrix1 = matrix1.toKoma() val komaMatrix1 = matrix1.toKoma()
val komaMatrix2 = matrix2.toKoma() val komaMatrix2 = matrix2.toKoma()

View File

@ -9,13 +9,11 @@ fun main() {
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
} }
val compute = (NDField.complex(8)) {
val compute = NDField.complex(8).run {
val a = produce { (it) -> i * it - it.toDouble() } val a = produce { (it) -> i * it - it.toDouble() }
val b = 3 val b = 3
val c = Complex(1.0, 1.0) val c = Complex(1.0, 1.0)
(a pow b) + c (a pow b) + c
} }
} }

View File

@ -13,9 +13,8 @@ fun main() {
val realField = NDField.real(dim, dim) val realField = NDField.real(dim, dim)
val complexField = NDField.complex(dim, dim) val complexField = NDField.complex(dim, dim)
val realTime = measureTimeMillis { val realTime = measureTimeMillis {
realField.run { realField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
@ -26,18 +25,15 @@ fun main() {
println("Real addition completed in $realTime millis") println("Real addition completed in $realTime millis")
val complexTime = measureTimeMillis { val complexTime = measureTimeMillis {
complexField.run { complexField {
var res: NDBuffer<Complex> = one var res: NDBuffer<Complex> = one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
} }
println("Complex addition completed in $complexTime millis") println("Complex addition completed in $complexTime millis")
} }
fun complexExample() { fun complexExample() {
//Create a context for 2-d structure with complex values //Create a context for 2-d structure with complex values
ComplexField { ComplexField {
@ -46,10 +42,7 @@ fun complexExample() {
val x = one * 2.5 val x = one * 2.5
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im) operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
//a structure generator specific to this context //a structure generator specific to this context
val matrix = produce { (k, l) -> val matrix = produce { (k, l) -> k + l * i }
k + l * i
}
//Perform sum //Perform sum
val sum = matrix + x + 1.0 val sum = matrix + x + 1.0

View File

@ -2,14 +2,18 @@ package scientifik.kmath.structures
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
internal inline fun measureAndPrint(title: String, block: () -> Unit) { internal inline fun measureAndPrint(title: String, block: () -> Unit) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
val time = measureTimeMillis(block) val time = measureTimeMillis(block)
println("$title completed in $time millis") println("$title completed in $time millis")
} }
fun main() { fun main() {
val dim = 1000 val dim = 1000
val n = 1000 val n = 1000
@ -22,27 +26,21 @@ fun main() {
val genericField = NDField.boxing(RealField, dim, dim) val genericField = NDField.boxing(RealField, dim, dim)
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField.run { autoField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += number(1.0) }
res += number(1.0)
}
} }
} }
measureAndPrint("Element addition") { measureAndPrint("Element addition") {
var res = genericField.one var res = genericField.one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
measureAndPrint("Specialized addition") { measureAndPrint("Specialized addition") {
specializedField.run { specializedField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) { res += 1.0 }
res += 1.0
}
} }
} }
@ -60,12 +58,11 @@ fun main() {
measureAndPrint("Generic addition") { measureAndPrint("Generic addition") {
//genericField.run(action) //genericField.run(action)
genericField.run { genericField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += one // con't avoid using `one` due to resolution ambiguity res += one // couldn't avoid using `one` due to resolution ambiguity }
} }
} }
} }
} }

View File

@ -23,13 +23,10 @@ fun DMatrixContext<Double, RealField>.custom() {
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() } val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() } val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() } val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
(m1 dot m2) + m3 (m1 dot m2) + m3
} }
fun main() { fun main(): Unit = with(DMatrixContext.real) {
DMatrixContext.real.run {
simple() simple()
custom() custom()
}
} }

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -1,4 +1,4 @@
# Abstract syntax tree expression representation and operations (`kmath-ast`) # Abstract Syntax Tree Expression Representation and Operations (`kmath-ast`)
This subproject implements the following features: This subproject implements the following features:
@ -38,7 +38,7 @@ This subproject implements the following features:
> ``` > ```
> >
## Dynamic expression code generation with ObjectWeb ASM ## Dynamic Expression Code Generation with ObjectWeb ASM
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
a special implementation of `Expression<T>` with implemented `invoke` function. a special implementation of `Expression<T>` with implemented `invoke` function.
@ -46,7 +46,7 @@ a special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder: For example, the following builder:
```kotlin ```kotlin
RealField.mstInField { symbol("x") + 2 }.compile() RealField.mstInField { symbol("x") + 2 }.compile()
``` ```
… leads to generation of bytecode, which can be decompiled to the following Java class: … leads to generation of bytecode, which can be decompiled to the following Java class:
@ -75,7 +75,7 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression<Doub
### Example Usage ### Example Usage
This API is an extension to MST and MstExpression, so you may optimize as both of them: This API extends MST and MstExpression, so you may optimize as both of them:
```kotlin ```kotlin
RealField.mstInField { symbol("x") + 2 }.compile() RealField.mstInField { symbol("x") + 2 }.compile()

View File

@ -1,37 +1,20 @@
plugins { plugins { id("scientifik.mpp") }
id("scientifik.mpp")
}
repositories {
maven("https://dl.bintray.com/hotkeytlt/maven")
}
kotlin.sourceSets { kotlin.sourceSets {
// all { all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") }
// languageSettings.apply{
// enableLanguageFeature("NewInference")
// }
// }
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
} }
} }
jvmMain { jvmMain {
dependencies { dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
implementation("org.ow2.asm:asm:8.0.1") implementation("org.ow2.asm:asm:8.0.1")
implementation("org.ow2.asm:asm-commons:8.0.1") implementation("org.ow2.asm:asm-commons:8.0.1")
implementation(kotlin("reflect")) implementation(kotlin("reflect"))
} }
} }
jsMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
}
}
} }

View File

@ -0,0 +1,59 @@
grammar ArithmeticsEvaluator;
fragment DIGIT: '0'..'9';
fragment LETTER: 'a'..'z';
fragment CAPITAL_LETTER: 'A'..'Z';
fragment UNDERSCORE: '_';
ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*;
NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?;
MUL: '*';
DIV: '/';
PLUS: '+';
MINUS: '-';
POW: '^';
COMMA: ',';
LPAR: '(';
RPAR: ')';
WS: [ \n\t\r]+ -> skip;
num
: NUM
;
singular
: ID
;
unaryFunction
: ID LPAR subSumChain RPAR
;
binaryFunction
: ID LPAR subSumChain COMMA subSumChain RPAR
;
term
: num
| singular
| unaryFunction
| binaryFunction
| MINUS term
| LPAR subSumChain RPAR
;
powChain
: term (POW term)*
;
divMulChain
: powChain ((DIV | MUL) powChain)*
;
subSumChain
: divMulChain ((PLUS | MINUS) divMulChain)*
;
rootParser
: subSumChain EOF
;

View File

@ -5,63 +5,83 @@ import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
/** /**
* A Mathematical Syntax Tree node for mathematical expressions * A Mathematical Syntax Tree node for mathematical expressions.
*/ */
sealed class MST { sealed class MST {
/** /**
* A node containing unparsed string * A node containing raw string.
*
* @property value the value of this node.
*/ */
data class Symbolic(val value: String) : MST() data class Symbolic(val value: String) : MST()
/** /**
* A node containing a number * A node containing a numeric value or scalar.
*
* @property value the value of this number.
*/ */
data class Numeric(val value: Number) : MST() data class Numeric(val value: Number) : MST()
/** /**
* A node containing an unary operation * A node containing an unary operation.
*
* @property operation the identifier of operation.
* @property value the argument of this operation.
*/ */
data class Unary(val operation: String, val value: MST) : MST() { data class Unary(val operation: String, val value: MST) : MST() {
companion object { companion object
const val ABS_OPERATION = "abs"
//TODO add operations
}
} }
/** /**
* A node containing binary operation * A node containing binary operation.
*
* @property operation the identifier operation.
* @property left the left operand.
* @property right the right operand.
*/ */
data class Binary(val operation: String, val left: MST, val right: MST) : MST() { data class Binary(val operation: String, val left: MST, val right: MST) : MST() {
companion object companion object
} }
} }
//TODO add a function with positional arguments // TODO add a function with named arguments
//TODO add a function with named arguments /**
* Interprets the [MST] node with this [Algebra].
fun <T> Algebra<T>.evaluate(node: MST): T { *
return when (node) { * @receiver the algebra that provides operations.
* @param node the node to evaluate.
* @return the value of expression.
*/
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when { is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation( val number = RealField.binaryOperation(
node.operation, node.operation,
node.left.value.toDouble(), node.left.value.toDouble(),
node.right.value.toDouble() node.right.value.toDouble()
) )
number(number) number(number)
} }
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
} }
}
} }
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this) /**
* Interprets the [MST] node with this [Algebra].
*
* @receiver the node to evaluate.
* @param algebra the algebra that provides operations.
* @return the value of expression.
*/
fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -2,6 +2,9 @@ package scientifik.kmath.ast
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
/**
* [Algebra] over [MST] nodes.
*/
object MstAlgebra : NumericAlgebra<MST> { object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value) override fun number(value: Number): MST = MST.Numeric(value)
@ -14,17 +17,16 @@ object MstAlgebra : NumericAlgebra<MST> {
MST.Binary(operation, left, right) MST.Binary(operation, left, right)
} }
/**
* [Space] over [MST] nodes.
*/
object MstSpace : Space<MST>, NumericAlgebra<MST> { object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0) override val zero: MST = number(0.0)
override fun number(value: Number): MST = MstAlgebra.number(value) override fun number(value: Number): MST = MstAlgebra.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun symbol(value: String): MST = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun add(a: MST, b: MST): MST = override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right) MstAlgebra.binaryOperation(operation, left, right)
@ -32,41 +34,69 @@ object MstSpace : Space<MST>, NumericAlgebra<MST> {
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
} }
/**
* [Ring] over [MST] nodes.
*/
object MstRing : Ring<MST>, NumericAlgebra<MST> { object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0) override val zero: MST = number(0.0)
override val one: MST = number(1.0) override val one: MST = number(1.0)
override fun number(value: Number): MST = MstAlgebra.number(value) override fun number(value: Number): MST = MstSpace.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun symbol(value: String): MST = MstSpace.symbol(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST = override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right) MstSpace.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
} }
/**
* [Field] over [MST] nodes.
*/
object MstField : Field<MST> { object MstField : Field<MST> {
override val zero: MST = number(0.0) override val zero: MST = number(0.0)
override val one: MST = number(1.0) override val one: MST = number(1.0)
override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun symbol(value: String): MST = MstRing.symbol(value)
override fun number(value: Number): MST = MstAlgebra.number(value) override fun number(value: Number): MST = MstRing.number(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
override fun multiply(a: MST, k: Number): MST = override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right) MstRing.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
}
/**
* [ExtendedField] over [MST] nodes.
*/
object MstExtendedField : ExtendedField<MST> {
override val zero: MST = number(0.0)
override val one: MST = number(1.0)
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstField.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg)
} }

View File

@ -1,19 +1,19 @@
package scientifik.kmath.ast package scientifik.kmath.ast
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.*
import scientifik.kmath.expressions.FunctionalExpressionField
import scientifik.kmath.expressions.FunctionalExpressionRing
import scientifik.kmath.expressions.FunctionalExpressionSpace
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
* ASM-generated expressions.
*
* @property algebra the algebra that provides operations.
* @property mst the [MST] node.
*/ */
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> { class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
/**
* Substitute algebra raw value
*/
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
@ -27,29 +27,77 @@ class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
error("Numeric nodes are not supported by $this") error("Numeric nodes are not supported by $this")
} }
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst) override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
} }
/**
* Builds [MstExpression] over [Algebra].
*/
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst( inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E, mstAlgebra: E,
block: E.() -> MST block: E.() -> MST
): MstExpression<T> = MstExpression(this, mstAlgebra.block()) ): MstExpression<T> = MstExpression(this, mstAlgebra.block())
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> = /**
MstExpression(this, MstSpace.block()) * Builds [MstExpression] over [Space].
*/
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstSpace.block())
}
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> = /**
MstExpression(this, MstRing.block()) * Builds [MstExpression] over [Ring].
*/
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block())
}
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> = /**
MstExpression(this, MstField.block()) * Builds [MstExpression] over [Field].
*/
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block())
}
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> = /**
algebra.mstInSpace(block) * Builds [MstExpression] over [ExtendedField].
*/
inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block())
}
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> = /**
algebra.mstInRing(block) * Builds [MstExpression] over [FunctionalExpressionSpace].
*/
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInSpace(block)
}
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> = /**
algebra.mstInField(block) * Builds [MstExpression] over [FunctionalExpressionRing].
*/
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionField].
*/
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
*/
inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block)
}

View File

@ -5,6 +5,9 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.lexer.Token
import com.github.h0tk3y.betterParse.lexer.TokenMatch
import com.github.h0tk3y.betterParse.lexer.regexToken
import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations import scientifik.kmath.operations.FieldOperations
@ -13,47 +16,82 @@ import scientifik.kmath.operations.RingOperations
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
/** /**
* TODO move to common * TODO move to core
*/ */
private object ArithmeticsEvaluator : Grammar<MST>() { object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) // TODO replace with "...".toRegex() when better-parse 0.4.1 is released
val lpar by token("\\(".toRegex()) private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
val rpar by token("\\)".toRegex()) private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
val mul by token("\\*".toRegex()) private val lpar: Token by regexToken("\\(")
val pow by token("\\^".toRegex()) private val rpar: Token by regexToken("\\)")
val div by token("/".toRegex()) private val comma: Token by regexToken(",")
val minus by token("-".toRegex()) private val mul: Token by regexToken("\\*")
val plus by token("\\+".toRegex()) private val pow: Token by regexToken("\\^")
val ws by token("\\s+".toRegex(), ignore = true) private val div: Token by regexToken("/")
private val minus: Token by regexToken("-")
private val plus: Token by regexToken("\\+")
private val ws: Token by regexToken("\\s+", ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) } private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
val term: Parser<MST> by number or private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or .map { (id, term) -> MST.Unary(id.text, term) }
(skip(lpar) and parser(this::rootParser) and skip(rpar))
val powChain by leftAssociative(term, pow) { a, _, b -> private val binaryFunction: Parser<MST> by id
.and(skip(lpar))
.and(parser(::subSumChain))
.and(skip(comma))
.and(parser(::subSumChain))
.and(skip(rpar))
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
private val term: Parser<MST> by number
.or(binaryFunction)
.or(unaryFunction)
.or(singular)
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b) MST.Binary(PowerOperations.POW_OPERATION, a, b)
} }
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b -> private val divMulChain: Parser<MST> by leftAssociative(
if (op == div) { term = powChain,
operator = div or mul use TokenMatch::type
) { a, op, b ->
if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b) MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else { else
MST.Binary(RingOperations.TIMES_OPERATION, a, b) MST.Binary(RingOperations.TIMES_OPERATION, a, b)
} }
}
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b -> private val subSumChain: Parser<MST> by leftAssociative(
if (op == plus) { term = divMulChain,
operator = plus or minus use TokenMatch::type
) { a, op, b ->
if (op == plus)
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else { else
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
} }
}
override val rootParser: Parser<MST> by subSumChain override val rootParser: Parser<MST> by subSumChain
} }
/**
* Tries to parse the string into [MST].
*
* @receiver the string to parse.
* @return the [MST] node.
*/
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this) fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
/**
* Parses the string into [MST].
*
* @receiver the string to parse.
* @return the [MST] node.
*/
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)

View File

@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MstExpression import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> { fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: Throwable) {
null
}
if (symbol != null)
loadTConstant(symbol)
else
loadVariable(node.value)
}
is MST.Numeric -> loadNumeric(node.value) is MST.Numeric -> loadNumeric(node.value)
is MST.Unary -> buildAlgebraOperationCall( is MST.Unary -> buildAlgebraOperationCall(

View File

@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Loads a [T] constant from [constants]. * Loads a [T] constant from [constants].
*/ */
private fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop() val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT val mustBeBoxed = expectedType.sort == Type.OBJECT
@ -340,7 +340,7 @@ internal class AsmBuilder<T> internal constructor(
checkcast(type) checkcast(type)
} }
fun loadNumeric(value: Number) { internal fun loadNumeric(value: Number) {
if (expectationStack.peek() == NUMBER_TYPE) { if (expectationStack.peek() == NUMBER_TYPE) {
loadNumberConstant(value, true) loadNumberConstant(value, true)
expectationStack.pop() expectationStack.pop()

View File

@ -7,6 +7,9 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import java.lang.reflect.Method import java.lang.reflect.Method
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.reflect.KClass import kotlin.reflect.KClass
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy { private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
@ -26,8 +29,10 @@ internal val KClass<*>.asm: Type
/** /**
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
*/ */
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> = internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> {
if (predicate(this)) arrayOf(this) else emptyArray() contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) }
return if (predicate(this)) arrayOf(this) else emptyArray()
}
/** /**
* Creates an [InstructionAdapter] from this [MethodVisitor]. * Creates an [InstructionAdapter] from this [MethodVisitor].
@ -37,8 +42,10 @@ private fun MethodVisitor.instructionAdapter(): InstructionAdapter = Instruction
/** /**
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
*/ */
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter {
instructionAdapter().apply(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return instructionAdapter().apply(block)
}
/** /**
* Constructs a [Label], then applies it to this visitor. * Constructs a [Label], then applies it to this visitor.
@ -64,8 +71,10 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
} }
@Suppress("FunctionName") @Suppress("FunctionName")
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter {
ClassWriter(flags).apply(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return ClassWriter(flags).apply(block)
}
internal inline fun ClassWriter.visitField( internal inline fun ClassWriter.visitField(
access: Int, access: Int,
@ -74,7 +83,10 @@ internal inline fun ClassWriter.visitField(
signature: String?, signature: String?,
value: Any?, value: Any?,
block: FieldVisitor.() -> Unit block: FieldVisitor.() -> Unit
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) ): FieldVisitor {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return visitField(access, name, descriptor, signature, value).apply(block)
}
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? = private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
context.javaClass.methods.find { method -> context.javaClass.methods.find { method ->
@ -158,6 +170,7 @@ internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
parameterTypes: Array<MstType>, parameterTypes: Array<MstType>,
parameters: AsmBuilder<T>.() -> Unit parameters: AsmBuilder<T>.() -> Unit
) { ) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val arity = parameterTypes.size val arity = parameterTypes.size
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)

View File

@ -0,0 +1,36 @@
package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.parseMath
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserPrecedenceTest {
private val f: Field<Double> = RealField
@Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
@Test
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
@Test
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
@Test
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
@Test
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
@Test
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
@Test
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
@Test
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
}

View File

@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -22,4 +24,37 @@ internal class ParserTest {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
@Test
fun `evaluate MST with singular`() {
val mst = "i".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(ComplexField.i, res)
}
@Test
fun `evaluate MST with unary function`() {
val mst = "sin(0)".parseMath()
val res = RealField.evaluate(mst)
assertEquals(0.0, res)
}
@Test
fun `evaluate MST with binary function`() {
val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
}
}
val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst)
assertEquals("a ★ b", res)
}
} }

View File

@ -1,7 +1,4 @@
plugins { plugins { id("scientifik.jvm") }
id("scientifik.jvm")
}
description = "Commons math binding for kmath" description = "Commons math binding for kmath"
dependencies { dependencies {
@ -11,3 +8,5 @@ dependencies {
api(project(":kmath-functions")) api(project(":kmath-functions"))
api("org.apache.commons:commons-math3:3.6.1") api("org.apache.commons:commons-math3:3.6.1")
} }
kotlin.sourceSets.all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") }

View File

@ -5,6 +5,7 @@ import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
import kotlin.properties.ReadOnlyProperty import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KProperty import kotlin.reflect.KProperty
@ -15,26 +16,22 @@ class DerivativeStructureField(
val order: Int, val order: Int,
val parameters: Map<String, Double> val parameters: Map<String, Double>
) : ExtendedField<DerivativeStructure> { ) : ExtendedField<DerivativeStructure> {
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) -> private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
} }
val variable = object : ReadOnlyProperty<Any?, DerivativeStructure> { val variable: ReadOnlyProperty<Any?, DerivativeStructure> = object : ReadOnlyProperty<Any?, DerivativeStructure> {
override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure { override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure =
return variables[property.name] ?: error("A variable with name ${property.name} does not exist") variables[property.name] ?: error("A variable with name ${property.name} does not exist")
}
} }
fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure = fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
variables[name] ?: default ?: error("A variable with name $name does not exist") variables[name] ?: default ?: error("A variable with name $name does not exist")
fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
fun Number.const() = DerivativeStructure(order, parameters.size, toDouble())
fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
return deriv(mapOf(parName to order)) return deriv(mapOf(parName to order))
@ -60,10 +57,18 @@ class DerivativeStructureField(
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow) is Double -> arg.pow(pow)
is Int -> arg.pow(pow) is Int -> arg.pow(pow)
@ -71,23 +76,20 @@ class DerivativeStructureField(
} }
fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
override operator fun Number.plus(b: DerivativeStructure) = b + this override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
override operator fun Number.minus(b: DerivativeStructure) = b - this override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
} }
/** /**
* A constructs that creates a derivative structure with required order on-demand * A constructs that creates a derivative structure with required order on-demand
*/ */
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> { class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
override fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
0, 0,
arguments arguments
).run(function).value ).run(function).value
@ -96,45 +98,40 @@ class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStru
* Get the derivative expression with given orders * Get the derivative expression with given orders
* TODO make result [DiffExpression] * TODO make result [DiffExpression]
*/ */
fun derivative(orders: Map<String, Int>): Expression<Double> { fun derivative(orders: Map<String, Int>): Expression<Double> = object : Expression<Double> {
return object : Expression<Double> { override operator fun invoke(arguments: Map<String, Double>): Double =
override fun invoke(arguments: Map<String, Double>): Double = (DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) }
DerivativeStructureField(orders.values.max() ?: 0, arguments)
.run {
function().deriv(orders)
}
}
} }
//TODO add gradient and maybe other vector operators //TODO add gradient and maybe other vector operators
} }
fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(mapOf(*orders)) fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
fun DiffExpression.derivative(name: String) = derivative(name to 1) fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
/** /**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure]) * A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/ */
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> { object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
override fun variable(name: String, default: Double?) = override fun variable(name: String, default: Double?): DiffExpression =
DiffExpression { variable(name, default?.const()) } DiffExpression { variable(name, default?.const()) }
override fun const(value: Double): DiffExpression = override fun const(value: Double): DiffExpression =
DiffExpression { value.const() } DiffExpression { value.const() }
override fun add(a: DiffExpression, b: DiffExpression) = override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) + b.function(this) } DiffExpression { a.function(this) + b.function(this) }
override val zero = DiffExpression { 0.0.const() } override val zero: DiffExpression = DiffExpression { 0.0.const() }
override fun multiply(a: DiffExpression, k: Number) = override fun multiply(a: DiffExpression, k: Number): DiffExpression =
DiffExpression { a.function(this) * k } DiffExpression { a.function(this) * k }
override val one = DiffExpression { 1.0.const() } override val one: DiffExpression = DiffExpression { 1.0.const() }
override fun multiply(a: DiffExpression, b: DiffExpression) = override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) * b.function(this) } DiffExpression { a.function(this) * b.function(this) }
override fun divide(a: DiffExpression, b: DiffExpression) = override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) / b.function(this) } DiffExpression { a.function(this) / b.function(this) }
} }

View File

@ -1,8 +1,6 @@
package scientifik.kmath.commons.linear package scientifik.kmath.commons.linear
import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.*
import org.apache.commons.math3.linear.RealMatrix
import org.apache.commons.math3.linear.RealVector
import scientifik.kmath.linear.* import scientifik.kmath.linear.*
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
@ -14,12 +12,12 @@ class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> { override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
if (origin is DiagonalMatrix) yield(DiagonalFeature) if (origin is DiagonalMatrix) yield(DiagonalFeature)
}.toSet() }.toHashSet()
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature): CMMatrix =
CMMatrix(origin, this.features + features) CMMatrix(origin, this.features + features)
override fun get(i: Int, j: Int): Double = origin.getEntry(i, j) override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
@ -40,24 +38,22 @@ fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
CMMatrix(Array2DRowRealMatrix(array)) CMMatrix(Array2DRowRealMatrix(array))
} }
fun RealMatrix.asMatrix() = CMMatrix(this) fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
class CMVector(val origin: RealVector) : Point<Double> { class CMVector(val origin: RealVector) : Point<Double> {
override val size: Int get() = origin.dimension override val size: Int get() = origin.dimension
override fun get(index: Int): Double = origin.getEntry(index) override operator fun get(index: Int): Double = origin.getEntry(index)
override fun iterator(): Iterator<Double> = origin.toArray().iterator() override operator fun iterator(): Iterator<Double> = origin.toArray().iterator()
} }
fun Point<Double>.toCM(): CMVector = if (this is CMVector) { fun Point<Double>.toCM(): CMVector = if (this is CMVector) this else {
this
} else {
val array = DoubleArray(size) { this[it] } val array = DoubleArray(size) { this[it] }
CMVector(ArrayRealVector(array)) CMVector(ArrayRealVector(array))
} }
fun RealVector.toPoint() = CMVector(this) fun RealVector.toPoint(): CMVector = CMVector(this)
object CMMatrixContext : MatrixContext<Double> { object CMMatrixContext : MatrixContext<Double> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
@ -65,30 +61,31 @@ object CMMatrixContext : MatrixContext<Double> {
return CMMatrix(Array2DRowRealMatrix(array)) return CMMatrix(Array2DRowRealMatrix(array))
} }
override fun Matrix<Double>.dot(other: Matrix<Double>) = override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector = override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin)) CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
override fun Matrix<Double>.unaryMinus(): CMMatrix = override operator fun Matrix<Double>.unaryMinus(): CMMatrix =
produce(rowNum, colNum) { i, j -> -get(i, j) } produce(rowNum, colNum) { i, j -> -get(i, j) }
override fun add(a: Matrix<Double>, b: Matrix<Double>) = override fun add(a: Matrix<Double>, b: Matrix<Double>): CMMatrix =
CMMatrix(a.toCM().origin.multiply(b.toCM().origin)) CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
override fun Matrix<Double>.minus(b: Matrix<Double>) = override operator fun Matrix<Double>.minus(b: Matrix<Double>): CMMatrix =
CMMatrix(this.toCM().origin.subtract(b.toCM().origin)) CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
override fun multiply(a: Matrix<Double>, k: Number) = override fun multiply(a: Matrix<Double>, k: Number): CMMatrix =
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble())) CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
override fun Matrix<Double>.times(value: Double): Matrix<Double> = override operator fun Matrix<Double>.times(value: Double): Matrix<Double> =
produce(rowNum, colNum) { i, j -> get(i, j) * value } produce(rowNum, colNum) { i, j -> get(i, j) * value }
} }
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =
CMMatrix(this.origin.add(other.origin)) CMMatrix(this.origin.add(other.origin))
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = operator fun CMMatrix.minus(other: CMMatrix): CMMatrix =
CMMatrix(this.origin.subtract(other.origin)) CMMatrix(this.origin.subtract(other.origin))

View File

@ -4,10 +4,9 @@ import scientifik.kmath.prob.RandomGenerator
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
org.apache.commons.math3.random.RandomGenerator { org.apache.commons.math3.random.RandomGenerator {
private var generator = factory(intArrayOf()) private var generator: RandomGenerator = factory(intArrayOf())
override fun nextBoolean(): Boolean = generator.nextBoolean() override fun nextBoolean(): Boolean = generator.nextBoolean()
override fun nextFloat(): Float = generator.nextDouble().toFloat() override fun nextFloat(): Float = generator.nextDouble().toFloat()
override fun setSeed(seed: Int) { override fun setSeed(seed: Int) {
@ -27,12 +26,8 @@ class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
} }
override fun nextInt(): Int = generator.nextInt() override fun nextInt(): Int = generator.nextInt()
override fun nextInt(n: Int): Int = generator.nextInt(n) override fun nextInt(n: Int): Int = generator.nextInt(n)
override fun nextGaussian(): Double = TODO() override fun nextGaussian(): Double = TODO()
override fun nextDouble(): Double = generator.nextDouble() override fun nextDouble(): Double = generator.nextDouble()
override fun nextLong(): Long = generator.nextLong() override fun nextLong(): Long = generator.nextLong()
} }

View File

@ -1,11 +1,15 @@
package scientifik.kmath.commons.expressions package scientifik.kmath.commons.expressions
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R) = inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R): R {
DerivativeStructureField(order, mapOf(*parameters)).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
}
class AutoDiffTest { class AutoDiffTest {
@Test @Test

40
kmath-core/README.md Normal file
View File

@ -0,0 +1,40 @@
# The Core Module (`kmath-core`)
The core features of KMath:
- Algebraic structures: contexts and elements.
- ND structures.
- Buffers.
- Functional Expressions.
- Domains.
- Automatic differentiation.
> #### Artifact:
> This module is distributed in the artifact `scientifik:kmath-core:0.1.4-dev-8`.
>
> **Gradle:**
>
> ```gradle
> repositories {
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url https://dl.bintray.com/hotkeytlt/maven' }
> }
>
> dependencies {
> implementation 'scientifik:kmath-core:0.1.4-dev-8'
> }
> ```
> **Gradle Kotlin DSL:**
>
> ```kotlin
> repositories {
> maven("https://dl.bintray.com/mipt-npm/scientifik")
> maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven")
> }
>
> dependencies {``
> implementation("scientifik:kmath-core:0.1.4-dev-8")
> }
> ```

View File

@ -1,11 +1,6 @@
plugins { plugins { id("scientifik.mpp") }
id("scientifik.mpp")
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") }
dependencies { commonMain { dependencies { api(project(":kmath-memory")) } }
api(project(":kmath-memory"))
}
}
} }

View File

@ -3,13 +3,18 @@ package scientifik.kmath.domains
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
/** /**
* A simple geometric domain * A simple geometric domain.
*
* @param T the type of element of this domain.
*/ */
interface Domain<T : Any> { interface Domain<T : Any> {
/**
* Checks if the specified point is contained in this domain.
*/
operator fun contains(point: Point<T>): Boolean operator fun contains(point: Point<T>): Boolean
/** /**
* Number of hyperspace dimensions * Number of hyperspace dimensions.
*/ */
val dimension: Int val dimension: Int
} }

View File

@ -42,13 +42,14 @@ class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBu
override fun getUpperBound(num: Int): Double? = upper[num] override fun getUpperBound(num: Int): Double? = upper[num]
override fun nearestInDomain(point: Point<Double>): Point<Double> { override fun nearestInDomain(point: Point<Double>): Point<Double> {
val res: DoubleArray = DoubleArray(point.size) { i -> val res = DoubleArray(point.size) { i ->
when { when {
point[i] < lower[i] -> lower[i] point[i] < lower[i] -> lower[i]
point[i] > upper[i] -> upper[i] point[i] > upper[i] -> upper[i]
else -> point[i] else -> point[i]
} }
} }
return RealBuffer(*res) return RealBuffer(*res)
} }

View File

@ -22,8 +22,7 @@ import scientifik.kmath.linear.Point
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
interface RealDomain: Domain<Double> { interface RealDomain : Domain<Double> {
fun nearestInDomain(point: Point<Double>): Point<Double> fun nearestInDomain(point: Point<Double>): Point<Double>
/** /**
@ -61,5 +60,4 @@ interface RealDomain: Domain<Double> {
* @return * @return
*/ */
fun volume(): Double fun volume(): Double
} }

View File

@ -18,7 +18,6 @@ package scientifik.kmath.domains
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
class UnconstrainedDomain(override val dimension: Int) : RealDomain { class UnconstrainedDomain(override val dimension: Int) : RealDomain {
override operator fun contains(point: Point<Double>): Boolean = true override operator fun contains(point: Point<Double>): Boolean = true
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
@ -32,5 +31,4 @@ class UnconstrainedDomain(override val dimension: Int) : RealDomain {
override fun nearestInDomain(point: Point<Double>): Point<Double> = point override fun nearestInDomain(point: Point<Double>): Point<Double> = point
override fun volume(): Double = Double.POSITIVE_INFINITY override fun volume(): Double = Double.POSITIVE_INFINITY
} }

View File

@ -4,7 +4,6 @@ import scientifik.kmath.linear.Point
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain { inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
operator fun contains(d: Double): Boolean = range.contains(d) operator fun contains(d: Double): Boolean = range.contains(d)
override operator fun contains(point: Point<Double>): Boolean { override operator fun contains(point: Point<Double>): Boolean {
@ -15,7 +14,7 @@ inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : Rea
override fun nearestInDomain(point: Point<Double>): Point<Double> { override fun nearestInDomain(point: Point<Double>): Point<Double> {
require(point.size == 1) require(point.size == 1)
val value = point[0] val value = point[0]
return when{ return when {
value in range -> point value in range -> point
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
else -> doubleArrayOf(range.start).asBuffer() else -> doubleArrayOf(range.start).asBuffer()

View File

@ -1,23 +1,41 @@
package scientifik.kmath.expressions package scientifik.kmath.expressions
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* Create a functional expression on this [Space] * Creates a functional expression with this [Space].
*/ */
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> = inline fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> {
FunctionalExpressionSpace(this).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionSpace(this).block()
}
/** /**
* Create a functional expression on this [Ring] * Creates a functional expression with this [Ring].
*/ */
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> = inline fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> {
FunctionalExpressionRing(this).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionRing(this).block()
}
/** /**
* Create a functional expression on this [Field] * Creates a functional expression with this [Field].
*/ */
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> = inline fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> {
FunctionalExpressionField(this).run(block) contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionField(this).block()
}
/**
* Creates a functional expression with this [ExtendedField].
*/
inline fun <T> ExtendedField<T>.extendedFieldExpression(block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionExtendedField(this).block()
}

View File

@ -6,6 +6,12 @@ import scientifik.kmath.operations.Algebra
* An elementary function that could be invoked on a map of arguments * An elementary function that could be invoked on a map of arguments
*/ */
interface Expression<T> { interface Expression<T> {
/**
* Calls this expression from arguments.
*
* @param arguments the map of arguments.
* @return the value.
*/
operator fun invoke(arguments: Map<String, T>): T operator fun invoke(arguments: Map<String, T>): T
companion object companion object
@ -14,10 +20,17 @@ interface Expression<T> {
/** /**
* Create simple lazily evaluated expression inside given algebra * Create simple lazily evaluated expression inside given algebra
*/ */
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> = object: Expression<T> { fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> =
override fun invoke(arguments: Map<String, T>): T = block(arguments) object : Expression<T> {
} override operator fun invoke(arguments: Map<String, T>): T = block(arguments)
}
/**
* Calls this expression from arguments.
*
* @param pairs the pair of arguments' names to values.
* @return the value.
*/
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs)) operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
/** /**

View File

@ -4,7 +4,7 @@ import scientifik.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) : internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> { Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments)) override operator fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
} }
internal class FunctionalBinaryOperation<T>( internal class FunctionalBinaryOperation<T>(
@ -13,17 +13,17 @@ internal class FunctionalBinaryOperation<T>(
val first: Expression<T>, val first: Expression<T>,
val second: Expression<T> val second: Expression<T>
) : Expression<T> { ) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = override operator fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
} }
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> { internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = override operator fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name") arguments[name] ?: default ?: error("Parameter not found: $name")
} }
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> { internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value override operator fun invoke(arguments: Map<String, T>): T = value
} }
internal class FunctionalConstProductExpression<T>( internal class FunctionalConstProductExpression<T>(
@ -31,7 +31,7 @@ internal class FunctionalConstProductExpression<T>(
private val expr: Expression<T>, private val expr: Expression<T>,
val const: Number val const: Number
) : Expression<T> { ) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const) override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
} }
/** /**
@ -40,7 +40,6 @@ internal class FunctionalConstProductExpression<T>(
* @param algebra The algebra to provide for Expressions built. * @param algebra The algebra to provide for Expressions built.
*/ */
abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) : ExpressionAlgebra<T, Expression<T>> { abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) : ExpressionAlgebra<T, Expression<T>> {
/** /**
* Builds an Expression of constant expression which does not depend on arguments. * Builds an Expression of constant expression which does not depend on arguments.
*/ */
@ -69,14 +68,13 @@ abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) :
*/ */
open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) : open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> { FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
override val zero: Expression<T> get() = const(algebra.zero) override val zero: Expression<T> get() = const(algebra.zero)
/** /**
* Builds an Expression of addition of two another expressions. * Builds an Expression of addition of two another expressions.
*/ */
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
/** /**
* Builds an Expression of multiplication of expression by number. * Builds an Expression of multiplication of expression by number.
@ -105,7 +103,7 @@ open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpac
* Builds an Expression of multiplication of two expressions. * Builds an Expression of multiplication of two expressions.
*/ */
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) binaryOperation(RingOperations.TIMES_OPERATION, a, b)
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg) operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
@ -124,7 +122,7 @@ open class FunctionalExpressionField<T, A>(algebra: A) :
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) binaryOperation(FieldOperations.DIV_OPERATION, a, b)
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg) operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
@ -136,6 +134,28 @@ open class FunctionalExpressionField<T, A>(algebra: A) :
super<FunctionalExpressionRing>.binaryOperation(operation, left, right) super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
} }
open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
FunctionalExpressionField<T, A>(algebra),
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
}
inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> = inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionSpace(this).block() FunctionalExpressionSpace(this).block()
@ -144,3 +164,6 @@ inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T
inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> = inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).block() FunctionalExpressionField(this).block()
inline fun <T, A : ExtendedField<T>> A.expressionInExtendedField(block: FunctionalExpressionExtendedField<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionExtendedField(this).block()

View File

@ -19,22 +19,20 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer) override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
companion object { companion object
}
} }
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
object RealMatrixContext : GenericMatrixContext<Double, RealField> { object RealMatrixContext : GenericMatrixContext<Double, RealField> {
override val elementContext get() = RealField override val elementContext: RealField get() = RealField
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> { override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size,initializer) override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size, initializer)
} }
class BufferMatrix<T : Any>( class BufferMatrix<T : Any>(
@ -52,19 +50,15 @@ class BufferMatrix<T : Any>(
override val shape: IntArray get() = intArrayOf(rowNum, colNum) override val shape: IntArray get() = intArrayOf(rowNum, colNum)
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
BufferMatrix(rowNum, colNum, buffer, this.features + features) BufferMatrix(rowNum, colNum, buffer, this.features + features)
override fun get(index: IntArray): T = get(index[0], index[1]) override operator fun get(index: IntArray): T = get(index[0], index[1])
override fun get(i: Int, j: Int): T = buffer[i * colNum + j] override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
override fun elements(): Sequence<Pair<IntArray, T>> = sequence { override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in 0 until rowNum) { for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
for (j in 0 until colNum) {
yield(intArrayOf(i, j) to get(i, j))
}
}
} }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
@ -84,8 +78,8 @@ class BufferMatrix<T : Any>(
override fun toString(): String { override fun toString(): String {
return if (rowNum <= 5 && colNum <= 5) { return if (rowNum <= 5 && colNum <= 5) {
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
it.asSequence().joinToString(separator = "\t") { it.toString() } buffer.asSequence().joinToString(separator = "\t") { it.toString() }
} }
} else { } else {
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
@ -97,7 +91,7 @@ class BufferMatrix<T : Any>(
* Optimized dot product for real matrices * Optimized dot product for real matrices
*/ */
infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> { infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
val array = DoubleArray(this.rowNum * other.colNum) val array = DoubleArray(this.rowNum * other.colNum)

View File

@ -4,6 +4,8 @@ import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.Structure2D
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.math.sqrt import kotlin.math.sqrt
/** /**
@ -23,25 +25,25 @@ interface FeaturedMatrix<T : Any> : Matrix<T> {
*/ */
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
companion object { companion object
}
} }
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> {
MatrixContext.real.produce(rows, columns, initializer) contract { callsInPlace(initializer) }
return MatrixContext.real.produce(rows, columns, initializer)
}
/** /**
* Build a square matrix from given elements. * Build a square matrix from given elements.
*/ */
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> { fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
val size: Int = sqrt(elements.size.toDouble()).toInt() val size: Int = sqrt(elements.size.toDouble()).toInt()
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer) return BufferMatrix(size, size, buffer)
} }
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet() val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
/** /**
* Check if matrix has the given feature class * Check if matrix has the given feature class
@ -68,7 +70,7 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
* A virtual matrix of zeroes * A virtual matrix of zeroes
*/ */
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> = fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero } VirtualMatrix(rows, columns) { _, _ -> elementContext.zero }
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature

View File

@ -3,6 +3,7 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.BufferAccessor2D import scientifik.kmath.structures.BufferAccessor2D
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.Structure2D
@ -18,7 +19,7 @@ class LUPDecomposition<T : Any>(
private val even: Boolean private val even: Boolean
) : LUPDecompositionFeature<T>, DeterminantFeature<T> { ) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
val elementContext get() = context.elementContext val elementContext: Field<T> get() = context.elementContext
/** /**
* Returns the matrix L of the decomposition. * Returns the matrix L of the decomposition.
@ -60,15 +61,13 @@ class LUPDecomposition<T : Any>(
* @return determinant of the matrix * @return determinant of the matrix
*/ */
override val determinant: T by lazy { override val determinant: T by lazy {
with(elementContext) { elementContext { (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } }
(0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] }
}
} }
} }
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) = fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
if (value > elementContext.zero) value else with(elementContext) { -value } if (value > elementContext.zero) value else elementContext { -value }
/** /**
@ -88,43 +87,34 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
//TODO just waits for KEEP-176 //TODO just waits for KEEP-176
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
elementContext.run { elementContext {
val lu = create(matrix) val lu = create(matrix)
// Initialize permutation array and parity // Initialize permutation array and parity
for (row in 0 until m) { for (row in 0 until m) pivot[row] = row
pivot[row] = row
}
var even = true var even = true
// Initialize permutation array and parity // Initialize permutation array and parity
for (row in 0 until m) { for (row in 0 until m) pivot[row] = row
pivot[row] = row
}
// Loop over columns // Loop over columns
for (col in 0 until m) { for (col in 0 until m) {
// upper // upper
for (row in 0 until col) { for (row in 0 until col) {
val luRow = lu.row(row) val luRow = lu.row(row)
var sum = luRow[col] var sum = luRow[col]
for (i in 0 until row) { for (i in 0 until row) sum -= luRow[i] * lu[i, col]
sum -= luRow[i] * lu[i, col]
}
luRow[col] = sum luRow[col] = sum
} }
// lower // lower
var max = col // permutation row var max = col // permutation row
var largest = -one var largest = -one
for (row in col until m) { for (row in col until m) {
val luRow = lu.row(row) val luRow = lu.row(row)
var sum = luRow[col] var sum = luRow[col]
for (i in 0 until col) { for (i in 0 until col) sum -= luRow[i] * lu[i, col]
sum -= luRow[i] * lu[i, col]
}
luRow[col] = sum luRow[col] = sum
// maintain best permutation choice // maintain best permutation choice
@ -135,19 +125,19 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
} }
// Singularity check // Singularity check
if (checkSingular(this@lup.abs(lu[max, col]))) { check(!checkSingular(this@lup.abs(lu[max, col]))) { "The matrix is singular" }
error("The matrix is singular")
}
// Pivot if necessary // Pivot if necessary
if (max != col) { if (max != col) {
val luMax = lu.row(max) val luMax = lu.row(max)
val luCol = lu.row(col) val luCol = lu.row(col)
for (i in 0 until m) { for (i in 0 until m) {
val tmp = luMax[i] val tmp = luMax[i]
luMax[i] = luCol[i] luMax[i] = luCol[i]
luCol[i] = tmp luCol[i] = tmp
} }
val temp = pivot[max] val temp = pivot[max]
pivot[max] = pivot[col] pivot[max] = pivot[col]
pivot[col] = temp pivot[col] = temp
@ -156,9 +146,7 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
// Divide the lower elements by the "winning" diagonal elt. // Divide the lower elements by the "winning" diagonal elt.
val luDiag = lu[col, col] val luDiag = lu[col, col]
for (row in col + 1 until m) { for (row in col + 1 until m) lu[row, col] /= luDiag
lu[row, col] /= luDiag
}
} }
return LUPDecomposition(this@lup, lu.collect(), pivot, even) return LUPDecomposition(this@lup, lu.collect(), pivot, even)
@ -169,33 +157,29 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup( inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean noinline checkSingular: (T) -> Boolean
) = lup(T::class, matrix, checkSingular) ): LUPDecomposition<T> = lup(T::class, matrix, checkSingular)
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 } fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
lup(Double::class, matrix) { it < 1e-11 }
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> { fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }
if (matrix.rowNum != pivot.size) {
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
}
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
elementContext.run { elementContext {
// Apply permutations to b // Apply permutations to b
val bp = create { _, _ -> zero } val bp = create { _, _ -> zero }
for (row in 0 until pivot.size) { for (row in pivot.indices) {
val bpRow = bp.row(row) val bpRow = bp.row(row)
val pRow = pivot[row] val pRow = pivot[row]
for (col in 0 until matrix.colNum) { for (col in 0 until matrix.colNum) bpRow[col] = matrix[pRow, col]
bpRow[col] = matrix[pRow, col]
}
} }
// Solve LY = b // Solve LY = b
for (col in 0 until pivot.size) { for (col in pivot.indices) {
val bpCol = bp.row(col) val bpCol = bp.row(col)
for (i in col + 1 until pivot.size) { for (i in col + 1 until pivot.size) {
val bpI = bp.row(i) val bpI = bp.row(i)
val luICol = lu[i, col] val luICol = lu[i, col]
@ -209,23 +193,21 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
for (col in pivot.size - 1 downTo 0) { for (col in pivot.size - 1 downTo 0) {
val bpCol = bp.row(col) val bpCol = bp.row(col)
val luDiag = lu[col, col] val luDiag = lu[col, col]
for (j in 0 until matrix.colNum) { for (j in 0 until matrix.colNum) bpCol[j] /= luDiag
bpCol[j] /= luDiag
}
for (i in 0 until col) { for (i in 0 until col) {
val bpI = bp.row(i) val bpI = bp.row(i)
val luICol = lu[i, col] val luICol = lu[i, col]
for (j in 0 until matrix.colNum) { for (j in 0 until matrix.colNum) bpI[j] -= bpCol[j] * luICol
bpI[j] -= bpCol[j] * luICol
}
} }
} }
return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] } return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
} }
} }
} }
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix) inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> = solve(T::class, matrix)
/** /**
* Solve a linear equation **a*x = b** * Solve a linear equation **a*x = b**
@ -240,13 +222,12 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
return decomposition.solve(T::class, b) return decomposition.solve(T::class, b)
} }
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) = fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> = solve(a, b) { it < 1e-11 }
solve(a, b) { it < 1e-11 }
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse( inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean noinline checkSingular: (T) -> Boolean
) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular) ): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
fun RealMatrixContext.inverse(matrix: Matrix<Double>) = fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 } solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }

View File

@ -25,4 +25,4 @@ fun <T : Any> Matrix<T>.asPoint(): Point<T> =
error("Can't convert matrix with more than one column to vector") error("Can't convert matrix with more than one column to vector")
} }
fun <T : Any> Point<T>.asMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) } fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }

View File

@ -7,7 +7,7 @@ import scientifik.kmath.structures.asBuffer
class MatrixBuilder(val rows: Int, val columns: Int) { class MatrixBuilder(val rows: Int, val columns: Int) {
operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> { operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns") require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }

View File

@ -2,6 +2,7 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.operations.invoke
import scientifik.kmath.operations.sum import scientifik.kmath.operations.sum
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
@ -29,7 +30,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
/** /**
* Non-boxing double matrix * Non-boxing double matrix
*/ */
val real = RealMatrixContext val real: RealMatrixContext = RealMatrixContext
/** /**
* A structured matrix with custom buffer * A structured matrix with custom buffer
@ -37,8 +38,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
fun <T : Any, R : Ring<T>> buffered( fun <T : Any, R : Ring<T>> buffered(
ring: R, ring: R,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
): GenericMatrixContext<T, R> = ): GenericMatrixContext<T, R> = BufferMatrixContext(ring, bufferFactory)
BufferMatrixContext(ring, bufferFactory)
/** /**
* Automatic buffered matrix, unboxed if it is possible * Automatic buffered matrix, unboxed if it is possible
@ -61,45 +61,49 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> { override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
//TODO add typed error //TODO add typed error
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
return produce(rowNum, other.colNum) { i, j -> return produce(rowNum, other.colNum) { i, j ->
val row = rows[i] val row = rows[i]
val column = other.columns[j] val column = other.columns[j]
with(elementContext) { elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) }
sum(row.asSequence().zip(column.asSequence(), ::multiply))
}
} }
} }
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> { override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
//TODO add typed error //TODO add typed error
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})") require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
return point(rowNum) { i -> return point(rowNum) { i ->
val row = rows[i] val row = rows[i]
with(elementContext) { elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) }
sum(row.asSequence().zip(vector.asSequence(), ::multiply))
}
} }
} }
override operator fun Matrix<T>.unaryMinus() = override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> { override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]") require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) + b[i, j] } } "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
}
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
} }
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> { override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]") require(rowNum == b.rowNum && colNum == b.colNum) {
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
}
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
} }
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> = override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) * k } } produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
override fun Matrix<T>.times(value: T): Matrix<T> = override operator fun Matrix<T>.times(value: T): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } } produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
} }

View File

@ -1,7 +1,7 @@
package scientifik.kmath.linear package scientifik.kmath.linear
/** /**
* A marker interface representing some matrix feature like diagonal, sparce, zero, etc. Features used to optimize matrix * A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix
* operations performance in some cases. * operations performance in some cases.
*/ */
interface MatrixFeature interface MatrixFeature
@ -36,19 +36,19 @@ interface DeterminantFeature<T : Any> : MatrixFeature {
} }
@Suppress("FunctionName") @Suppress("FunctionName")
fun <T: Any> DeterminantFeature(determinant: T) = object: DeterminantFeature<T>{ fun <T : Any> DeterminantFeature(determinant: T): DeterminantFeature<T> = object : DeterminantFeature<T> {
override val determinant: T = determinant override val determinant: T = determinant
} }
/** /**
* Lower triangular matrix * Lower triangular matrix
*/ */
object LFeature: MatrixFeature object LFeature : MatrixFeature
/** /**
* Upper triangular feature * Upper triangular feature
*/ */
object UFeature: MatrixFeature object UFeature : MatrixFeature
/** /**
* TODO add documentation * TODO add documentation

View File

@ -2,6 +2,7 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
@ -10,10 +11,9 @@ import scientifik.kmath.structures.BufferFactory
* Could be used on any point-like structure * Could be used on any point-like structure
*/ */
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> { interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
val size: Int val size: Int
val space: S val space: S
override val zero: Point<T> get() = produce { space.zero }
fun produce(initializer: (Int) -> T): Point<T> fun produce(initializer: (Int) -> T): Point<T>
@ -22,30 +22,25 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
*/ */
//fun produceElement(initializer: (Int) -> T): Vector<T, S> //fun produceElement(initializer: (Int) -> T): Vector<T, S>
override val zero: Point<T> get() = produce { space.zero } override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { space { a[it] + b[it] } }
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } } override fun multiply(a: Point<T>, k: Number): Point<T> = produce { space { a[it] * k } }
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
//TODO add basis //TODO add basis
companion object { companion object {
private val realSpaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
/** /**
* Non-boxing double vector space * Non-boxing double vector space
*/ */
fun real(size: Int): BufferVectorSpace<Double, RealField> { fun real(size: Int): BufferVectorSpace<Double, RealField> = realSpaceCache.getOrPut(size) {
return realSpaceCache.getOrPut(size) {
BufferVectorSpace( BufferVectorSpace(
size, size,
RealField, RealField,
Buffer.Companion::auto Buffer.Companion::auto
) )
} }
}
/** /**
* A structured vector space with custom buffer * A structured vector space with custom buffer
@ -54,7 +49,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
size: Int, size: Int,
space: S, space: S,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
) = BufferVectorSpace(size, space, bufferFactory) ): BufferVectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory)
/** /**
* Automatic buffered vector, unboxed if it is possible * Automatic buffered vector, unboxed if it is possible
@ -70,6 +65,6 @@ class BufferVectorSpace<T : Any, S : Space<T>>(
override val space: S, override val space: S,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : VectorSpace<T, S> { ) : VectorSpace<T, S> {
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer) override fun produce(initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer)) //override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
} }

View File

@ -18,9 +18,9 @@ class VirtualMatrix<T : Any>(
override val shape: IntArray get() = intArrayOf(rowNum, colNum) override val shape: IntArray get() = intArrayOf(rowNum, colNum)
override fun get(i: Int, j: Int): T = generator(i, j) override operator fun get(i: Int, j: Int): T = generator(i, j)
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
VirtualMatrix(rowNum, colNum, this.features + features, generator) VirtualMatrix(rowNum, colNum, this.features + features, generator)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {

View File

@ -3,8 +3,12 @@ package scientifik.kmath.misc
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
import scientifik.kmath.operations.sum import scientifik.kmath.operations.sum
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/* /*
* Implementation of backward-mode automatic differentiation. * Implementation of backward-mode automatic differentiation.
@ -22,20 +26,19 @@ class DerivationResult<T : Any>(
val deriv: Map<Variable<T>, T>, val deriv: Map<Variable<T>, T>,
val context: Field<T> val context: Field<T>
) : Variable<T>(value) { ) : Variable<T>(value) {
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
/** /**
* compute divergence * compute divergence
*/ */
fun div() = context.run { sum(deriv.values) } fun div(): T = context { sum(deriv.values) }
/** /**
* Compute a gradient for variables in given order * Compute a gradient for variables in given order
*/ */
fun grad(vararg variables: Variable<T>): Point<T> = if (variables.isEmpty()) { fun grad(vararg variables: Variable<T>): Point<T> {
error("Variable order is not provided for gradient construction") check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
} else { return variables.map(::deriv).asBuffer()
variables.map(::deriv).asBuffer()
} }
} }
@ -52,19 +55,27 @@ class DerivationResult<T : Any>(
* assertEquals(9.0, x.d) // dy/dx * assertEquals(9.0, x.d) // dy/dx
* ``` * ```
*/ */
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> = inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
AutoDiffContext<T, F>(this).run { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return (AutoDiffContext(this)) {
val result = body() val result = body()
result.d = context.one// computing derivative w.r.t result result.d = context.one // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
DerivationResult(result.value, derivatives, this@deriv) DerivationResult(result.value, derivatives, this@deriv)
} }
}
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> { abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
abstract val context: F abstract val context: F
/**
* A variable accessing inner state of derivatives.
* Use this function in inner builders to avoid creating additional derivative bindings
*/
abstract var Variable<T>.d: T
/** /**
* Performs update of derivative after the rest of the formula in the back-pass. * Performs update of derivative after the rest of the formula in the back-pass.
* *
@ -78,15 +89,9 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
*/ */
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
/**
* A variable accessing inner state of derivatives.
* Use this function in inner builders to avoid creating additional derivative bindings
*/
abstract var Variable<T>.d: T
abstract fun variable(value: T): Variable<T> abstract fun variable(value: T): Variable<T>
inline fun variable(block: F.() -> T) = variable(context.block()) inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
// Overloads for Double constants // Overloads for Double constants
@ -98,46 +103,35 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this) override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
override operator fun Number.minus(b: Variable<T>): Variable<T> = override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - b.value }) { z -> derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
b.d -= z.d
}
override operator fun Variable<T>.minus(b: Number): Variable<T> = override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * b.toDouble() }) { z -> derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
this@minus.d += z.d
}
} }
/** /**
* Automatic Differentiation context class. * Automatic Differentiation context class.
*/ */
private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() { @PublishedApi
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8) private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp = 0 private var sp: Int = 0
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
internal val derivatives = HashMap<Variable<T>, T>() override val zero: Variable<T> get() = Variable(context.zero)
override val one: Variable<T> get() = Variable(context.one)
/** /**
* A variable coupled with its derivative. For internal use only * A variable coupled with its derivative. For internal use only
*/ */
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x) private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
override fun variable(value: T): Variable<T> = override fun variable(value: T): Variable<T> =
VariableWithDeriv(value, context.zero) VariableWithDeriv(value, context.zero)
override var Variable<T>.d: T override var Variable<T>.d: T
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) { set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
if (this is VariableWithDeriv) {
d = value
} else {
derivatives[this] = value
}
}
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R { override fun <R> derive(value: R, block: F.(R) -> Unit): R {
@ -160,67 +154,49 @@ private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) :
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
derive(variable { a.value + b.value }) { z ->
a.d += z.d a.d += z.d
b.d += z.d b.d += z.d
} }
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
derive(variable { a.value * b.value }) { z ->
a.d += z.d * b.value a.d += z.d * b.value
b.d += z.d * a.value b.d += z.d * a.value
} }
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
derive(variable { a.value / b.value }) { z ->
a.d += z.d / b.value a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value) b.d -= z.d * a.value / (b.value * b.value)
} }
override fun multiply(a: Variable<T>, k: Number): Variable<T> = override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
derive(variable { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble() a.d += z.d * k.toDouble()
} }
override val zero: Variable<T> get() = Variable(context.zero)
override val one: Variable<T> get() = Variable(context.one)
} }
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
// x ^ 2 // x ^ 2
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> = fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
derive(variable { x.value * x.value }) { z -> derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
x.d += z.d * 2 * x.value
}
// x ^ 1/2 // x ^ 1/2
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
derive(variable { sqrt(x.value) }) { z -> derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
x.d += z.d * 0.5 / z.value
}
// x ^ y (const) // x ^ y (const)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
derive(variable { power(x.value, y) }) { z -> derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
x.d += z.d * y * power(x.value, y - 1)
}
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble()) fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
// exp(x) // exp(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
derive(variable { exp(x.value) }) { z -> derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
x.d += z.d * z.value
}
// ln(x) // ln(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> = derive( fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
variable { ln(x.value) } derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
) { z ->
x.d += z.d / x.value
}
// x ^ y (any) // x ^ y (any)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
@ -228,12 +204,8 @@ fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: V
// sin(x) // sin(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
x.d += z.d * cos(x.value)
}
// cos(x) // cos(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> = fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
x.d -= z.d * sin(x.value)
}

View File

@ -41,6 +41,6 @@ fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Seque
*/ */
@Deprecated("Replace by 'toSequenceWithPoints'") @Deprecated("Replace by 'toSequenceWithPoints'")
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray { fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't create generic grid with less than two points") require(numPoints >= 2) { "Can't create generic grid with less than two points" }
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
} }

View File

@ -1,76 +1,80 @@
package scientifik.kmath.misc package scientifik.kmath.misc
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
/** /**
* Generic cumulative operation on iterator * Generic cumulative operation on iterator.
* @param T type of initial iterable *
* @param R type of resulting iterable * @param T the type of initial iterable.
* @param initial lazy evaluated * @param R the type of resulting iterable.
* @param initial lazy evaluated.
*/ */
fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> { inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> {
contract { callsInPlace(operation) }
return object : Iterator<R> {
var state: R = initial var state: R = initial
override fun hasNext(): Boolean = this@cumulative.hasNext() override fun hasNext(): Boolean = this@cumulative.hasNext()
override fun next(): R { override fun next(): R {
state = operation(state, this@cumulative.next()) state = operation(state, this@cumulative.next())
return state return state
} }
}
} }
fun <T, R> Iterable<T>.cumulative(initial: R, operation: (R, T) -> R): Iterable<R> = object : Iterable<R> { inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> =
override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation) Iterable { this@cumulative.iterator().cumulative(initial, operation) }
}
fun <T, R> Sequence<T>.cumulative(initial: R, operation: (R, T) -> R): Sequence<R> = object : Sequence<R> { inline fun <T, R> Sequence<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence<R> = Sequence {
override fun iterator(): Iterator<R> = this@cumulative.iterator().cumulative(initial, operation) this@cumulative.iterator().cumulative(initial, operation)
} }
fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> = fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
this.iterator().cumulative(initial, operation).asSequence().toList() iterator().cumulative(initial, operation).asSequence().toList()
//Cumulative sum //Cumulative sum
/** /**
* Cumulative sum with custom space * Cumulative sum with custom space
*/ */
fun <T> Iterable<T>.cumulativeSum(space: Space<T>) = with(space) { fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> =
cumulative(zero) { element: T, sum: T -> sum + element } space { cumulative(zero) { element: T, sum: T -> sum + element } }
}
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } fun Iterable<Double>.cumulativeSum(): Iterable<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Iterable<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } fun Iterable<Int>.cumulativeSum(): Iterable<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } fun Iterable<Long>.cumulativeSum(): Iterable<Long> = cumulative(0L) { element, sum -> sum + element }
fun <T> Sequence<T>.cumulativeSum(space: Space<T>) = with(space) { fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> =
cumulative(zero) { element: T, sum: T -> sum + element } space { cumulative(zero) { element: T, sum: T -> sum + element } }
}
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } fun Sequence<Double>.cumulativeSum(): Sequence<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun Sequence<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } fun Sequence<Int>.cumulativeSum(): Sequence<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } fun Sequence<Long>.cumulativeSum(): Sequence<Long> = cumulative(0L) { element, sum -> sum + element }
fun <T> List<T>.cumulativeSum(space: Space<T>) = with(space) { fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> =
cumulative(zero) { element: T, sum: T -> sum + element } space { cumulative(zero) { element: T, sum: T -> sum + element } }
}
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
fun List<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } fun List<Double>.cumulativeSum(): List<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt") @JvmName("cumulativeSumOfInt")
fun List<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } fun List<Int>.cumulativeSum(): List<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
fun List<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } fun List<Long>.cumulativeSum(): List<Long> = cumulative(0L) { element, sum -> sum + element }

View File

@ -1,10 +1,15 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/**
* Stub for DSL the [Algebra] is.
*/
@DslMarker @DslMarker
annotation class KMathContext annotation class KMathContext
/** /**
* Marker interface for any algebra * Represents an algebraic structure.
*
* @param T the type of element of this structure.
*/ */
interface Algebra<T> { interface Algebra<T> {
/** /**
@ -24,50 +29,121 @@ interface Algebra<T> {
} }
/** /**
* An algebra with numeric representation of members * An algebraic structure where elements can have numeric representation.
*
* @param T the type of element of this structure.
*/ */
interface NumericAlgebra<T> : Algebra<T> { interface NumericAlgebra<T> : Algebra<T> {
/** /**
* Wrap a number * Wraps a number.
*/ */
fun number(value: Number): T fun number(value: Number): T
/**
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
*/
fun leftSideNumberOperation(operation: String, left: Number, right: T): T = fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
binaryOperation(operation, number(left), right) binaryOperation(operation, number(left), right)
/**
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
*/
fun rightSideNumberOperation(operation: String, left: T, right: Number): T = fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left) leftSideNumberOperation(operation, right, left)
} }
/** /**
* Call a block with an [Algebra] as receiver * Call a block with an [Algebra] as receiver.
*/ */
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block) inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
/** /**
* Space-like operations without neutral element * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as
* multiplication by scalars.
*
* @param T the type of element of this semispace.
*/ */
interface SpaceOperations<T> : Algebra<T> { interface SpaceOperations<T> : Algebra<T> {
/** /**
* Addition operation for two context elements * Addition of two elements.
*
* @param a the addend.
* @param b the augend.
* @return the sum.
*/ */
fun add(a: T, b: T): T fun add(a: T, b: T): T
/** /**
* Multiplication operation for context element and real number * Multiplication of element by scalar.
*
* @param a the multiplier.
* @param k the multiplicand.
* @return the produce.
*/ */
fun multiply(a: T, k: Number): T fun multiply(a: T, k: Number): T
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176
/**
* The negation of this element.
*
* @receiver this value.
* @return the additive inverse of this value.
*/
operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.unaryMinus(): T = multiply(this, -1.0)
/**
* Returns this value.
*
* @receiver this value.
* @return this value.
*/
operator fun T.unaryPlus(): T = this operator fun T.unaryPlus(): T = this
/**
* Addition of two elements.
*
* @receiver the addend.
* @param b the augend.
* @return the sum.
*/
operator fun T.plus(b: T): T = add(this, b) operator fun T.plus(b: T): T = add(this, b)
/**
* Subtraction of two elements.
*
* @receiver the minuend.
* @param b the subtrahend.
* @return the difference.
*/
operator fun T.minus(b: T): T = add(this, -b) operator fun T.minus(b: T): T = add(this, -b)
operator fun T.times(k: Number) = multiply(this, k.toDouble())
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) /**
operator fun Number.times(b: T) = b * this * Multiplication of this element by a scalar.
*
* @receiver the multiplier.
* @param k the multiplicand.
* @return the product.
*/
operator fun T.times(k: Number): T = multiply(this, k.toDouble())
/**
* Division of this element by scalar.
*
* @receiver the dividend.
* @param k the divisor.
* @return the quotient.
*/
operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble())
/**
* Multiplication of this number by element.
*
* @receiver the multiplier.
* @param b the multiplicand.
* @return the product.
*/
operator fun Number.times(b: T): T = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String, arg: T): T = when (operation) {
PLUS_OPERATION -> arg PLUS_OPERATION -> arg
@ -82,37 +158,54 @@ interface SpaceOperations<T> : Algebra<T> {
} }
companion object { companion object {
const val PLUS_OPERATION = "+" /**
const val MINUS_OPERATION = "-" * The identifier of addition.
const val NOT_OPERATION = "!" */
const val PLUS_OPERATION: String = "+"
/**
* The identifier of subtraction (and negation).
*/
const val MINUS_OPERATION: String = "-"
const val NOT_OPERATION: String = "!"
} }
} }
/** /**
* A general interface representing linear context of some kind. * Represents linear space, i.e. algebraic structure with associative binary operation called "addition" and its neutral
* The context defines sum operation for its elements and multiplication by real value. * element as well as multiplication by scalars.
* One must note that in some cases context is a singleton class, but in some cases it
* works as a context for operations inside it.
* *
* TODO do we need non-commutative context? * @param T the type of element of this group.
*/ */
interface Space<T> : SpaceOperations<T> { interface Space<T> : SpaceOperations<T> {
/** /**
* Neutral element for sum operation * The neutral element of addition.
*/ */
val zero: T val zero: T
} }
/** /**
* Operations on ring without multiplication neutral element * Represents semiring, i.e. algebraic structure with two associative binary operations called "addition" and
* "multiplication".
*
* @param T the type of element of this semiring.
*/ */
interface RingOperations<T> : SpaceOperations<T> { interface RingOperations<T> : SpaceOperations<T> {
/** /**
* Multiplication for two field elements * Multiplies two elements.
*
* @param a the multiplier.
* @param b the multiplicand.
*/ */
fun multiply(a: T, b: T): T fun multiply(a: T, b: T): T
/**
* Multiplies this element by scalar.
*
* @receiver the multiplier.
* @param b the multiplicand.
*/
operator fun T.times(b: T): T = multiply(this, b) operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
@ -121,12 +214,18 @@ interface RingOperations<T> : SpaceOperations<T> {
} }
companion object { companion object {
const val TIMES_OPERATION = "*" /**
* The identifier of multiplication.
*/
const val TIMES_OPERATION: String = "*"
} }
} }
/** /**
* The same as {@link Space} but with additional multiplication operation * Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and
* "multiplication" and their neutral elements.
*
* @param T the type of element of this ring.
*/ */
interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> { interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
/** /**
@ -150,20 +249,64 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
else -> super.rightSideNumberOperation(operation, left, right) else -> super.rightSideNumberOperation(operation, left, right)
} }
/**
* Addition of element and scalar.
*
* @receiver the addend.
* @param b the augend.
*/
operator fun T.plus(b: Number): T = this + number(b)
operator fun T.plus(b: Number) = this.plus(number(b)) /**
operator fun Number.plus(b: T) = b + this * Addition of scalar and element.
*
* @receiver the addend.
* @param b the augend.
*/
operator fun Number.plus(b: T): T = b + this
operator fun T.minus(b: Number) = this.minus(number(b)) /**
operator fun Number.minus(b: T) = -b + this * Subtraction of element from number.
*
* @receiver the minuend.
* @param b the subtrahend.
* @receiver the difference.
*/
operator fun T.minus(b: Number): T = this - number(b)
/**
* Subtraction of number from element.
*
* @receiver the minuend.
* @param b the subtrahend.
* @receiver the difference.
*/
operator fun Number.minus(b: T): T = -b + this
} }
/** /**
* All ring operations but without neutral elements * Represents semifield, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
* and "division".
*
* @param T the type of element of this semifield.
*/ */
interface FieldOperations<T> : RingOperations<T> { interface FieldOperations<T> : RingOperations<T> {
/**
* Division of two elements.
*
* @param a the dividend.
* @param b the divisor.
* @return the quotient.
*/
fun divide(a: T, b: T): T fun divide(a: T, b: T): T
/**
* Division of two elements.
*
* @receiver the dividend.
* @param b the divisor.
* @return the quotient.
*/
operator fun T.div(b: T): T = divide(this, b) operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
@ -172,13 +315,26 @@ interface FieldOperations<T> : RingOperations<T> {
} }
companion object { companion object {
const val DIV_OPERATION = "/" /**
* The identifier of division.
*/
const val DIV_OPERATION: String = "/"
} }
} }
/** /**
* Four operations algebra * Represents field, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
* and "division" and their neutral elements.
*
* @param T the type of element of this semifield.
*/ */
interface Field<T> : Ring<T>, FieldOperations<T> { interface Field<T> : Ring<T>, FieldOperations<T> {
operator fun Number.div(b: T) = this * divide(one, b) /**
* Division of element by scalar.
*
* @receiver the dividend.
* @param b the divisor.
* @return the quotient.
*/
operator fun Number.div(b: T): T = this * divide(one, b)
} }

View File

@ -2,47 +2,107 @@ package scientifik.kmath.operations
/** /**
* The generic mathematics elements which is able to store its context * The generic mathematics elements which is able to store its context
* @param T the type of space operation results *
* @param I self type of the element. Needed for static type checking * @param C the type of mathematical context for this element.
* @param C the type of mathematical context for this element
*/ */
interface MathElement<C> { interface MathElement<C> {
/** /**
* The context this element belongs to * The context this element belongs to.
*/ */
val context: C val context: C
} }
/**
* Represents element that can be wrapped to its "primitive" value.
*
* @param T the type wrapped by this wrapper.
* @param I the type of this wrapper.
*/
interface MathWrapper<T, I> { interface MathWrapper<T, I> {
/**
* Unwraps [I] to [T].
*/
fun unwrap(): T fun unwrap(): T
/**
* Wraps [T] to [I].
*/
fun T.wrap(): I fun T.wrap(): I
} }
/** /**
* The element of linear context * The element of [Space].
* @param T the type of space operation results *
* @param I self type of the element. Needed for static type checking * @param T the type of space operation results.
* @param S the type of space * @param I self type of the element. Needed for static type checking.
* @param S the type of space.
*/ */
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> { interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> {
/**
* Adds element to this one.
*
* @param b the augend.
* @return the sum.
*/
operator fun plus(b: T): I = context.add(unwrap(), b).wrap()
operator fun plus(b: T) = context.add(unwrap(), b).wrap() /**
operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap() * Subtracts element from this one.
operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap() *
operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap() * @param b the subtrahend.
* @return the difference.
*/
operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
/**
* Multiplies this element by number.
*
* @param k the multiplicand.
* @return the product.
*/
operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap()
/**
* Divides this element by number.
*
* @param k the divisor.
* @return the quotient.
*/
operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
} }
/** /**
* Ring element * The element of [Ring].
*
* @param T the type of space operation results.
* @param I self type of the element. Needed for static type checking.
* @param R the type of space.
*/ */
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> { interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
operator fun times(b: T) = context.multiply(unwrap(), b).wrap() /**
* Multiplies this element by another one.
*
* @param b the multiplicand.
* @return the product.
*/
operator fun times(b: T): I = context.multiply(unwrap(), b).wrap()
} }
/** /**
* Field element * The element of [Field].
*
* @param T the type of space operation results.
* @param I self type of the element. Needed for static type checking.
* @param F the type of field.
*/ */
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> { interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
override val context: F override val context: F
operator fun div(b: T) = context.divide(unwrap(), b).wrap()
/**
* Divides this element by another one.
*
* @param b the divisor.
* @return the quotient.
*/
operator fun div(b: T): I = context.divide(unwrap(), b).wrap()
} }

View File

@ -1,15 +1,107 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/**
* Returns the sum of all elements in the iterable in this [Space].
*
* @receiver the algebra that provides addition.
* @param data the iterable to sum up.
* @return the sum.
*/
fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> add(left, right) } fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> add(left, right) }
/**
* Returns the sum of all elements in the sequence in this [Space].
*
* @receiver the algebra that provides addition.
* @param data the sequence to sum up.
* @return the sum.
*/
fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> add(left, right) } fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> add(left, right) }
fun <T : Any, S : Space<T>> Iterable<T>.sumWith(space: S): T = space.sum(this) /**
* Returns an average value of elements in the iterable in this [Space].
*
* @receiver the algebra that provides addition and division.
* @param data the iterable to find average.
* @return the average value.
*/
fun <T> Space<T>.average(data: Iterable<T>): T = sum(data) / data.count()
/**
* Returns an average value of elements in the sequence in this [Space].
*
* @receiver the algebra that provides addition and division.
* @param data the sequence to find average.
* @return the average value.
*/
fun <T> Space<T>.average(data: Sequence<T>): T = sum(data) / data.count()
/**
* Returns the sum of all elements in the iterable in provided space.
*
* @receiver the collection to sum up.
* @param space the algebra that provides addition.
* @return the sum.
*/
fun <T> Iterable<T>.sumWith(space: Space<T>): T = space.sum(this)
/**
* Returns the sum of all elements in the sequence in provided space.
*
* @receiver the collection to sum up.
* @param space the algebra that provides addition.
* @return the sum.
*/
fun <T> Sequence<T>.sumWith(space: Space<T>): T = space.sum(this)
/**
* Returns an average value of elements in the iterable in this [Space].
*
* @receiver the iterable to find average.
* @param space the algebra that provides addition and division.
* @return the average value.
*/
fun <T> Iterable<T>.averageWith(space: Space<T>): T = space.average(this)
/**
* Returns an average value of elements in the sequence in this [Space].
*
* @receiver the sequence to find average.
* @param space the algebra that provides addition and division.
* @return the average value.
*/
fun <T> Sequence<T>.averageWith(space: Space<T>): T = space.average(this)
//TODO optimized power operation //TODO optimized power operation
fun <T> RingOperations<T>.power(arg: T, power: Int): T {
/**
* Raises [arg] to the natural power [power].
*
* @receiver the algebra to provide multiplication.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
*/
fun <T> Ring<T>.power(arg: T, power: Int): T {
require(power >= 0) { "The power can't be negative." }
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
var res = arg var res = arg
repeat(power - 1) { repeat(power - 1) { res *= arg }
res *= arg
}
return res return res
} }
/**
* Raises [arg] to the integer power [power].
*
* @receiver the algebra to provide multiplication and division.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
*/
fun <T> Field<T>.power(arg: T, power: Int): T {
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
if (power < 0) return one / (this as Ring<T>).power(arg, -power)
return (this as Ring<T>).power(arg, power)
}

View File

@ -3,12 +3,13 @@ package scientifik.kmath.operations
import scientifik.kmath.operations.BigInt.Companion.BASE import scientifik.kmath.operations.BigInt.Companion.BASE
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.math.log2 import kotlin.math.log2
import kotlin.math.max import kotlin.math.max
import kotlin.math.min import kotlin.math.min
import kotlin.math.sign import kotlin.math.sign
typealias Magnitude = UIntArray typealias Magnitude = UIntArray
typealias TBase = ULong typealias TBase = ULong
@ -22,8 +23,9 @@ object BigIntField : Field<BigInt> {
override val one: BigInt = BigInt.ONE override val one: BigInt = BigInt.ONE
override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b) override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b)
override fun number(value: Number): BigInt = value.toLong().toBigInt()
override fun multiply(a: BigInt, k: Number): BigInt = a.times(k.toLong()) override fun multiply(a: BigInt, k: Number): BigInt = a.times(number(k))
override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b)
@ -194,8 +196,8 @@ class BigInt internal constructor(
} }
infix fun or(other: BigInt): BigInt { infix fun or(other: BigInt): BigInt {
if (this == ZERO) return other; if (this == ZERO) return other
if (other == ZERO) return this; if (other == ZERO) return this
val resSize = max(this.magnitude.size, other.magnitude.size) val resSize = max(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize) val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) { for (i in 0 until resSize) {
@ -210,7 +212,7 @@ class BigInt internal constructor(
} }
infix fun and(other: BigInt): BigInt { infix fun and(other: BigInt): BigInt {
if ((this == ZERO) or (other == ZERO)) return ZERO; if ((this == ZERO) or (other == ZERO)) return ZERO
val resSize = min(this.magnitude.size, other.magnitude.size) val resSize = min(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize) val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) { for (i in 0 until resSize) {
@ -260,7 +262,7 @@ class BigInt internal constructor(
} }
companion object { companion object {
const val BASE = 0xffffffffUL const val BASE: ULong = 0xffffffffUL
const val BASE_SIZE: Int = 32 const val BASE_SIZE: Int = 32
val ZERO: BigInt = BigInt(0, uintArrayOf()) val ZERO: BigInt = BigInt(0, uintArrayOf())
val ONE: BigInt = BigInt(1, uintArrayOf(1u)) val ONE: BigInt = BigInt(1, uintArrayOf(1u))
@ -394,12 +396,12 @@ fun abs(x: BigInt): BigInt = x.abs()
/** /**
* Convert this [Int] to [BigInt] * Convert this [Int] to [BigInt]
*/ */
fun Int.toBigInt() = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt())) fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))
/** /**
* Convert this [Long] to [BigInt] * Convert this [Long] to [BigInt]
*/ */
fun Long.toBigInt() = BigInt( fun Long.toBigInt(): BigInt = BigInt(
sign.toByte(), stripLeadingZeros( sign.toByte(), stripLeadingZeros(
uintArrayOf( uintArrayOf(
(kotlin.math.abs(this).toULong() and BASE).toUInt(), (kotlin.math.abs(this).toULong() and BASE).toUInt(),
@ -411,17 +413,17 @@ fun Long.toBigInt() = BigInt(
/** /**
* Convert UInt to [BigInt] * Convert UInt to [BigInt]
*/ */
fun UInt.toBigInt() = BigInt(1, uintArrayOf(this)) fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this))
/** /**
* Convert ULong to [BigInt] * Convert ULong to [BigInt]
*/ */
fun ULong.toBigInt() = BigInt( fun ULong.toBigInt(): BigInt = BigInt(
1, 1,
stripLeadingZeros( stripLeadingZeros(
uintArrayOf( uintArrayOf(
(this and BigInt.BASE).toUInt(), (this and BASE).toUInt(),
((this shr BigInt.BASE_SIZE) and BigInt.BASE).toUInt() ((this shr BASE_SIZE) and BASE).toUInt()
) )
) )
) )
@ -430,11 +432,11 @@ fun ULong.toBigInt() = BigInt(
* Create a [BigInt] with this array of magnitudes with protective copy * Create a [BigInt] with this array of magnitudes with protective copy
*/ */
fun UIntArray.toBigInt(sign: Byte): BigInt { fun UIntArray.toBigInt(sign: Byte): BigInt {
if (sign == 0.toByte() && isNotEmpty()) error("") require(sign != 0.toByte() || !isNotEmpty())
return BigInt(sign, this.copyOf()) return BigInt(sign, copyOf())
} }
val hexChToInt = hashMapOf( val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3, '0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7, '4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11, '8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
@ -484,11 +486,15 @@ fun String.parseBigInteger(): BigInt? {
return res * sign return res * sign
} }
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> = inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> {
boxing(size, initializer) contract { callsInPlace(initializer) }
return boxing(size, initializer)
}
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> = inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> {
boxing(size, initializer) contract { callsInPlace(initializer) }
return boxing(size, initializer)
}
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> = fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
@ -496,5 +502,4 @@ fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntFi
fun NDElement.Companion.bigInt( fun NDElement.Companion.bigInt(
vararg shape: Int, vararg shape: Int,
initializer: BigIntField.(IntArray) -> BigInt initializer: BigIntField.(IntArray) -> BigInt
): BufferedNDRingElement<BigInt, BigIntField> = ): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer)
NDAlgebra.bigInt(*shape).produce(initializer)

View File

@ -6,19 +6,50 @@ import scientifik.kmath.structures.MutableBuffer
import scientifik.memory.MemoryReader import scientifik.memory.MemoryReader
import scientifik.memory.MemorySpec import scientifik.memory.MemorySpec
import scientifik.memory.MemoryWriter import scientifik.memory.MemoryWriter
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.math.* import kotlin.math.*
/**
* This complex's conjugate.
*/
val Complex.conjugate: Complex
get() = Complex(re, -im)
/**
* This complex's reciprocal.
*/
val Complex.reciprocal: Complex
get() {
val scale = re * re + im * im
return Complex(re / scale, -im / scale)
}
/**
* Absolute value of complex number.
*/
val Complex.r: Double
get() = sqrt(re * re + im * im)
/**
* An angle between vector represented by complex number and X axis.
*/
val Complex.theta: Double
get() = atan(im / re)
private val PI_DIV_2 = Complex(PI / 2, 0) private val PI_DIV_2 = Complex(PI / 2, 0)
/** /**
* A field for complex numbers * A field of [Complex].
*/ */
object ComplexField : ExtendedField<Complex> { object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
override val zero: Complex = Complex(0.0, 0.0) override val zero: Complex = 0.0.toComplex()
override val one: Complex = 1.0.toComplex()
override val one: Complex = Complex(1.0, 0.0) /**
* The imaginary unit.
val i = Complex(0.0, 1.0) */
val i: Complex = Complex(0.0, 1.0)
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
@ -27,53 +58,123 @@ object ComplexField : ExtendedField<Complex> {
override fun multiply(a: Complex, b: Complex): Complex = override fun multiply(a: Complex, b: Complex): Complex =
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
override fun divide(a: Complex, b: Complex): Complex { override fun divide(a: Complex, b: Complex): Complex = when {
val norm = b.re * b.re + b.im * b.im b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN)
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
(if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> {
val wr = b.im / b.re
val wd = b.re + wr * b.im
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
else
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
}
b.im == 0.0 -> Complex(Double.NaN, Double.NaN)
else -> {
val wr = b.re / b.im
val wd = b.im + wr * b.re
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
else
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
}
} }
override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg)
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg)
override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2
override fun power(arg: Complex, pow: Number): Complex = override fun tan(arg: Complex): Complex {
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) val e1 = exp(-i * arg)
val e2 = exp(i * arg)
return i * (e1 - e2) / (e1 + e2)
}
override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg)
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg)
override fun atan(arg: Complex): Complex {
val iArg = i * arg
return i * (ln(1 - iArg) - ln(1 + iArg)) / 2
}
override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0)
arg.re.pow(pow.toDouble()).toComplex()
else
exp(pow * ln(arg))
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re)
operator fun Double.plus(c: Complex) = add(this.toComplex(), c) /**
* Adds complex number to real one.
*
* @receiver the addend.
* @param c the augend.
* @return the sum.
*/
operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c)
operator fun Double.minus(c: Complex) = add(this.toComplex(), -c) /**
* Subtracts complex number from real one.
*
* @receiver the minuend.
* @param c the subtrahend.
* @return the difference.
*/
operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c)
operator fun Complex.plus(d: Double) = d + this /**
* Adds real number to complex one.
*
* @receiver the addend.
* @param d the augend.
* @return the sum.
*/
operator fun Complex.plus(d: Double): Complex = d + this
operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) /**
* Subtracts real number from complex one.
*
* @receiver the minuend.
* @param d the subtrahend.
* @return the difference.
*/
operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex())
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) /**
* Multiplies real number by complex one.
*
* @receiver the multiplier.
* @param c the multiplicand.
* @receiver the product.
*/
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
override fun symbol(value: String): Complex = if (value == "i") { override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg)
i
} else { override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
super.symbol(value)
}
} }
/** /**
* Complex number class * Represents complex number.
*
* @property re The real part.
* @property im The imaginary part.
*/ */
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> { data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
override val context: ComplexField get() = ComplexField
override fun unwrap(): Complex = this override fun unwrap(): Complex = this
override fun Complex.wrap(): Complex = this override fun Complex.wrap(): Complex = this
override val context: ComplexField get() = ComplexField
override fun compareTo(other: Complex): Int = r.compareTo(other.r) override fun compareTo(other: Complex): Int = r.compareTo(other.r)
companion object : MemorySpec<Complex> { companion object : MemorySpec<Complex> {
@ -90,26 +191,19 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
} }
/** /**
* A complex conjugate * Creates a complex number with real part equal to this real.
*
* @receiver the real part.
* @return the new complex number.
*/ */
val Complex.conjugate: Complex get() = Complex(re, -im) fun Number.toComplex(): Complex = Complex(this, 0.0)
/**
* Absolute value of complex number
*/
val Complex.r: Double get() = sqrt(re * re + im * im)
/**
* An angle between vector represented by complex number and X axis
*/
val Complex.theta: Double get() = atan(im / re)
fun Double.toComplex() = Complex(this, 0.0)
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
contract { callsInPlace(init) }
return MemoryBuffer.create(Complex, size, init) return MemoryBuffer.create(Complex, size, init)
} }
inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
contract { callsInPlace(init) }
return MemoryBuffer.create(Complex, size, init) return MemoryBuffer.create(Complex, size, init)
} }

View File

@ -1,25 +1,35 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
/** /**
* Advanced Number-like field that implements basic operations * Advanced Number-like semifield that implements basic operations.
*/ */
interface ExtendedFieldOperations<T> : interface ExtendedFieldOperations<T> :
InverseTrigonometricOperations<T>, FieldOperations<T>,
TrigonometricOperations<T>,
HyperbolicOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T> { ExponentialOperations<T> {
override fun tan(arg: T): T = sin(arg) / cos(arg) override fun tan(arg: T): T = sin(arg) / cos(arg)
override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg)
TrigonometricOperations.TAN_OPERATION -> tan(arg) TrigonometricOperations.TAN_OPERATION -> tan(arg)
InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) TrigonometricOperations.ACOS_OPERATION -> acos(arg)
InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) TrigonometricOperations.ASIN_OPERATION -> asin(arg)
InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) TrigonometricOperations.ATAN_OPERATION -> atan(arg)
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
PowerOperations.SQRT_OPERATION -> sqrt(arg) PowerOperations.SQRT_OPERATION -> sqrt(arg)
ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.EXP_OPERATION -> exp(arg)
ExponentialOperations.LN_OPERATION -> ln(arg) ExponentialOperations.LN_OPERATION -> ln(arg)
@ -27,7 +37,18 @@ interface ExtendedFieldOperations<T> :
} }
} }
/**
* Advanced Number-like field that implements basic operations.
*/
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> { interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg)
override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
PowerOperations.POW_OPERATION -> power(left, right) PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right) else -> super.rightSideNumberOperation(operation, left, right)
@ -37,175 +58,213 @@ interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
/** /**
* Real field element wrapping double. * Real field element wrapping double.
* *
* @property value the [Double] value wrapped by this [Real].
*
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/ */
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> { inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
override val context: RealField
get() = RealField
override fun unwrap(): Double = value override fun unwrap(): Double = value
override fun Double.wrap(): Real = Real(value) override fun Double.wrap(): Real = Real(value)
override val context get() = RealField
companion object companion object
} }
/** /**
* A field for double without boxing. Does not produce appropriate field element * A field for [Double] without boxing. Does not produce appropriate field element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object RealField : ExtendedField<Double>, Norm<Double, Double> { object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0 override val zero: Double
override inline fun add(a: Double, b: Double) = a + b get() = 0.0
override inline fun multiply(a: Double, b: Double) = a * b
override inline fun multiply(a: Double, k: Number) = a * k.toDouble()
override val one: Double = 1.0 override val one: Double
override inline fun divide(a: Double, b: Double) = a / b get() = 1.0
override inline fun sin(arg: Double) = kotlin.math.sin(arg) override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
override inline fun cos(arg: Double) = kotlin.math.cos(arg) PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
override inline fun add(a: Double, b: Double): Double = a + b
override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
override inline fun multiply(a: Double, b: Double): Double = a * b
override inline fun divide(a: Double, b: Double): Double = a / b
override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg)
override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg)
override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg)
override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg)
override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg)
override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg)
override inline fun exp(arg: Double) = kotlin.math.exp(arg) override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
override inline fun ln(arg: Double) = kotlin.math.ln(arg) override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
override inline fun norm(arg: Double) = abs(arg) override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus() = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b
override inline fun Double.plus(b: Double) = this + b override inline fun Double.minus(b: Double): Double = this - b
override inline fun Double.times(b: Double): Double = this * b
override inline fun Double.minus(b: Double) = this - b override inline fun Double.div(b: Double): Double = this / b
override inline fun Double.times(b: Double) = this * b
override inline fun Double.div(b: Double) = this / b
}
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float = 0f
override inline fun add(a: Float, b: Float) = a + b
override inline fun multiply(a: Float, b: Float) = a * b
override inline fun multiply(a: Float, k: Number) = a * k.toFloat()
override val one: Float = 1f
override inline fun divide(a: Float, b: Float) = a / b
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
override inline fun tan(arg: Float) = kotlin.math.tan(arg)
override inline fun acos(arg: Float) = kotlin.math.acos(arg)
override inline fun asin(arg: Float) = kotlin.math.asin(arg)
override inline fun atan(arg: Float) = kotlin.math.atan(arg)
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
override inline fun exp(arg: Float) = kotlin.math.exp(arg)
override inline fun ln(arg: Float) = kotlin.math.ln(arg)
override inline fun norm(arg: Float) = abs(arg)
override inline fun Float.unaryMinus() = -this
override inline fun Float.plus(b: Float) = this + b
override inline fun Float.minus(b: Float) = this - b
override inline fun Float.times(b: Float) = this * b
override inline fun Float.div(b: Float) = this / b
} }
/** /**
* A field for [Int] without boxing. Does not produce corresponding field element * A field for [Float] without boxing. Does not produce appropriate field element.
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float
get() = 0.0f
override val one: Float
get() = 1.0f
override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
override inline fun add(a: Float, b: Float): Float = a + b
override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
override inline fun multiply(a: Float, b: Float): Float = a * b
override inline fun divide(a: Float, b: Float): Float = a / b
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
override inline fun acos(arg: Float): Float = kotlin.math.acos(arg)
override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg)
override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg)
override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg)
override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg)
override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg)
override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg)
override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat())
override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(b: Float): Float = this + b
override inline fun Float.minus(b: Float): Float = this - b
override inline fun Float.times(b: Float): Float = this * b
override inline fun Float.div(b: Float): Float = this / b
}
/**
* A field for [Int] without boxing. Does not produce corresponding ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object IntRing : Ring<Int>, Norm<Int, Int> { object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int = 0 override val zero: Int
override inline fun add(a: Int, b: Int) = a + b get() = 0
override inline fun multiply(a: Int, b: Int) = a * b
override inline fun multiply(a: Int, k: Number) = k.toInt() * a
override val one: Int = 1
override inline fun norm(arg: Int) = abs(arg) override val one: Int
get() = 1
override inline fun Int.unaryMinus() = -this override inline fun add(a: Int, b: Int): Int = a + b
override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
override inline fun multiply(a: Int, b: Int): Int = a * b
override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(b: Int): Int = this + b override inline fun Int.plus(b: Int): Int = this + b
override inline fun Int.minus(b: Int): Int = this - b override inline fun Int.minus(b: Int): Int = this - b
override inline fun Int.times(b: Int): Int = this * b override inline fun Int.times(b: Int): Int = this * b
} }
/** /**
* A field for [Short] without boxing. Does not produce appropriate field element * A field for [Short] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ShortRing : Ring<Short>, Norm<Short, Short> { object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short = 0 override val zero: Short
override inline fun add(a: Short, b: Short) = (a + b).toShort() get() = 0
override inline fun multiply(a: Short, b: Short) = (a * b).toShort()
override inline fun multiply(a: Short, k: Number) = (a * k.toShort()).toShort() override val one: Short
override val one: Short = 1 get() = 1
override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus() = (-this).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(b: Short): Short = (this + b).toShort()
override inline fun Short.plus(b: Short) = (this + b).toShort() override inline fun Short.minus(b: Short): Short = (this - b).toShort()
override inline fun Short.times(b: Short): Short = (this * b).toShort()
override inline fun Short.minus(b: Short) = (this - b).toShort()
override inline fun Short.times(b: Short) = (this * b).toShort()
} }
/** /**
* A field for [Byte] values * A field for [Byte] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ByteRing : Ring<Byte>, Norm<Byte, Byte> { object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte = 0 override val zero: Byte
override inline fun add(a: Byte, b: Byte) = (a + b).toByte() get() = 0
override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte()
override inline fun multiply(a: Byte, k: Number) = (a * k.toByte()).toByte() override val one: Byte
override val one: Byte = 1 get() = 1
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus() = (-this).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
override inline fun Byte.plus(b: Byte) = (this + b).toByte() override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
override inline fun Byte.minus(b: Byte) = (this - b).toByte()
override inline fun Byte.times(b: Byte) = (this * b).toByte()
} }
/** /**
* A field for [Long] values * A field for [Double] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object LongRing : Ring<Long>, Norm<Long, Long> { object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long = 0 override val zero: Long
override inline fun add(a: Long, b: Long) = (a + b) get() = 0
override inline fun multiply(a: Long, b: Long) = (a * b)
override inline fun multiply(a: Long, k: Number) = a * k.toLong() override val one: Long
override val one: Long = 1 get() = 1
override inline fun add(a: Long, b: Long): Long = a + b
override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
override inline fun multiply(a: Long, b: Long): Long = a * b
override fun norm(arg: Long): Long = abs(arg) override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus() = (-this) override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(b: Long): Long = (this + b)
override inline fun Long.plus(b: Long) = (this + b) override inline fun Long.minus(b: Long): Long = (this - b)
override inline fun Long.times(b: Long): Long = (this * b)
override inline fun Long.minus(b: Long) = (this - b)
override inline fun Long.times(b: Long) = (this * b)
} }

View File

@ -1,84 +1,309 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/* Trigonometric operations */
/** /**
* A container for trigonometric operations for specific type. Trigonometric operations are limited to fields. * A container for trigonometric operations for specific type.
*
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
* It also allows to override behavior for optional operations
* *
* @param T the type of element of this structure.
*/
interface TrigonometricOperations<T> : Algebra<T> {
/**
* Computes the sine of [arg].
*/ */
interface TrigonometricOperations<T> : FieldOperations<T> {
fun sin(arg: T): T fun sin(arg: T): T
/**
* Computes the cosine of [arg].
*/
fun cos(arg: T): T fun cos(arg: T): T
/**
* Computes the tangent of [arg].
*/
fun tan(arg: T): T fun tan(arg: T): T
companion object { /**
const val SIN_OPERATION = "sin" * Computes the inverse sine of [arg].
const val COS_OPERATION = "cos" */
const val TAN_OPERATION = "tan"
}
}
interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
fun asin(arg: T): T fun asin(arg: T): T
/**
* Computes the inverse cosine of [arg].
*/
fun acos(arg: T): T fun acos(arg: T): T
/**
* Computes the inverse tangent of [arg].
*/
fun atan(arg: T): T fun atan(arg: T): T
companion object { companion object {
const val ASIN_OPERATION = "asin" /**
const val ACOS_OPERATION = "acos" * The identifier of sine.
const val ATAN_OPERATION = "atan" */
const val SIN_OPERATION: String = "sin"
/**
* The identifier of cosine.
*/
const val COS_OPERATION: String = "cos"
/**
* The identifier of tangent.
*/
const val TAN_OPERATION: String = "tan"
/**
* The identifier of inverse sine.
*/
const val ASIN_OPERATION: String = "asin"
/**
* The identifier of inverse cosine.
*/
const val ACOS_OPERATION: String = "acos"
/**
* The identifier of inverse tangent.
*/
const val ATAN_OPERATION: String = "atan"
} }
} }
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/* Power and roots */
/** /**
* A context extension to include power operations like square roots, etc * Computes the sine of [arg].
*/ */
interface PowerOperations<T> : Algebra<T> { fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
fun power(arg: T, pow: Number): T
fun sqrt(arg: T) = power(arg, 0.5)
infix fun T.pow(pow: Number) = power(this, pow) /**
* Computes the cosine of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
/**
* Computes the tangent of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
/**
* Computes the inverse sine of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
/**
* Computes the inverse cosine of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
/**
* Computes the inverse tangent of [arg].
*/
fun <T : MathElement<out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/**
* A container for hyperbolic trigonometric operations for specific type.
*
* @param T the type of element of this structure.
*/
interface HyperbolicOperations<T> : Algebra<T> {
/**
* Computes the hyperbolic sine of [arg].
*/
fun sinh(arg: T): T
/**
* Computes the hyperbolic cosine of [arg].
*/
fun cosh(arg: T): T
/**
* Computes the hyperbolic tangent of [arg].
*/
fun tanh(arg: T): T
/**
* Computes the inverse hyperbolic sine of [arg].
*/
fun asinh(arg: T): T
/**
* Computes the inverse hyperbolic cosine of [arg].
*/
fun acosh(arg: T): T
/**
* Computes the inverse hyperbolic tangent of [arg].
*/
fun atanh(arg: T): T
companion object { companion object {
const val POW_OPERATION = "pow" /**
const val SQRT_OPERATION = "sqrt" * The identifier of hyperbolic sine.
*/
const val SINH_OPERATION: String = "sinh"
/**
* The identifier of hyperbolic cosine.
*/
const val COSH_OPERATION: String = "cosh"
/**
* The identifier of hyperbolic tangent.
*/
const val TANH_OPERATION: String = "tanh"
/**
* The identifier of inverse hyperbolic sine.
*/
const val ASINH_OPERATION: String = "asinh"
/**
* The identifier of inverse hyperbolic cosine.
*/
const val ACOSH_OPERATION: String = "acosh"
/**
* The identifier of inverse hyperbolic tangent.
*/
const val ATANH_OPERATION: String = "atanh"
} }
} }
/**
* Computes the hyperbolic sine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> sinh(arg: T): T = arg.context.sinh(arg)
/**
* Computes the hyperbolic cosine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> cosh(arg: T): T = arg.context.cosh(arg)
/**
* Computes the hyperbolic tangent of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> tanh(arg: T): T = arg.context.tanh(arg)
/**
* Computes the inverse hyperbolic sine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> asinh(arg: T): T = arg.context.asinh(arg)
/**
* Computes the inverse hyperbolic cosine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> acosh(arg: T): T = arg.context.acosh(arg)
/**
* Computes the inverse hyperbolic tangent of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg)
/**
* A context extension to include power operations based on exponentiation.
*
* @param T the type of element of this structure.
*/
interface PowerOperations<T> : Algebra<T> {
/**
* Raises [arg] to the power [pow].
*/
fun power(arg: T, pow: Number): T
/**
* Computes the square root of the value [arg].
*/
fun sqrt(arg: T): T = power(arg, 0.5)
/**
* Raises this value to the power [pow].
*/
infix fun T.pow(pow: Number): T = power(this, pow)
companion object {
/**
* The identifier of exponentiation.
*/
const val POW_OPERATION: String = "pow"
/**
* The identifier of square root.
*/
const val SQRT_OPERATION: String = "sqrt"
}
}
/**
* Raises this element to the power [pow].
*
* @receiver the base.
* @param power the exponent.
* @return the base raised to the power.
*/
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power) infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
/**
* Computes the square root of the value [arg].
*/
fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5 fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
/**
* Computes the square of the value [arg].
*/
fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0 fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/* Exponential */ /**
* A container for operations related to `exp` and `ln` functions.
*
* @param T the type of element of this structure.
*/
interface ExponentialOperations<T> : Algebra<T> { interface ExponentialOperations<T> : Algebra<T> {
/**
* Computes Euler's number `e` raised to the power of the value [arg].
*/
fun exp(arg: T): T fun exp(arg: T): T
/**
* Computes the natural logarithm (base `e`) of the value [arg].
*/
fun ln(arg: T): T fun ln(arg: T): T
companion object { companion object {
const val EXP_OPERATION = "exp" /**
const val LN_OPERATION = "ln" * The identifier of exponential function.
*/
const val EXP_OPERATION: String = "exp"
/**
* The identifier of natural logarithm.
*/
const val LN_OPERATION: String = "ln"
} }
} }
/**
* The identifier of exponential function.
*/
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg) fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
/**
* The identifier of natural logarithm.
*/
fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg) fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
/**
* A container for norm functional on element.
*
* @param T the type of element having norm defined.
* @param R the type of norm.
*/
interface Norm<in T : Any, out R> { interface Norm<in T : Any, out R> {
/**
* Computes the norm of [arg] (i.e. absolute value or vector length).
*/
fun norm(arg: T): R fun norm(arg: T): R
} }
/**
* Computes the norm of [arg] (i.e. absolute value or vector length).
*/
fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg) fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)

View File

@ -3,32 +3,30 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.FieldElement
class BoxingNDField<T, F : Field<T>>( class BoxingNDField<T, F : Field<T>>(
override val shape: IntArray, override val shape: IntArray,
override val elementContext: F, override val elementContext: F,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : BufferedNDField<T, F> { ) : BufferedNDField<T, F> {
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer) bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>) {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
} }
override val zero by lazy { produce { zero } } override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
override val one by lazy { produce { one } }
override fun produce(initializer: F.(IntArray) -> T) =
BufferedNDFieldElement( BufferedNDFieldElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> { override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) }) buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) })

View File

@ -8,20 +8,17 @@ class BoxingNDRing<T, R : Ring<T>>(
override val elementContext: R, override val elementContext: R,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : BufferedNDRing<T, R> { ) : BufferedNDRing<T, R> {
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>) {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
} }
override val zero by lazy { produce { zero } } override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
override val one by lazy { produce { one } }
override fun produce(initializer: R.(IntArray) -> T) =
BufferedNDRingElement( BufferedNDRingElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })

View File

@ -6,17 +6,16 @@ import kotlin.reflect.KClass
* A context that allows to operate on a [MutableBuffer] as on 2d array * A context that allows to operate on a [MutableBuffer] as on 2d array
*/ */
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) { class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
operator fun Buffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) { operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
set(i + colNum * j, value) set(i + colNum * j, value)
} }
inline fun create(init: (i: Int, j: Int) -> T) = inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer<T> =
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] } fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
//TODO optimize wrapper //TODO optimize wrapper
fun MutableBuffer<T>.collect(): Structure2D<T> = fun MutableBuffer<T>.collect(): Structure2D<T> =
@ -26,20 +25,19 @@ class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> { inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
override val size: Int get() = colNum override val size: Int get() = colNum
override fun get(index: Int): T = buffer[rowIndex, index] override operator fun get(index: Int): T = buffer[rowIndex, index]
override fun set(index: Int, value: T) { override operator fun set(index: Int, value: T) {
buffer[rowIndex, index] = value buffer[rowIndex, index] = value
} }
override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) } override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) }
override operator fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
override fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
} }
/** /**
* Get row * Get row
*/ */
fun MutableBuffer<T>.row(i: Int) = Row(this, i) fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
} }

View File

@ -2,16 +2,16 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{ interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
val strides: Strides val strides: Strides
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>): Unit =
if (!elements.all { it.strides == this.strides }) error("Strides mismatch") require(elements.all { it.strides == strides }) { ("Strides mismatch") }
}
/** /**
* Convert any [NDStructure] to buffered structure using strides from this context. * Convert any [NDStructure] to buffered structure using strides from this context.
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over
* indices.
* *
* If the argument is [NDBuffer] with different strides structure, the new element will be produced. * If the argument is [NDBuffer] with different strides structure, the new element will be produced.
*/ */
@ -30,7 +30,7 @@ interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
} }
interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T,S> { interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>> override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
} }

View File

@ -8,7 +8,7 @@ import scientifik.kmath.operations.*
abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> { abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> {
abstract override val context: BufferedNDAlgebra<T, C> abstract override val context: BufferedNDAlgebra<T, C>
override val strides get() = context.strides override val strides: Strides get() = context.strides
override val shape: IntArray get() = context.shape override val shape: IntArray get() = context.shape
} }
@ -30,7 +30,6 @@ class BufferedNDRingElement<T, R : Ring<T>>(
override val context: BufferedNDRing<T, R>, override val context: BufferedNDRing<T, R>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> { ) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> = this
override fun NDBuffer<T>.wrap(): BufferedNDRingElement<T, R> { override fun NDBuffer<T>.wrap(): BufferedNDRingElement<T, R> {
@ -43,7 +42,6 @@ class BufferedNDFieldElement<T, F : Field<T>>(
override val context: BufferedNDField<T, F>, override val context: BufferedNDField<T, F>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> { ) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
override fun unwrap(): NDBuffer<T> = this override fun unwrap(): NDBuffer<T> = this
override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> { override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> {
@ -54,9 +52,9 @@ class BufferedNDFieldElement<T, F : Field<T>>(
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy.
*/ */
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>) = operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>): MathElement<out BufferedNDAlgebra<T, F>> =
ndElement.context.run { map(ndElement) { invoke(it) }.toElement() } ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
/* plus and minus */ /* plus and minus */
@ -64,13 +62,13 @@ operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedN
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T) = operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it + arg }.wrap() context.map(this) { it + arg }.wrap()
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) = operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it - arg }.wrap() context.map(this) { it - arg }.wrap()
/* prod and div */ /* prod and div */
@ -78,11 +76,11 @@ operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) =
/** /**
* Product operation for [BufferedNDElement] and single element * Product operation for [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T) = operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it * arg }.wrap() context.map(this) { it * arg }.wrap()
/** /**
* Division operation between [BufferedNDElement] and single element * Division operation between [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T) = operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it / arg }.wrap() context.map(this) { it / arg }.wrap()

View File

@ -2,41 +2,52 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.complex import scientifik.kmath.operations.complex
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.reflect.KClass import kotlin.reflect.KClass
/**
* Function that produces [Buffer] from its size and function that supplies values.
*
* @param T the type of buffer.
*/
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T> typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
/** /**
* A generic random access structure for both primitives and objects * Function that produces [MutableBuffer] from its size and function that supplies values.
*
* @param T the type of buffer.
*/
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
/**
* A generic immutable random-access structure for both primitives and objects.
*
* @param T the type of elements contained in the buffer.
*/ */
interface Buffer<T> { interface Buffer<T> {
/** /**
* The size of the buffer * The size of this buffer.
*/ */
val size: Int val size: Int
/** /**
* Get element at given index * Gets element at given index.
*/ */
operator fun get(index: Int): T operator fun get(index: Int): T
/** /**
* Iterate over all elements * Iterates over all elements.
*/ */
operator fun iterator(): Iterator<T> operator fun iterator(): Iterator<T>
/** /**
* Check content eqiality with another buffer * Checks content equality with another buffer.
*/ */
fun contentEquals(other: Buffer<*>): Boolean = fun contentEquals(other: Buffer<*>): Boolean =
asSequence().mapIndexed { index, value -> value == other[index] }.all { it } asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
companion object { companion object {
inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
val array = DoubleArray(size) { initializer(it) } val array = DoubleArray(size) { initializer(it) }
return RealBuffer(array) return RealBuffer(array)
@ -69,17 +80,34 @@ interface Buffer<T> {
} }
} }
/**
* Creates a sequence that returns all elements from this [Buffer].
*/
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator) fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
/**
* Creates an iterable that returns all elements from this [Buffer].
*/
fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator) fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator)
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1) /**
* Returns an [IntRange] of the valid indices for this [Buffer].
*/
val Buffer<*>.indices: IntRange get() = 0 until size
/**
* A generic mutable random-access structure for both primitives and objects.
*
* @param T the type of elements contained in the buffer.
*/
interface MutableBuffer<T> : Buffer<T> { interface MutableBuffer<T> : Buffer<T> {
/**
* Sets the array element at the specified [index] to the specified [value].
*/
operator fun set(index: Int, value: T) operator fun set(index: Int, value: T)
/** /**
* A shallow copy of the buffer * Returns a shallow copy of the buffer.
*/ */
fun copy(): MutableBuffer<T> fun copy(): MutableBuffer<T>
@ -91,15 +119,14 @@ interface MutableBuffer<T> : Buffer<T> {
MutableListBuffer(MutableList(size, initializer)) MutableListBuffer(MutableList(size, initializer))
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> { inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
return when (type) { when (type) {
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
else -> boxing(size, initializer) else -> boxing(size, initializer)
} }
}
/** /**
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible * Create most appropriate mutable buffer for given type avoiding boxing wherever possible
@ -114,73 +141,110 @@ interface MutableBuffer<T> : Buffer<T> {
} }
} }
/**
* [Buffer] implementation over [List].
*
* @param T the type of elements contained in the buffer.
* @property list The underlying list.
*/
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> { inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
override fun get(index: Int): T = list[index] override operator fun get(index: Int): T = list[index]
override operator fun iterator(): Iterator<T> = list.iterator()
override fun iterator(): Iterator<T> = list.iterator()
}
fun <T> List<T>.asBuffer() = ListBuffer<T>(this)
@Suppress("FunctionName")
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
override val size: Int
get() = list.size
override fun get(index: Int): T = list[index]
override fun set(index: Int, value: T) {
list[index] = value
}
override fun iterator(): Iterator<T> = list.iterator()
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
}
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
//Can't inline because array is invariant
override val size: Int
get() = array.size
override fun get(index: Int): T = array[index]
override fun set(index: Int, value: T) {
array[index] = value
}
override fun iterator(): Iterator<T> = array.iterator()
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
}
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size
override fun get(index: Int): T = buffer.get(index)
override fun iterator() = buffer.iterator()
} }
/** /**
* A buffer with content calculated on-demand. The calculated contect is not stored, so it is recalculated on each call. * Returns an [ListBuffer] that wraps the original list.
*/
fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
/**
* Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified
* [init] function.
*
* The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an array element given its index.
*/
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> {
contract { callsInPlace(init) }
return List(size, init).asBuffer()
}
/**
* [MutableBuffer] implementation over [MutableList].
*
* @param T the type of elements contained in the buffer.
* @property list The underlying list.
*/
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
override val size: Int
get() = list.size
override operator fun get(index: Int): T = list[index]
override operator fun set(index: Int, value: T) {
list[index] = value
}
override operator fun iterator(): Iterator<T> = list.iterator()
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
}
/**
* [MutableBuffer] implementation over [Array].
*
* @param T the type of elements contained in the buffer.
* @property array The underlying array.
*/
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
// Can't inline because array is invariant
override val size: Int
get() = array.size
override operator fun get(index: Int): T = array[index]
override operator fun set(index: Int, value: T) {
array[index] = value
}
override operator fun iterator(): Iterator<T> = array.iterator()
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
}
/**
* Returns an [ArrayBuffer] that wraps the original array.
*/
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
/**
* Immutable wrapper for [MutableBuffer].
*
* @param T the type of elements contained in the buffer.
* @property buffer The underlying buffer.
*/
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size
override operator fun get(index: Int): T = buffer[index]
override operator fun iterator(): Iterator<T> = buffer.iterator()
}
/**
* A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call.
* Useful when one needs single element from the buffer. * Useful when one needs single element from the buffer.
*
* @param T the type of elements provided by the buffer.
*/ */
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> { class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
override fun get(index: Int): T { override operator fun get(index: Int): T {
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
return generator(index) return generator(index)
} }
override fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
override fun contentEquals(other: Buffer<*>): Boolean { override fun contentEquals(other: Buffer<*>): Boolean {
return if (other is VirtualBuffer) { return if (other is VirtualBuffer) {
@ -192,17 +256,16 @@ class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T
} }
/** /**
* Convert this buffer to read-only buffer * Convert this buffer to read-only buffer.
*/ */
fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) { fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) ReadOnlyBuffer(this) else this
ReadOnlyBuffer(this)
} else {
this
}
/** /**
* Typealias for buffer transformations * Typealias for buffer transformations.
*/ */
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R> typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
/**
* Typealias for buffer transformations with suspend function.
*/
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R> typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>

View File

@ -4,6 +4,9 @@ import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.FieldElement
import scientifik.kmath.operations.complex import scientifik.kmath.operations.complex
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField> typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
@ -15,10 +18,9 @@ class ComplexNDField(override val shape: IntArray) :
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> { ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: ComplexField get() = ComplexField override val elementContext: ComplexField get() = ComplexField
override val zero by lazy { produce { zero } } override val zero: ComplexNDElement by lazy { produce { zero } }
override val one by lazy { produce { one } } override val one: ComplexNDElement by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
Buffer.complex(size) { initializer(it) } Buffer.complex(size) { initializer(it) }
@ -45,6 +47,7 @@ class ComplexNDField(override val shape: IntArray) :
transform: ComplexField.(index: IntArray, Complex) -> Complex transform: ComplexField.(index: IntArray, Complex) -> Complex
): ComplexNDElement { ): ComplexNDElement {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -61,6 +64,7 @@ class ComplexNDField(override val shape: IntArray) :
transform: ComplexField.(Complex, Complex) -> Complex transform: ComplexField.(Complex, Complex) -> Complex
): ComplexNDElement { ): ComplexNDElement {
check(a, b) check(a, b)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
@ -69,23 +73,25 @@ class ComplexNDField(override val shape: IntArray) :
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> = override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
BufferedNDFieldElement(this@ComplexNDField, buffer) BufferedNDFieldElement(this@ComplexNDField, buffer)
override fun power(arg: NDBuffer<Complex>, pow: Number) = map(arg) { power(it, pow) } override fun power(arg: NDBuffer<Complex>, pow: Number): ComplexNDElement =
map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Complex>) = map(arg) { exp(it) } override fun exp(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { ln(it) }
override fun ln(arg: NDBuffer<Complex>) = map(arg) { ln(it) } override fun sin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atan(it) }
override fun sin(arg: NDBuffer<Complex>) = map(arg) { sin(it) } override fun sinh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sinh(it) }
override fun cosh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cosh(it) }
override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) } override fun tanh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tanh(it) }
override fun asinh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asinh(it) }
override fun tan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { tan(it) } override fun acosh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acosh(it) }
override fun atanh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atanh(it) }
override fun asin(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {acos(it)}
override fun atan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {atan(it)}
} }
@ -98,15 +104,16 @@ inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline init
} }
/** /**
* Map one [ComplexNDElement] using function with indexes * Map one [ComplexNDElement] using function with indices.
*/ */
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex) = inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
/** /**
* Map one [ComplexNDElement] using function without indexes * Map one [ComplexNDElement] using function without indices.
*/ */
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
contract { callsInPlace(transform) }
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
return BufferedNDFieldElement(context, buffer) return BufferedNDFieldElement(context, buffer)
} }
@ -114,7 +121,7 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) ->
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) = operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
ndElement.map { this@invoke(it) } ndElement.map { this@invoke(it) }
@ -123,19 +130,18 @@ operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun ComplexNDElement.plus(arg: Complex) = operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg }
map { it + arg }
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun ComplexNDElement.minus(arg: Complex) = operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement =
map { it - arg } map { it - arg }
operator fun ComplexNDElement.plus(arg: Double) = operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement =
map { it + arg } map { it + arg }
operator fun ComplexNDElement.minus(arg: Double) = operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement =
map { it - arg } map { it - arg }
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape) fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
@ -147,5 +153,6 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
return NDField.complex(*shape).run(action) contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return NDField.complex(*shape).action()
} }

View File

@ -2,9 +2,15 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedField
/**
* [ExtendedField] over [NDStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the extended field of structure elements.
*/
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N> interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
///** ///**
// * NDField that supports [ExtendedField] operations on its elements // * NDField that supports [ExtendedField] operations on its elements
// */ // */
@ -36,5 +42,3 @@ interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : N
// return produce { with(elementContext) { cos(arg[it]) } } // return produce { with(elementContext) { cos(arg[it]) } }
// } // }
//} //}

View File

@ -1,16 +1,38 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.experimental.and import kotlin.experimental.and
/**
* Represents flags to supply additional info about values of buffer.
*
* @property mask bit mask value of this flag.
*/
enum class ValueFlag(val mask: Byte) { enum class ValueFlag(val mask: Byte) {
/**
* Reports the value is NaN.
*/
NAN(0b0000_0001), NAN(0b0000_0001),
/**
* Reports the value doesn't present in the buffer (when the type of value doesn't support `null`).
*/
MISSING(0b0000_0010), MISSING(0b0000_0010),
/**
* Reports the value is negative infinity.
*/
NEGATIVE_INFINITY(0b0000_0100), NEGATIVE_INFINITY(0b0000_0100),
/**
* Reports the value is positive infinity
*/
POSITIVE_INFINITY(0b0000_1000) POSITIVE_INFINITY(0b0000_1000)
} }
/** /**
* A buffer with flagged values * A buffer with flagged values.
*/ */
interface FlaggedBuffer<T> : Buffer<T> { interface FlaggedBuffer<T> : Buffer<T> {
fun getFlag(index: Int): Byte fun getFlag(index: Int): Byte
@ -19,11 +41,11 @@ interface FlaggedBuffer<T> : Buffer<T> {
/** /**
* The value is valid if all flags are down * The value is valid if all flags are down
*/ */
fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte() fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte() fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING) fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
/** /**
* A real buffer which supports flags for each value like NaN or Missing * A real buffer which supports flags for each value like NaN or Missing
@ -37,17 +59,18 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
override val size: Int get() = values.size override val size: Int get() = values.size
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map { override operator fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
if (isValid(it)) values[it] else null if (isValid(it)) values[it] else null
}.iterator() }.iterator()
} }
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
for(i in indices){ contract { callsInPlace(block) }
if(isValid(i)){
block(values[i]) indices
} .asSequence()
} .filter(::isValid)
.forEach { block(values[it]) }
} }

View File

@ -0,0 +1,55 @@
package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/**
* Specialized [MutableBuffer] implementation over [FloatArray].
*
* @property array the underlying array.
*/
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
override val size: Int get() = array.size
override operator fun get(index: Int): Float = array[index]
override operator fun set(index: Int, value: Float) {
array[index] = value
}
override operator fun iterator(): FloatIterator = array.iterator()
override fun copy(): MutableBuffer<Float> =
FloatBuffer(array.copyOf())
}
/**
* Creates a new [FloatBuffer] with the specified [size], where each element is calculated by calling the specified
* [init] function.
*
* The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index.
*/
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer {
contract { callsInPlace(init) }
return FloatBuffer(FloatArray(size) { init(it) })
}
/**
* Returns a new [FloatBuffer] of given elements.
*/
fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats)
/**
* Returns a [FloatArray] containing all of the elements of this [MutableBuffer].
*/
val MutableBuffer<out Float>.array: FloatArray
get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) })
/**
* Returns [FloatBuffer] over this array.
*
* @receiver the array.
* @return the new buffer.
*/
fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this)

View File

@ -1,20 +1,56 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/**
* Specialized [MutableBuffer] implementation over [IntArray].
*
* @property array the underlying array.
*/
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> { inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Int = array[index] override operator fun get(index: Int): Int = array[index]
override fun set(index: Int, value: Int) { override operator fun set(index: Int, value: Int) {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override operator fun iterator(): IntIterator = array.iterator()
override fun copy(): MutableBuffer<Int> = override fun copy(): MutableBuffer<Int> =
IntBuffer(array.copyOf()) IntBuffer(array.copyOf())
} }
/**
* Creates a new [IntBuffer] with the specified [size], where each element is calculated by calling the specified
* [init] function.
*
* The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index.
*/
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer {
contract { callsInPlace(init) }
return IntBuffer(IntArray(size) { init(it) })
}
fun IntArray.asBuffer() = IntBuffer(this) /**
* Returns a new [IntBuffer] of given elements.
*/
fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints)
/**
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
*/
val MutableBuffer<out Int>.array: IntArray
get() = (if (this is IntBuffer) array else IntArray(size) { get(it) })
/**
* Returns [IntBuffer] over this array.
*
* @receiver the array.
* @return the new buffer.
*/
fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)

View File

@ -1,19 +1,56 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/**
* Specialized [MutableBuffer] implementation over [LongArray].
*
* @property array the underlying array.
*/
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> { inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Long = array[index] override operator fun get(index: Int): Long = array[index]
override fun set(index: Int, value: Long) { override operator fun set(index: Int, value: Long) {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override operator fun iterator(): LongIterator = array.iterator()
override fun copy(): MutableBuffer<Long> = override fun copy(): MutableBuffer<Long> =
LongBuffer(array.copyOf()) LongBuffer(array.copyOf())
} }
fun LongArray.asBuffer() = LongBuffer(this) /**
* Creates a new [LongBuffer] with the specified [size], where each element is calculated by calling the specified
* [init] function.
*
* The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index.
*/
inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer {
contract { callsInPlace(init) }
return LongBuffer(LongArray(size) { init(it) })
}
/**
* Returns a new [LongBuffer] of given elements.
*/
fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs)
/**
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
*/
val MutableBuffer<out Long>.array: LongArray
get() = (if (this is LongBuffer) array else LongArray(size) { get(it) })
/**
* Returns [LongBuffer] over this array.
*
* @receiver the array.
* @return the new buffer.
*/
fun LongArray.asBuffer(): LongBuffer = LongBuffer(this)

View File

@ -3,21 +3,22 @@ package scientifik.kmath.structures
import scientifik.memory.* import scientifik.memory.*
/** /**
* A non-boxing buffer based on [ByteBuffer] storage * A non-boxing buffer over [Memory] object.
*
* @param T the type of elements contained in the buffer.
* @property memory the underlying memory segment.
* @property spec the spec of [T] type.
*/ */
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> { open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
override val size: Int get() = memory.size / spec.objectSize override val size: Int get() = memory.size / spec.objectSize
private val reader = memory.reader() private val reader: MemoryReader = memory.reader()
override fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
companion object { companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int) = fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( inline fun <T : Any> create(
@ -33,24 +34,30 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
} }
} }
/**
* A mutable non-boxing buffer over [Memory] object.
*
* @param T the type of elements contained in the buffer.
* @property memory the underlying memory segment.
* @property spec the spec of [T] type.
*/
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec), class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
MutableBuffer<T> { MutableBuffer<T> {
private val writer = memory.writer() private val writer: MemoryWriter = memory.writer()
override fun set(index: Int, value: T) = writer.write(spec, spec.objectSize * index, value)
override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec) override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
companion object { companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int) = fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( inline fun <T : Any> create(
spec: MemorySpec<T>, spec: MemorySpec<T>,
size: Int, size: Int,
crossinline initializer: (Int) -> T crossinline initializer: (Int) -> T
) = ): MutableMemoryBuffer<T> =
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
(0 until size).forEach { (0 until size).forEach {
buffer[it] = initializer(it) buffer[it] = initializer(it)

View File

@ -56,7 +56,7 @@ interface NDAlgebra<T, C, N : NDStructure<T>> {
/** /**
* element-by-element invoke a function working on [T] on a [NDStructure] * element-by-element invoke a function working on [T] on a [NDStructure]
*/ */
operator fun Function1<T, T>.invoke(structure: N) = map(structure) { value -> this@invoke(value) } operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
companion object companion object
} }
@ -76,12 +76,12 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) } override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) } operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) } operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
operator fun T.plus(arg: N) = map(arg) { value -> add(this@plus, value) } operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
operator fun T.minus(arg: N) = map(arg) { value -> add(-this@minus, value) } operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
companion object companion object
} }
@ -97,20 +97,19 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) } operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) } operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
companion object companion object
} }
/** /**
* Field for n-dimensional structures. * Field of [NDStructure].
* @param shape - the list of dimensions of the array *
* @param elementField - operations field defined on individual array element * @param T the type of the element contained in ND structure.
* @param T - the type of the element contained in ND structure * @param N the type of ND structure.
* @param F - field of structure elements * @param F field of structure elements.
* @param R - actual nd-element type of this field
*/ */
interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> { interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
@ -120,9 +119,9 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) } operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) } operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
companion object { companion object {
@ -131,7 +130,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
/** /**
* Create a nd-field for [Double] values or pull it from cache if it was created previously * Create a nd-field for [Double] values or pull it from cache if it was created previously
*/ */
fun real(vararg shape: Int) = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
/** /**
* Create a nd-field with boxing generic buffer * Create a nd-field with boxing generic buffer
@ -140,7 +139,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
field: F, field: F,
vararg shape: Int, vararg shape: Int,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
) = BoxingNDField(shape, field, bufferFactory) ): BoxingNDField<T, F> = BoxingNDField(shape, field, bufferFactory)
/** /**
* Create a most suitable implementation for nd-field using reified class. * Create a most suitable implementation for nd-field using reified class.

View File

@ -23,19 +23,24 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
/** /**
* Create a optimized NDArray of doubles * Create a optimized NDArray of doubles
*/ */
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }) = fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
NDField.real(*shape).produce(initializer) NDField.real(*shape).produce(initializer)
inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) =
real(intArrayOf(dim)) { initializer(it[0]) } real(intArrayOf(dim)) { initializer(it[0]) }
inline fun real2D(
dim1: Int,
dim2: Int,
crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 }
): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) = inline fun real3D(
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } dim1: Int,
dim2: Int,
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) = dim3: Int,
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } crossinline initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }
): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
/** /**
@ -62,16 +67,16 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
} }
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T) = fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
context.mapIndexed(unwrap(), transform).wrap() context.mapIndexed(unwrap(), transform).wrap()
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap() fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
context.map(unwrap(), transform).wrap()
/** /**
* Element by element application of any operation on elements to the whole [NDElement] * Element by element application of any operation on elements to the whole [NDElement]
*/ */
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>) = operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> =
ndElement.map { value -> this@invoke(value) } ndElement.map { value -> this@invoke(value) }
/* plus and minus */ /* plus and minus */
@ -79,13 +84,13 @@ operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElem
/** /**
* Summation operation for [NDElement] and single element * Summation operation for [NDElement] and single element
*/ */
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T) = operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> =
map { value -> arg + value } map { value -> arg + value }
/** /**
* Subtraction operation between [NDElement] and single element * Subtraction operation between [NDElement] and single element
*/ */
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T) = operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> =
map { value -> arg - value } map { value -> arg - value }
/* prod and div */ /* prod and div */
@ -93,16 +98,15 @@ operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg:
/** /**
* Product operation for [NDElement] and single element * Product operation for [NDElement] and single element
*/ */
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T) = operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> =
map { value -> arg * value } map { value -> arg * value }
/** /**
* Division operation between [NDElement] and single element * Division operation between [NDElement] and single element
*/ */
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T) = operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
map { value -> arg / value } map { value -> arg / value }
// /** // /**
// * Reverse sum operation // * Reverse sum operation
// */ // */

View File

@ -1,17 +1,42 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.reflect.KClass import kotlin.reflect.KClass
/**
* Represents n-dimensional structure, i.e. multidimensional container of items of the same type and size. The number
* of dimensions and items in an array is defined by its shape, which is a sequence of non-negative integers that
* specify the sizes of each dimension.
*
* @param T the type of items.
*/
interface NDStructure<T> { interface NDStructure<T> {
/**
* The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of
* this structure.
*/
val shape: IntArray val shape: IntArray
val dimension get() = shape.size /**
* The count of dimensions in this structure. It should be equal to size of [shape].
*/
val dimension: Int get() = shape.size
/**
* Returns the value at the specified indices.
*
* @param index the indices.
* @return the value.
*/
operator fun get(index: IntArray): T operator fun get(index: IntArray): T
/**
* Returns the sequence of all the elements associated by their indices.
*
* @return the lazy sequence of pairs of indices to values.
*/
fun elements(): Sequence<Pair<IntArray, T>> fun elements(): Sequence<Pair<IntArray, T>>
override fun equals(other: Any?): Boolean override fun equals(other: Any?): Boolean
@ -19,6 +44,9 @@ interface NDStructure<T> {
override fun hashCode(): Int override fun hashCode(): Int
companion object { companion object {
/**
* Indicates whether some [NDStructure] is equal to another one.
*/
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
if (st1 === st2) return true if (st1 === st2) return true
@ -36,58 +64,89 @@ interface NDStructure<T> {
} }
/** /**
* Create a NDStructure with explicit buffer factory * Creates a NDStructure with explicit buffer factory.
* *
* Strides should be reused if possible * Strides should be reused if possible.
*/ */
fun <T> build( fun <T> build(
strides: Strides, strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T
) = ): BufferNDStructure<T> =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
/** /**
* Inline create NDStructure with non-boxing buffer implementation if it is possible * Inline create NDStructure with non-boxing buffer implementation if it is possible
*/ */
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> auto(
strides: Strides,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
inline fun <T : Any> auto(type: KClass<T>, strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <T : Any> auto(
type: KClass<T>,
strides: Strides,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T> build( fun <T> build(
shape: IntArray, shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T
) = build(DefaultStrides(shape), bufferFactory, initializer) ): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> auto(
shape: IntArray,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
@JvmName("autoVarArg") @JvmName("autoVarArg")
inline fun <reified T : Any> auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> auto(
vararg shape: Int,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
inline fun <T : Any> auto(type: KClass<T>, vararg shape: Int, crossinline initializer: (IntArray) -> T) = inline fun <T : Any> auto(
type: KClass<T>,
vararg shape: Int,
crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> =
auto(type, DefaultStrides(shape), initializer) auto(type, DefaultStrides(shape), initializer)
} }
} }
/**
* Returns the value at the specified indices.
*
* @param index the indices.
* @return the value.
*/
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index) operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
/**
* Represents mutable [NDStructure].
*/
interface MutableNDStructure<T> : NDStructure<T> { interface MutableNDStructure<T> : NDStructure<T> {
/**
* Inserts an item at the specified indices.
*
* @param index the indices.
* @param value the value.
*/
operator fun set(index: IntArray, value: T) operator fun set(index: IntArray, value: T)
} }
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) { inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
elements().forEach { (index, oldValue) -> contract { callsInPlace(action) }
this[index] = action(index, oldValue) elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
}
} }
/** /**
* A way to convert ND index to linear one and back * A way to convert ND index to linear one and back.
*/ */
interface Strides { interface Strides {
/** /**
@ -124,11 +183,14 @@ interface Strides {
} }
} }
/**
* Simple implementation of [Strides].
*/
class DefaultStrides private constructor(override val shape: IntArray) : Strides { class DefaultStrides private constructor(override val shape: IntArray) : Strides {
/** /**
* Strides for memory access * Strides for memory access
*/ */
override val strides by lazy { override val strides: List<Int> by lazy {
sequence { sequence {
var current = 1 var current = 1
yield(1) yield(1)
@ -139,14 +201,12 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
}.toList() }.toList()
} }
override fun offset(index: IntArray): Int { override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
return index.mapIndexed { i, value -> if (value < 0 || value >= this.shape[i])
if (value < 0 || value >= this.shape[i]) { throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
throw RuntimeException("Index $value out of shape bounds: (0,${this.shape[i]})")
}
value * strides[i] value * strides[i]
}.sum() }.sum()
}
override fun index(offset: Int): IntArray { override fun index(offset: Int): IntArray {
val res = IntArray(shape.size) val res = IntArray(shape.size)
@ -163,19 +223,14 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
override val linearSize: Int override val linearSize: Int
get() = strides[shape.size] get() = strides[shape.size]
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other !is DefaultStrides) return false if (other !is DefaultStrides) return false
if (!shape.contentEquals(other.shape)) return false if (!shape.contentEquals(other.shape)) return false
return true return true
} }
override fun hashCode(): Int { override fun hashCode(): Int = shape.contentHashCode()
return shape.contentHashCode()
}
companion object { companion object {
private val defaultStridesCache = HashMap<IntArray, Strides>() private val defaultStridesCache = HashMap<IntArray, Strides>()
@ -187,11 +242,23 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
} }
} }
/**
* Represents [NDStructure] over [Buffer].
*
* @param T the type of items.
*/
abstract class NDBuffer<T> : NDStructure<T> { abstract class NDBuffer<T> : NDStructure<T> {
/**
* The underlying buffer.
*/
abstract val buffer: Buffer<T> abstract val buffer: Buffer<T>
/**
* The strides to access elements of [Buffer] by linear indices.
*/
abstract val strides: Strides abstract val strides: Strides
override fun get(index: IntArray): T = buffer[strides.offset(index)] override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
override val shape: IntArray get() = strides.shape override val shape: IntArray get() = strides.shape
@ -238,7 +305,7 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
} }
/** /**
* Mutable ND buffer based on linear [autoBuffer] * Mutable ND buffer based on linear [MutableBuffer].
*/ */
class MutableBufferNDStructure<T>( class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
@ -246,18 +313,18 @@ class MutableBufferNDStructure<T>(
) : NDBuffer<T>(), MutableNDStructure<T> { ) : NDBuffer<T>(), MutableNDStructure<T> {
init { init {
if (strides.linearSize != buffer.size) { require(strides.linearSize == buffer.size) {
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") "Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
} }
} }
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value) override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
} }
inline fun <reified T : Any> NDStructure<T>.combine( inline fun <reified T : Any> NDStructure<T>.combine(
struct: NDStructure<T>, struct: NDStructure<T>,
crossinline block: (T, T) -> T crossinline block: (T, T) -> T
): NDStructure<T> { ): NDStructure<T> {
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
return NDStructure.auto(shape) { block(this[it], struct[it]) } return NDStructure.auto(shape) { block(this[it], struct[it]) }
} }

View File

@ -1,34 +1,55 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/**
* Specialized [MutableBuffer] implementation over [DoubleArray].
*
* @property array the underlying array.
*/
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> { inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Double = array[index] override operator fun get(index: Int): Double = array[index]
override fun set(index: Int, value: Double) { override operator fun set(index: Int, value: Double) {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override operator fun iterator(): DoubleIterator = array.iterator()
override fun copy(): MutableBuffer<Double> = override fun copy(): MutableBuffer<Double> =
RealBuffer(array.copyOf()) RealBuffer(array.copyOf())
} }
@Suppress("FunctionName") /**
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) * Creates a new [RealBuffer] with the specified [size], where each element is calculated by calling the specified
* [init] function.
*
* The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index.
*/
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer {
contract { callsInPlace(init) }
return RealBuffer(DoubleArray(size) { init(it) })
}
@Suppress("FunctionName") /**
* Returns a new [RealBuffer] of given elements.
*/
fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
/** /**
* Transform buffer of doubles into array for high performance operations * Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
*/ */
val MutableBuffer<out Double>.array: DoubleArray val MutableBuffer<out Double>.array: DoubleArray
get() = if (this is RealBuffer) { get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) })
array
} else {
DoubleArray(size) { get(it) }
}
fun DoubleArray.asBuffer() = RealBuffer(this) /**
* Returns [RealBuffer] over this array.
*
* @receiver the array.
* @return the new buffer.
*/
fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this)

View File

@ -6,18 +6,19 @@ import kotlin.math.*
/** /**
* A simple field over linear buffers of [Double] * [ExtendedFieldOperations] over [RealBuffer].
*/ */
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> { object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
} else } else RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
} }
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer { override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
@ -26,12 +27,13 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
return if (a is RealBuffer) { return if (a is RealBuffer) {
val aArray = a.array val aArray = a.array
RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) RealBuffer(DoubleArray(a.size) { aArray[it] * kValue })
} else } else RealBuffer(DoubleArray(a.size) { a[it] * kValue })
RealBuffer(DoubleArray(a.size) { a[it] * kValue })
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
@ -42,34 +44,31 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
} else } else RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
} }
override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
} else { } else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
}
override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
@ -90,25 +89,57 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
} else } else
RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
} else } else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
} }
/**
* [ExtendedField] over [RealBuffer].
*
* @property size the size of buffers to operate on.
*/
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> { class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } } override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } } override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
@ -163,6 +194,36 @@ class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
return RealBufferFieldOperations.atan(arg) return RealBufferFieldOperations.atan(arg)
} }
override fun sinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.sinh(arg)
}
override fun cosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.cosh(arg)
}
override fun tanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.tanh(arg)
}
override fun asinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.asinh(arg)
}
override fun acosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.acosh(arg)
}
override fun atanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.atanh(arg)
}
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.power(arg, pow) return RealBufferFieldOperations.power(arg, pow)

View File

@ -12,8 +12,8 @@ class RealNDField(override val shape: IntArray) :
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: RealField get() = RealField override val elementContext: RealField get() = RealField
override val zero by lazy { produce { zero } } override val zero: RealNDElement by lazy { produce { zero } }
override val one by lazy { produce { one } } override val one: RealNDElement by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
RealBuffer(DoubleArray(size) { initializer(it) }) RealBuffer(DoubleArray(size) { initializer(it) })
@ -40,6 +40,7 @@ class RealNDField(override val shape: IntArray) :
transform: RealField.(index: IntArray, Double) -> Double transform: RealField.(index: IntArray, Double) -> Double
): RealNDElement { ): RealNDElement {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -64,23 +65,25 @@ class RealNDField(override val shape: IntArray) :
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> = override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
BufferedNDFieldElement(this@RealNDField, buffer) BufferedNDFieldElement(this@RealNDField, buffer)
override fun power(arg: NDBuffer<Double>, pow: Number) = map(arg) { power(it, pow) } override fun power(arg: NDBuffer<Double>, pow: Number): RealNDElement = map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Double>) = map(arg) { exp(it) } override fun exp(arg: NDBuffer<Double>): RealNDElement = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Double>) = map(arg) { ln(it) } override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Double>) = map(arg) { sin(it) } override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Double>): RealNDElement = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Double>): RealNDElement = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Double>): RealNDElement = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Double>): RealNDElement = map(arg) { atan(it) }
override fun cos(arg: NDBuffer<Double>) = map(arg) { cos(it) } override fun sinh(arg: NDBuffer<Double>): RealNDElement = map(arg) { sinh(it) }
override fun cosh(arg: NDBuffer<Double>): RealNDElement = map(arg) { cosh(it) }
override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) } override fun tanh(arg: NDBuffer<Double>): RealNDElement = map(arg) { tanh(it) }
override fun asinh(arg: NDBuffer<Double>): RealNDElement = map(arg) { asinh(it) }
override fun asin(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { asin(it) } override fun acosh(arg: NDBuffer<Double>): RealNDElement = map(arg) { acosh(it) }
override fun atanh(arg: NDBuffer<Double>): RealNDElement = map(arg) { atanh(it) }
override fun acos(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { atan(it) }
} }
@ -93,13 +96,13 @@ inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initiali
} }
/** /**
* Map one [RealNDElement] using function with indexes * Map one [RealNDElement] using function with indices.
*/ */
inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double) = inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double): RealNDElement =
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
/** /**
* Map one [RealNDElement] using function without indexes * Map one [RealNDElement] using function without indices.
*/ */
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
@ -107,9 +110,9 @@ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double
} }
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy.
*/ */
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) = operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement =
ndElement.map { this@invoke(it) } ndElement.map { this@invoke(it) }
@ -118,18 +121,17 @@ operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun RealNDElement.plus(arg: Double) = operator fun RealNDElement.plus(arg: Double): RealNDElement =
map { it + arg } map { it + arg }
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun RealNDElement.minus(arg: Double) = operator fun RealNDElement.minus(arg: Double): RealNDElement =
map { it - arg } map { it - arg }
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R {
return NDField.real(*shape).run(action) inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)
}

View File

@ -1,20 +1,55 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/**
* Specialized [MutableBuffer] implementation over [ShortArray].
*
* @property array the underlying array.
*/
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> { inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Short = array[index] override operator fun get(index: Int): Short = array[index]
override fun set(index: Int, value: Short) { override operator fun set(index: Int, value: Short) {
array[index] = value array[index] = value
} }
override fun iterator() = array.iterator() override operator fun iterator(): ShortIterator = array.iterator()
override fun copy(): MutableBuffer<Short> = override fun copy(): MutableBuffer<Short> =
ShortBuffer(array.copyOf()) ShortBuffer(array.copyOf())
} }
/**
* Creates a new [ShortBuffer] with the specified [size], where each element is calculated by calling the specified
* [init] function.
*
* The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index.
*/
inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer {
contract { callsInPlace(init) }
return ShortBuffer(ShortArray(size) { init(it) })
}
fun ShortArray.asBuffer() = ShortBuffer(this) /**
* Returns a new [ShortBuffer] of given elements.
*/
fun ShortBuffer(vararg shorts: Short): ShortBuffer = ShortBuffer(shorts)
/**
* Returns a [ShortArray] containing all of the elements of this [MutableBuffer].
*/
val MutableBuffer<out Short>.array: ShortArray
get() = (if (this is ShortBuffer) array else ShortArray(size) { get(it) })
/**
* Returns [ShortBuffer] over this array.
*
* @receiver the array.
* @return the new buffer.
*/
fun ShortArray.asBuffer(): ShortBuffer = ShortBuffer(this)

View File

@ -12,8 +12,8 @@ class ShortNDRing(override val shape: IntArray) :
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: ShortRing get() = ShortRing override val elementContext: ShortRing get() = ShortRing
override val zero by lazy { produce { ShortRing.zero } } override val zero: ShortNDElement by lazy { produce { zero } }
override val one by lazy { produce { ShortRing.one } } override val one: ShortNDElement by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
ShortBuffer(ShortArray(size) { initializer(it) }) ShortBuffer(ShortArray(size) { initializer(it) })
@ -40,6 +40,7 @@ class ShortNDRing(override val shape: IntArray) :
transform: ShortRing.(index: IntArray, Short) -> Short transform: ShortRing.(index: IntArray, Short) -> Short
): ShortNDElement { ): ShortNDElement {
check(arg) check(arg)
return BufferedNDRingElement( return BufferedNDRingElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -67,7 +68,7 @@ class ShortNDRing(override val shape: IntArray) :
/** /**
* Fast element production using function inlining * Fast element production using function inlining.
*/ */
inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement { inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement {
val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) } val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }
@ -75,22 +76,22 @@ inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initialize
} }
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array.
*/ */
operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement) = operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement): ShortNDElement =
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
/* plus and minus */ /* plus and minus */
/** /**
* Summation operation for [StridedNDFieldElement] and single element * Summation operation for [ShortNDElement] and single element.
*/ */
operator fun ShortNDElement.plus(arg: Short) = operator fun ShortNDElement.plus(arg: Short): ShortNDElement =
context.produceInline { i -> (buffer[i] + arg).toShort() } context.produceInline { i -> (buffer[i] + arg).toShort() }
/** /**
* Subtraction operation between [StridedNDFieldElement] and single element * Subtraction operation between [ShortNDElement] and single element.
*/ */
operator fun ShortNDElement.minus(arg: Short) = operator fun ShortNDElement.minus(arg: Short): ShortNDElement =
context.produceInline { i -> (buffer[i] - arg).toShort() } context.produceInline { i -> (buffer[i] - arg).toShort() }

View File

@ -6,23 +6,23 @@ package scientifik.kmath.structures
interface Structure1D<T> : NDStructure<T>, Buffer<T> { interface Structure1D<T> : NDStructure<T>, Buffer<T> {
override val dimension: Int get() = 1 override val dimension: Int get() = 1
override fun get(index: IntArray): T { override operator fun get(index: IntArray): T {
if (index.size != 1) error("Index dimension mismatch. Expected 1 but found ${index.size}") require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
return get(index[0]) return get(index[0])
} }
override fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
} }
/** /**
* A 1D wrapper for nd-structure * A 1D wrapper for nd-structure
*/ */
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T>{ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0] override val size: Int get() = structure.shape[0]
override fun get(index: Int): T = structure[index] override operator fun get(index: Int): T = structure[index]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
} }
@ -39,14 +39,14 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
override fun elements(): Sequence<Pair<IntArray, T>> = override fun elements(): Sequence<Pair<IntArray, T>> =
asSequence().mapIndexed { index, value -> intArrayOf(index) to value } asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
override fun get(index: Int): T = buffer.get(index) override operator fun get(index: Int): T = buffer[index]
} }
/** /**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/ */
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) { fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
if( this is NDBuffer){ if (this is NDBuffer) {
Buffer1DWrapper(this.buffer) Buffer1DWrapper(this.buffer)
} else { } else {
Structure1DWrapper(this) Structure1DWrapper(this)

View File

@ -9,8 +9,8 @@ interface Structure2D<T> : NDStructure<T> {
operator fun get(i: Int, j: Int): T operator fun get(i: Int, j: Int): T
override fun get(index: IntArray): T { override operator fun get(index: IntArray): T {
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}") require(index.size == 2) { "Index dimension mismatch. Expected 2 but found ${index.size}" }
return get(index[0], index[1]) return get(index[0], index[1])
} }
@ -32,19 +32,17 @@ interface Structure2D<T> : NDStructure<T> {
} }
} }
companion object { companion object
}
} }
/** /**
* A 2D wrapper for nd-structure * A 2D wrapper for nd-structure
*/ */
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> { private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
override fun get(i: Int, j: Int): T = structure[i, j]
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override operator fun get(i: Int, j: Int): T = structure[i, j]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
} }

View File

@ -3,6 +3,7 @@ package scientifik.kmath.expressions
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -10,10 +11,12 @@ class ExpressionFieldTest {
@Test @Test
fun testExpression() { fun testExpression() {
val context = FunctionalExpressionField(RealField) val context = FunctionalExpressionField(RealField)
val expression = with(context) {
val expression = context {
val x = variable("x", 2.0) val x = variable("x", 2.0)
x * x + 2 * x + one x * x + 2 * x + one
} }
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression("x" to 1.0), 4.0)
assertEquals(expression(), 9.0) assertEquals(expression(), 9.0)
} }
@ -21,17 +24,19 @@ class ExpressionFieldTest {
@Test @Test
fun testComplex() { fun testComplex() {
val context = FunctionalExpressionField(ComplexField) val context = FunctionalExpressionField(ComplexField)
val expression = with(context) {
val expression = context {
val x = variable("x", Complex(2.0, 0.0)) val x = variable("x", Complex(2.0, 0.0))
x * x + 2 * x + one x * x + 2 * x + one
} }
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
assertEquals(expression(), Complex(9.0, 0.0)) assertEquals(expression(), Complex(9.0, 0.0))
} }
@Test @Test
fun separateContext() { fun separateContext() {
fun <T> FunctionalExpressionField<T,*>.expression(): Expression<T> { fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
val x = variable("x") val x = variable("x")
return x * x + 2 * x + one return x * x + 2 * x + one
} }
@ -42,7 +47,7 @@ class ExpressionFieldTest {
@Test @Test
fun valueExpression() { fun valueExpression() {
val expressionBuilder: FunctionalExpressionField<Double,*>.() -> Expression<Double> = { val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
val x = variable("x") val x = variable("x")
x * x + 2 * x + one x * x + 2 * x + one
} }

View File

@ -7,7 +7,6 @@ import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class MatrixTest { class MatrixTest {
@Test @Test
fun testTranspose() { fun testTranspose() {
val matrix = MatrixContext.real.one(3, 3) val matrix = MatrixContext.real.one(3, 3)
@ -49,17 +48,18 @@ class MatrixTest {
@Test @Test
fun test2DDot() { fun test2DDot() {
val firstMatrix = NDStructure.auto(2,3){ (i, j) -> (i + j).toDouble() }.as2D() val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D()
val secondMatrix = NDStructure.auto(3,2){ (i, j) -> (i + j).toDouble() }.as2D() val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D()
MatrixContext.real.run { MatrixContext.real.run {
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() } // val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() } // val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
val result = firstMatrix dot secondMatrix val result = firstMatrix dot secondMatrix
assertEquals(2, result.rowNum) assertEquals(2, result.rowNum)
assertEquals(2, result.colNum) assertEquals(2, result.colNum)
assertEquals(8.0, result[0,1]) assertEquals(8.0, result[0, 1])
assertEquals(8.0, result[1,0]) assertEquals(8.0, result[1, 0])
assertEquals(14.0, result[1,1]) assertEquals(14.0, result[1, 1])
} }
} }
} }

View File

@ -8,10 +8,10 @@ import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class AutoDiffTest { class AutoDiffTest {
fun Variable(int: Int): Variable<Double> = Variable(int.toDouble())
fun Variable(int: Int) = Variable(int.toDouble()) fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
RealField.deriv(body)
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>) = RealField.deriv(body)
@Test @Test
fun testPlusX2() { fun testPlusX2() {
@ -178,5 +178,4 @@ class AutoDiffTest {
private fun assertApprox(a: Double, b: Double) { private fun assertApprox(a: Double, b: Double) {
if ((a - b) > 1e-10) assertEquals(a, b) if ((a - b) > 1e-10) assertEquals(a, b)
} }
} }

View File

@ -1,9 +1,13 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.internal.RingVerifier
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class BigIntAlgebraTest { internal class BigIntAlgebraTest {
@Test
fun verify() = BigIntField { RingVerifier(this, +"42", +"10", +"-12", 10).verify() }
@Test @Test
fun testKBigIntegerRingSum() { fun testKBigIntegerRingSum() {
val res = BigIntField { val res = BigIntField {
@ -47,4 +51,3 @@ class BigIntAlgebraTest {
} }
} }

View File

@ -19,7 +19,7 @@ class BigIntConstructorTest {
@Test @Test
fun testConstructor_0xffffffffaL() { fun testConstructor_0xffffffffaL() {
val x = -0xffffffffaL.toBigInt() val x = (-0xffffffffaL).toBigInt()
val y = uintArrayOf(0xfffffffaU, 0xfU).toBigInt(-1) val y = uintArrayOf(0xfffffffaU, 0xfU).toBigInt(-1)
assertEquals(x, y) assertEquals(x, y)
} }

View File

@ -19,7 +19,7 @@ class BigIntConversionsTest {
@Test @Test
fun testToString_0x17ead2ffffd() { fun testToString_0x17ead2ffffd() {
val x = -0x17ead2ffffdL.toBigInt() val x = (-0x17ead2ffffdL).toBigInt()
assertEquals("-0x17ead2ffffd", x.toString()) assertEquals("-0x17ead2ffffd", x.toString())
} }

View File

@ -31,7 +31,7 @@ class BigIntOperationsTest {
@Test @Test
fun testUnaryMinus() { fun testUnaryMinus() {
val x = 1234.toBigInt() val x = 1234.toBigInt()
val y = -1234.toBigInt() val y = (-1234).toBigInt()
assertEquals(-x, y) assertEquals(-x, y)
} }
@ -48,18 +48,18 @@ class BigIntOperationsTest {
@Test @Test
fun testMinus__2_1() { fun testMinus__2_1() {
val x = -2.toBigInt() val x = (-2).toBigInt()
val y = 1.toBigInt() val y = 1.toBigInt()
val res = x - y val res = x - y
val sum = -3.toBigInt() val sum = (-3).toBigInt()
assertEquals(sum, res) assertEquals(sum, res)
} }
@Test @Test
fun testMinus___2_1() { fun testMinus___2_1() {
val x = -2.toBigInt() val x = (-2).toBigInt()
val y = 1.toBigInt() val y = 1.toBigInt()
val res = -x - y val res = -x - y
@ -74,7 +74,7 @@ class BigIntOperationsTest {
val y = 0xffffffffaL.toBigInt() val y = 0xffffffffaL.toBigInt()
val res = x - y val res = x - y
val sum = -0xfffffcfc1L.toBigInt() val sum = (-0xfffffcfc1L).toBigInt()
assertEquals(sum, res) assertEquals(sum, res)
} }
@ -92,11 +92,11 @@ class BigIntOperationsTest {
@Test @Test
fun testMultiply__2_3() { fun testMultiply__2_3() {
val x = -2.toBigInt() val x = (-2).toBigInt()
val y = 3.toBigInt() val y = 3.toBigInt()
val res = x * y val res = x * y
val prod = -6.toBigInt() val prod = (-6).toBigInt()
assertEquals(prod, res) assertEquals(prod, res)
} }
@ -129,7 +129,7 @@ class BigIntOperationsTest {
val y = -0xfff456 val y = -0xfff456
val res = x * y val res = x * y
val prod = -0xffe579ad5dc2L.toBigInt() val prod = (-0xffe579ad5dc2L).toBigInt()
assertEquals(prod, res) assertEquals(prod, res)
} }
@ -259,7 +259,7 @@ class BigIntOperationsTest {
val y = -3 val y = -3
val res = x / y val res = x / y
val div = -6.toBigInt() val div = (-6).toBigInt()
assertEquals(div, res) assertEquals(div, res)
} }
@ -267,10 +267,10 @@ class BigIntOperationsTest {
@Test @Test
fun testBigDivision_20__3() { fun testBigDivision_20__3() {
val x = 20.toBigInt() val x = 20.toBigInt()
val y = -3.toBigInt() val y = (-3).toBigInt()
val res = x / y val res = x / y
val div = -6.toBigInt() val div = (-6).toBigInt()
assertEquals(div, res) assertEquals(div, res)
} }

View File

@ -0,0 +1,77 @@
package scientifik.kmath.operations
import scientifik.kmath.operations.internal.FieldVerifier
import kotlin.math.PI
import kotlin.math.abs
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class ComplexFieldTest {
@Test
fun verify() = ComplexField { FieldVerifier(this, 42.0 * i, 66.0 + 28 * i, 2.0 + 0 * i, 5).verify() }
@Test
fun testAddition() {
assertEquals(Complex(42, 42), ComplexField { Complex(16, 16) + Complex(26, 26) })
assertEquals(Complex(42, 16), ComplexField { Complex(16, 16) + 26 })
assertEquals(Complex(42, 16), ComplexField { 26 + Complex(16, 16) })
}
@Test
fun testSubtraction() {
assertEquals(Complex(42, 42), ComplexField { Complex(86, 55) - Complex(44, 13) })
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
}
@Test
fun testMultiplication() {
assertEquals(Complex(42, 42), ComplexField { Complex(4.2, 0) * Complex(10, 10) })
assertEquals(Complex(42, 21), ComplexField { Complex(4.2, 2.1) * 10 })
assertEquals(Complex(42, 21), ComplexField { 10 * Complex(4.2, 2.1) })
}
@Test
fun testDivision() {
assertEquals(Complex(42, 42), ComplexField { Complex(0, 168) / Complex(2, 2) })
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(Double.NaN, Double.NaN) })
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) })
}
@Test
fun testSine() {
assertEquals(ComplexField { i * sinh(one) }, ComplexField { sin(i) })
assertEquals(ComplexField { i * sinh(PI.toComplex()) }, ComplexField { sin(i * PI.toComplex()) })
}
@Test
fun testInverseSine() {
assertEquals(Complex(0, -0.0), ComplexField { asin(zero) })
assertTrue(abs(ComplexField { i * asinh(one) }.r - ComplexField { asin(i) }.r) < 0.000000000000001)
}
@Test
fun testInverseHyperbolicSine() {
assertEquals(
ComplexField { i * PI.toComplex() / 2 },
ComplexField { asinh(i) })
}
@Test
fun testPower() {
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
assertEquals(
ComplexField { i * 8 }.let { it.im.toInt() to it.re.toInt() },
ComplexField { Complex(2, 2) pow 2 }.let { it.im.toInt() to it.re.toInt() })
}
@Test
fun testNorm() {
assertEquals(2.toComplex(), ComplexField { norm(2 * i) })
}
}

View File

@ -0,0 +1,38 @@
package scientifik.kmath.operations
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ComplexTest {
@Test
fun conjugate() {
assertEquals(
Complex(0, -42), (ComplexField.i * 42).conjugate
)
}
@Test
fun reciprocal() {
assertEquals(Complex(0.5, -0.0), 2.toComplex().reciprocal)
}
@Test
fun r() {
assertEquals(kotlin.math.sqrt(2.0), (ComplexField.i + 1.0.toComplex()).r)
}
@Test
fun theta() {
assertEquals(0.0, 1.toComplex().theta)
}
@Test
fun toComplex() {
assertEquals(Complex(42, 0), 42.toComplex())
assertEquals(Complex(42.0, 0), 42.0.toComplex())
assertEquals(Complex(42f, 0), 42f.toComplex())
assertEquals(Complex(42.0, 0), 42.0.toComplex())
assertEquals(Complex(42.toByte(), 0), 42.toByte().toComplex())
assertEquals(Complex(42.toShort(), 0), 42.toShort().toComplex())
}
}

View File

@ -1,14 +1,16 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.internal.FieldVerifier
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class RealFieldTest { internal class RealFieldTest {
@Test
fun verify() = FieldVerifier(RealField, 42.0, 66.0, 2.0, 5).verify()
@Test @Test
fun testSqrt() { fun testSqrt() {
val sqrt = RealField { val sqrt = RealField { sqrt(25 * one) }
sqrt(25 * one)
}
assertEquals(5.0, sqrt) assertEquals(5.0, sqrt)
} }
} }

View File

@ -0,0 +1,9 @@
package scientifik.kmath.operations.internal
import scientifik.kmath.operations.Algebra
internal interface AlgebraicVerifier<T, out A> where A : Algebra<T> {
val algebra: A
fun verify()
}

Some files were not shown because too many files have changed in this diff Show More