forked from kscience/kmath
Merge remote-tracking branch 'origin/dev' into hyp-trig-functions
# Conflicts: # kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt # kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt # kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt # kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt # kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt # kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt
This commit is contained in:
commit
40d7351515
@ -11,6 +11,7 @@ allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
}
|
||||
|
||||
group = "scientifik"
|
||||
|
124
doc/algebra.md
124
doc/algebra.md
@ -1,110 +1,124 @@
|
||||
# Algebra and algebra elements
|
||||
# Algebraic Structures and Algebraic Elements
|
||||
|
||||
The mathematical operations in `kmath` are generally separated from mathematical objects.
|
||||
This means that in order to perform an operation, say `+`, one needs two objects of a type `T` and
|
||||
and algebra context which defines appropriate operation, say `Space<T>`. Next one needs to run actual operation
|
||||
in the context:
|
||||
The mathematical operations in KMath are generally separated from mathematical objects. This means that to perform an
|
||||
operation, say `+`, one needs two objects of a type `T` and an algebra context, which draws appropriate operation up,
|
||||
say `Space<T>`. Next one needs to run the actual operation in the context:
|
||||
|
||||
```kotlin
|
||||
val a: T
|
||||
val b: T
|
||||
val space: Space<T>
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
val c = space.run{a + b}
|
||||
val a: T = ...
|
||||
val b: T = ...
|
||||
val space: Space<T> = ...
|
||||
|
||||
val c = space { a + b }
|
||||
```
|
||||
|
||||
From the first glance, this distinction seems to be a needless complication, but in fact one needs
|
||||
to remember that in mathematics, one could define different operations on the same objects. For example,
|
||||
one could use different types of geometry for vectors.
|
||||
At first glance, this distinction seems to be a needless complication, but in fact one needs to remember that in
|
||||
mathematics, one could draw up different operations on same objects. For example, one could use different types of
|
||||
geometry for vectors.
|
||||
|
||||
## Algebra hierarchy
|
||||
## Algebraic Structures
|
||||
|
||||
Mathematical contexts have the following hierarchy:
|
||||
|
||||
**Space** <- **Ring** <- **Field**
|
||||
**Algebra** ← **Space** ← **Ring** ← **Field**
|
||||
|
||||
All classes follow abstract mathematical constructs.
|
||||
[Space](http://mathworld.wolfram.com/Space.html) defines `zero` element, addition operation and multiplication by constant,
|
||||
[Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and unit `one` element,
|
||||
[Field](http://mathworld.wolfram.com/Field.html) adds division operation.
|
||||
These interfaces follow real algebraic structures:
|
||||
|
||||
Typical case of `Field` is the `RealField` which works on doubles. And typical case of `Space` is a `VectorSpace`.
|
||||
- [Space](https://mathworld.wolfram.com/VectorSpace.html) defines addition, its neutral element (i.e. 0) and scalar
|
||||
multiplication;
|
||||
- [Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and its neutral element (i.e. 1);
|
||||
- [Field](http://mathworld.wolfram.com/Field.html) adds division operation.
|
||||
|
||||
In some cases algebra context could hold additional operation like `exp` or `sin`, in this case it inherits appropriate
|
||||
interface. Also a context could have an operation which produces an element outside of its context. For example
|
||||
`Matrix` `dot` operation produces a matrix with new dimensions which can be incompatible with initial matrix in
|
||||
terms of linear operations.
|
||||
A typical implementation of `Field<T>` is the `RealField` which works on doubles, and `VectorSpace` for `Space<T>`.
|
||||
|
||||
## Algebra element
|
||||
In some cases algebra context can hold additional operations like `exp` or `sin`, and then it inherits appropriate
|
||||
interface. Also, contexts may have operations, which produce elements outside of the context. For example, `Matrix.dot`
|
||||
operation produces a matrix with new dimensions, which can be incompatible with initial matrix in terms of linear
|
||||
operations.
|
||||
|
||||
In order to achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving contexts
|
||||
`kmath` introduces special type objects called `MathElement`. A `MathElement` is basically some object coupled to
|
||||
## Algebraic Element
|
||||
|
||||
To achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving
|
||||
contexts KMath submits special type objects called `MathElement`. A `MathElement` is basically some object coupled to
|
||||
a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts,
|
||||
but it also holds reference to the `ComplexField` singleton which allows to perform direct operations on `Complex`
|
||||
but it also holds reference to the `ComplexField` singleton, which allows performing direct operations on `Complex`
|
||||
numbers without explicit involving the context like:
|
||||
|
||||
```kotlin
|
||||
val c1 = Complex(1.0, 1.0)
|
||||
val c2 = Complex(1.0, -1.0)
|
||||
val c3 = c1 + c2 + 3.0.toComplex()
|
||||
//or with field notation:
|
||||
val c4 = ComplexField.run{c1 + i - 2.0}
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
// Using elements
|
||||
val c1 = Complex(1.0, 1.0)
|
||||
val c2 = Complex(1.0, -1.0)
|
||||
val c3 = c1 + c2 + 3.0.toComplex()
|
||||
|
||||
// Using context
|
||||
val c4 = ComplexField { c1 + i - 2.0 }
|
||||
```
|
||||
|
||||
Both notations have their pros and cons.
|
||||
|
||||
The hierarchy for algebra elements follows the hierarchy for the corresponding algebra.
|
||||
The hierarchy for algebraic elements follows the hierarchy for the corresponding algebraic structures.
|
||||
|
||||
**MathElement** <- **SpaceElement** <- **RingElement** <- **FieldElement**
|
||||
**MathElement** ← **SpaceElement** ← **RingElement** ← **FieldElement**
|
||||
|
||||
**MathElement** is the generic common ancestor of the class with context.
|
||||
`MathElement<C>` is the generic common ancestor of the class with context.
|
||||
|
||||
One important distinction between algebra elements and algebra contexts is that algebra element has three type parameters:
|
||||
One major distinction between algebraic elements and algebraic contexts is that elements have three type
|
||||
parameters:
|
||||
|
||||
1. The type of elements, field operates on.
|
||||
2. The self-type of the element returned from operation (must be algebra element).
|
||||
1. The type of elements, the field operates on.
|
||||
2. The self-type of the element returned from operation (which has to be an algebraic element).
|
||||
3. The type of the algebra over first type-parameter.
|
||||
|
||||
The middle type is needed in case algebra members do not store context. For example, it is not possible to add
|
||||
a context to regular `Double`. The element performs automatic conversions from context types and back.
|
||||
One should used context operations in all important places. The performance of element operations is not guaranteed.
|
||||
The middle type is needed for of algebra members do not store context. For example, it is impossible to add a context
|
||||
to regular `Double`. The element performs automatic conversions from context types and back. One should use context
|
||||
operations in all performance-critical places. The performance of element operations is not guaranteed.
|
||||
|
||||
## Spaces and fields
|
||||
## Spaces and Fields
|
||||
|
||||
An obvious first choice of mathematical objects to implement in a context-oriented style are algebraic elements like spaces,
|
||||
rings and fields. Those are located in the `scientifik.kmath.operations.Algebra.kt` file. Alongside common contexts, the file includes definitions for algebra elements like `FieldElement`. A `FieldElement` object
|
||||
stores a reference to the `Field` which contains additive and multiplicative operations, meaning
|
||||
it has one fixed context attached and does not require explicit external context. So those `MathElements` can be operated without context:
|
||||
KMath submits both contexts and elements for builtin algebraic structures:
|
||||
|
||||
```kotlin
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
val c1 = Complex(1.0, 2.0)
|
||||
val c2 = ComplexField.i
|
||||
|
||||
val c3 = c1 + c2
|
||||
// or
|
||||
val c3 = ComplexField { c1 + c2 }
|
||||
```
|
||||
|
||||
`ComplexField` also features special operations to mix complex and real numbers, for example:
|
||||
Also, `ComplexField` features special operations to mix complex and real numbers, for example:
|
||||
|
||||
```kotlin
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
val c1 = Complex(1.0, 2.0)
|
||||
val c2 = ComplexField.run{ c1 - 1.0} // Returns: [re:0.0, im: 2.0]
|
||||
val c3 = ComplexField.run{ c1 - i*2.0}
|
||||
val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0)
|
||||
val c3 = ComplexField { c1 - i * 2.0 }
|
||||
```
|
||||
|
||||
**Note**: In theory it is possible to add behaviors directly to the context, but currently kotlin syntax does not support
|
||||
that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and [KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates.
|
||||
**Note**: In theory it is possible to add behaviors directly to the context, but as for now Kotlin does not support
|
||||
that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and
|
||||
[KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates.
|
||||
|
||||
## Nested fields
|
||||
|
||||
Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex elements like so:
|
||||
Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex
|
||||
elements like so:
|
||||
|
||||
```kotlin
|
||||
val element = NDElement.complex(shape = intArrayOf(2,2)){ index: IntArray ->
|
||||
val element = NDElement.complex(shape = intArrayOf(2, 2)) { index: IntArray ->
|
||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||
}
|
||||
```
|
||||
|
||||
The `element` in this example is a member of the `Field` of 2-d structures, each element of which is a member of its own
|
||||
`ComplexField`. The important thing is one does not need to create a special n-d class to hold complex
|
||||
The `element` in this example is a member of the `Field` of 2D structures, each element of which is a member of its own
|
||||
`ComplexField`. It is important one does not need to create a special n-d class to hold complex
|
||||
numbers and implement operations on it, one just needs to provide a field for its elements.
|
||||
|
||||
**Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts like
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Buffers
|
||||
|
||||
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
|
||||
There are different types of buffers:
|
||||
|
||||
@ -12,4 +13,5 @@ Some kmath features require a `BufferFactory` class to operate properly. A gener
|
||||
buffer for given reified type (for types with custom memory buffer it still better to use their own `MemoryBuffer.create()` factory).
|
||||
|
||||
## Buffer performance
|
||||
One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead
|
||||
|
||||
One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead
|
||||
|
34
doc/codestyle.md
Normal file
34
doc/codestyle.md
Normal file
@ -0,0 +1,34 @@
|
||||
# Coding Conventions
|
||||
|
||||
KMath code follows general [Kotlin conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but
|
||||
with a number of small changes and clarifications.
|
||||
|
||||
## Utility Class Naming
|
||||
|
||||
Filename should coincide with a name of one of the classes contained in the file or start with small letter and
|
||||
describe its contents.
|
||||
|
||||
The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that
|
||||
file names should start with a capital letter even if file does not contain classes. Yet starting utility classes and
|
||||
aggregators with a small letter seems to be a good way to visually separate those files.
|
||||
|
||||
This convention could be changed in future in a non-breaking way.
|
||||
|
||||
## Private Variable Naming
|
||||
|
||||
Private variables' names may start with underscore `_` for of the private mutable variable is shadowed by the public
|
||||
read-only value with the same meaning.
|
||||
|
||||
This rule does not permit underscores in names, but it is sometimes useful to "underscore" the fact that public and
|
||||
private versions draw up the same entity. It is allowed only for private variables.
|
||||
|
||||
This convention could be changed in future in a non-breaking way.
|
||||
|
||||
## Functions and Properties One-liners
|
||||
|
||||
Use one-liners when they occupy single code window line both for functions and properties with getters like
|
||||
`val b: String get() = "fff"`. The same should be performed with multiline expressions when they could be
|
||||
cleanly separated.
|
||||
|
||||
There is no universal consensus whenever use `fun a() = ...` or `fun a() { return ... }`. Yet from reader outlook
|
||||
one-lines seem to better show that the property or function is easily calculated.
|
@ -1,6 +1,6 @@
|
||||
## Basic linear algebra layout
|
||||
|
||||
Kmath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared
|
||||
KMath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared
|
||||
in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple
|
||||
back-ends. The new operations added as extensions to contexts instead of being member functions of data structures.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Nd-structure generation and operations
|
||||
# ND-structure generation and operations
|
||||
|
||||
**TODO**
|
||||
|
||||
|
@ -1,22 +0,0 @@
|
||||
# Local coding conventions
|
||||
|
||||
Kmath and other `scientifik` projects use general [kotlin code conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but with a number of small changes and clarifications.
|
||||
|
||||
## Utility class names
|
||||
File name should coincide with a name of one of the classes contained in the file or start with small letter and describe its contents.
|
||||
|
||||
The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that file names should start with capital letter even if file does not contain classes. Yet starting utility classes and aggregators with a small letter seems to be a good way to clearly visually separate those files.
|
||||
|
||||
This convention could be changed in future in a non-breaking way.
|
||||
|
||||
## Private variable names
|
||||
Private variable names could start with underscore `_` in case the private mutable variable is shadowed by the public read-only value with the same meaning.
|
||||
|
||||
Code convention do not permit underscores in names, but is is sometimes useful to "underscore" the fact that public and private versions define the same entity. It is allowed only for private variables.
|
||||
|
||||
This convention could be changed in future in a non-breaking way.
|
||||
|
||||
## Functions and properties one-liners
|
||||
Use one-liners when they occupy single code window line both for functions and properties with getters like `val b: String get() = "fff"`. The same should be done with multiline expressions when they could be cleanly separated.
|
||||
|
||||
There is not general consensus whenever use `fun a() = {}` or `fun a(){return}`. Yet from reader perspective one-lines seem to better show that the property or function is easily calculated.
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Abstract syntax tree expression representation and operations (`kmath-ast`)
|
||||
# Abstract Syntax Tree Expression Representation and Operations (`kmath-ast`)
|
||||
|
||||
This subproject implements the following features:
|
||||
|
||||
@ -38,7 +38,7 @@ This subproject implements the following features:
|
||||
> ```
|
||||
>
|
||||
|
||||
## Dynamic expression code generation with ObjectWeb ASM
|
||||
## Dynamic Expression Code Generation with ObjectWeb ASM
|
||||
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
@ -46,7 +46,7 @@ a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
For example, the following builder:
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
```
|
||||
|
||||
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||
@ -75,7 +75,7 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression<Doub
|
||||
|
||||
### Example Usage
|
||||
|
||||
This API is an extension to MST and MstExpression, so you may optimize as both of them:
|
||||
This API extends MST and MstExpression, so you may optimize as both of them:
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
|
@ -1,10 +1,4 @@
|
||||
plugins {
|
||||
id("scientifik.mpp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
}
|
||||
plugins { id("scientifik.mpp") }
|
||||
|
||||
kotlin.sourceSets {
|
||||
// all {
|
||||
@ -15,23 +9,15 @@ kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3")
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
||||
}
|
||||
}
|
||||
|
||||
jvmMain {
|
||||
dependencies {
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
|
||||
implementation("org.ow2.asm:asm:8.0.1")
|
||||
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||
implementation(kotlin("reflect"))
|
||||
}
|
||||
}
|
||||
|
||||
jsMain {
|
||||
dependencies {
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
|
||||
}
|
||||
}
|
||||
}
|
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
@ -0,0 +1,59 @@
|
||||
grammar ArithmeticsEvaluator;
|
||||
|
||||
fragment DIGIT: '0'..'9';
|
||||
fragment LETTER: 'a'..'z';
|
||||
fragment CAPITAL_LETTER: 'A'..'Z';
|
||||
fragment UNDERSCORE: '_';
|
||||
|
||||
ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*;
|
||||
NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?;
|
||||
MUL: '*';
|
||||
DIV: '/';
|
||||
PLUS: '+';
|
||||
MINUS: '-';
|
||||
POW: '^';
|
||||
COMMA: ',';
|
||||
LPAR: '(';
|
||||
RPAR: ')';
|
||||
WS: [ \n\t\r]+ -> skip;
|
||||
|
||||
num
|
||||
: NUM
|
||||
;
|
||||
|
||||
singular
|
||||
: ID
|
||||
;
|
||||
|
||||
unaryFunction
|
||||
: ID LPAR subSumChain RPAR
|
||||
;
|
||||
|
||||
binaryFunction
|
||||
: ID LPAR subSumChain COMMA subSumChain RPAR
|
||||
;
|
||||
|
||||
term
|
||||
: num
|
||||
| singular
|
||||
| unaryFunction
|
||||
| binaryFunction
|
||||
| MINUS term
|
||||
| LPAR subSumChain RPAR
|
||||
;
|
||||
|
||||
powChain
|
||||
: term (POW term)*
|
||||
;
|
||||
|
||||
divMulChain
|
||||
: powChain ((DIV | MUL) powChain)*
|
||||
;
|
||||
|
||||
subSumChain
|
||||
: divMulChain ((PLUS | MINUS) divMulChain)*
|
||||
;
|
||||
|
||||
rootParser
|
||||
: subSumChain EOF
|
||||
;
|
@ -5,63 +5,83 @@ import scientifik.kmath.operations.NumericAlgebra
|
||||
import scientifik.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* A Mathematical Syntax Tree node for mathematical expressions
|
||||
* A Mathematical Syntax Tree node for mathematical expressions.
|
||||
*/
|
||||
sealed class MST {
|
||||
|
||||
/**
|
||||
* A node containing unparsed string
|
||||
* A node containing raw string.
|
||||
*
|
||||
* @property value the value of this node.
|
||||
*/
|
||||
data class Symbolic(val value: String) : MST()
|
||||
|
||||
/**
|
||||
* A node containing a number
|
||||
* A node containing a numeric value or scalar.
|
||||
*
|
||||
* @property value the value of this number.
|
||||
*/
|
||||
data class Numeric(val value: Number) : MST()
|
||||
|
||||
/**
|
||||
* A node containing an unary operation
|
||||
* A node containing an unary operation.
|
||||
*
|
||||
* @property operation the identifier of operation.
|
||||
* @property value the argument of this operation.
|
||||
*/
|
||||
data class Unary(val operation: String, val value: MST) : MST() {
|
||||
companion object {
|
||||
const val ABS_OPERATION = "abs"
|
||||
//TODO add operations
|
||||
}
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* A node containing binary operation
|
||||
* A node containing binary operation.
|
||||
*
|
||||
* @property operation the identifier operation.
|
||||
* @property left the left operand.
|
||||
* @property right the right operand.
|
||||
*/
|
||||
data class Binary(val operation: String, val left: MST, val right: MST) : MST() {
|
||||
companion object
|
||||
}
|
||||
}
|
||||
|
||||
//TODO add a function with positional arguments
|
||||
// TODO add a function with named arguments
|
||||
|
||||
//TODO add a function with named arguments
|
||||
/**
|
||||
* Interprets the [MST] node with this [Algebra].
|
||||
*
|
||||
* @receiver the algebra that provides operations.
|
||||
* @param node the node to evaluate.
|
||||
* @return the value of expression.
|
||||
*/
|
||||
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
|
||||
fun <T> Algebra<T>.evaluate(node: MST): T {
|
||||
return when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
node.left.value.toDouble(),
|
||||
node.right.value.toDouble()
|
||||
)
|
||||
number(number)
|
||||
}
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
node.left.value.toDouble(),
|
||||
node.right.value.toDouble()
|
||||
)
|
||||
|
||||
number(number)
|
||||
}
|
||||
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
||||
/**
|
||||
* Interprets the [MST] node with this [Algebra].
|
||||
*
|
||||
* @receiver the node to evaluate.
|
||||
* @param algebra the algebra that provides operations.
|
||||
* @return the value of expression.
|
||||
*/
|
||||
fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)
|
||||
|
@ -2,6 +2,9 @@ package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
/**
|
||||
* [Algebra] over [MST] nodes.
|
||||
*/
|
||||
object MstAlgebra : NumericAlgebra<MST> {
|
||||
override fun number(value: Number): MST = MST.Numeric(value)
|
||||
|
||||
@ -14,17 +17,16 @@ object MstAlgebra : NumericAlgebra<MST> {
|
||||
MST.Binary(operation, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Space] over [MST] nodes.
|
||||
*/
|
||||
object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
|
||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||
|
||||
override fun add(a: MST, b: MST): MST =
|
||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
@ -32,41 +34,69 @@ object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Ring] over [MST] nodes.
|
||||
*/
|
||||
object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
override fun number(value: Number): MST = MstSpace.number(value)
|
||||
override fun symbol(value: String): MST = MstSpace.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
MstSpace.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Field] over [MST] nodes.
|
||||
*/
|
||||
object MstField : Field<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
override fun symbol(value: String): MST = MstRing.symbol(value)
|
||||
override fun number(value: Number): MST = MstRing.number(value)
|
||||
override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
MstRing.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* [ExtendedField] over [MST] nodes.
|
||||
*/
|
||||
object MstExtendedField : ExtendedField<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
override fun asin(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg)
|
||||
override fun acos(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg)
|
||||
override fun atan(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg)
|
||||
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
|
||||
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MstField.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg)
|
||||
}
|
||||
|
@ -1,19 +1,16 @@
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.expressions.FunctionalExpressionField
|
||||
import scientifik.kmath.expressions.FunctionalExpressionRing
|
||||
import scientifik.kmath.expressions.FunctionalExpressionSpace
|
||||
import scientifik.kmath.expressions.*
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
/**
|
||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
|
||||
* ASM-generated expressions.
|
||||
*
|
||||
* @property algebra the algebra that provides operations.
|
||||
* @property mst the [MST] node.
|
||||
*/
|
||||
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
||||
|
||||
/**
|
||||
* Substitute algebra raw value
|
||||
*/
|
||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||
@ -30,26 +27,62 @@ class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Algebra].
|
||||
*/
|
||||
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||
mstAlgebra: E,
|
||||
block: E.() -> MST
|
||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Space].
|
||||
*/
|
||||
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||
MstExpression(this, MstSpace.block())
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Ring].
|
||||
*/
|
||||
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||
MstExpression(this, MstRing.block())
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Field].
|
||||
*/
|
||||
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||
MstExpression(this, MstField.block())
|
||||
|
||||
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||
algebra.mstInSpace(block)
|
||||
/**
|
||||
* Builds [MstExpression] over [ExtendedField].
|
||||
*/
|
||||
inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> =
|
||||
MstExpression(this, MstExtendedField.block())
|
||||
|
||||
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||
algebra.mstInRing(block)
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionSpace].
|
||||
*/
|
||||
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(
|
||||
block: MstSpace.() -> MST
|
||||
): MstExpression<T> = algebra.mstInSpace(block)
|
||||
|
||||
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||
algebra.mstInField(block)
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionRing].
|
||||
*/
|
||||
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(
|
||||
block: MstRing.() -> MST
|
||||
): MstExpression<T> = algebra.mstInRing(block)
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionField].
|
||||
*/
|
||||
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(
|
||||
block: MstField.() -> MST
|
||||
): MstExpression<T> = algebra.mstInField(block)
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
|
||||
*/
|
||||
inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||
block: MstExtendedField.() -> MST
|
||||
): MstExpression<T> = algebra.mstInExtendedField(block)
|
||||
|
@ -5,6 +5,9 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
|
||||
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
||||
import com.github.h0tk3y.betterParse.grammar.parser
|
||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||
import com.github.h0tk3y.betterParse.lexer.Token
|
||||
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||
import com.github.h0tk3y.betterParse.lexer.regexToken
|
||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||
import com.github.h0tk3y.betterParse.parser.Parser
|
||||
import scientifik.kmath.operations.FieldOperations
|
||||
@ -13,47 +16,82 @@ import scientifik.kmath.operations.RingOperations
|
||||
import scientifik.kmath.operations.SpaceOperations
|
||||
|
||||
/**
|
||||
* TODO move to common
|
||||
* TODO move to core
|
||||
*/
|
||||
private object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
||||
val lpar by token("\\(".toRegex())
|
||||
val rpar by token("\\)".toRegex())
|
||||
val mul by token("\\*".toRegex())
|
||||
val pow by token("\\^".toRegex())
|
||||
val div by token("/".toRegex())
|
||||
val minus by token("-".toRegex())
|
||||
val plus by token("\\+".toRegex())
|
||||
val ws by token("\\s+".toRegex(), ignore = true)
|
||||
object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
||||
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
|
||||
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
|
||||
private val lpar: Token by regexToken("\\(")
|
||||
private val rpar: Token by regexToken("\\)")
|
||||
private val comma: Token by regexToken(",")
|
||||
private val mul: Token by regexToken("\\*")
|
||||
private val pow: Token by regexToken("\\^")
|
||||
private val div: Token by regexToken("/")
|
||||
private val minus: Token by regexToken("-")
|
||||
private val plus: Token by regexToken("\\+")
|
||||
private val ws: Token by regexToken("\\s+", ignore = true)
|
||||
|
||||
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
|
||||
|
||||
val term: Parser<MST> by number or
|
||||
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
|
||||
(skip(lpar) and parser(this::rootParser) and skip(rpar))
|
||||
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||
.map { (id, term) -> MST.Unary(id.text, term) }
|
||||
|
||||
val powChain by leftAssociative(term, pow) { a, _, b ->
|
||||
private val binaryFunction: Parser<MST> by id
|
||||
.and(skip(lpar))
|
||||
.and(parser(::subSumChain))
|
||||
.and(skip(comma))
|
||||
.and(parser(::subSumChain))
|
||||
.and(skip(rpar))
|
||||
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
|
||||
|
||||
private val term: Parser<MST> by number
|
||||
.or(binaryFunction)
|
||||
.or(unaryFunction)
|
||||
.or(singular)
|
||||
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
|
||||
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||
|
||||
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
||||
}
|
||||
|
||||
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
|
||||
if (op == div) {
|
||||
private val divMulChain: Parser<MST> by leftAssociative(
|
||||
term = powChain,
|
||||
operator = div or mul use TokenMatch::type
|
||||
) { a, op, b ->
|
||||
if (op == div)
|
||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||
} else {
|
||||
else
|
||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
|
||||
if (op == plus) {
|
||||
private val subSumChain: Parser<MST> by leftAssociative(
|
||||
term = divMulChain,
|
||||
operator = plus or minus use TokenMatch::type
|
||||
) { a, op, b ->
|
||||
if (op == plus)
|
||||
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
} else {
|
||||
else
|
||||
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
override val rootParser: Parser<MST> by subSumChain
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to parse the string into [MST].
|
||||
*
|
||||
* @receiver the string to parse.
|
||||
* @return the [MST] node.
|
||||
*/
|
||||
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
|
||||
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)
|
||||
|
||||
/**
|
||||
* Parses the string into [MST].
|
||||
*
|
||||
* @receiver the string to parse.
|
||||
* @return the [MST] node.
|
||||
*/
|
||||
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)
|
||||
|
@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.ast.MstExpression
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.NumericAlgebra
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
|
||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||
fun AsmBuilder<T>.visit(node: MST) {
|
||||
when (node) {
|
||||
is MST.Symbolic -> loadVariable(node.value)
|
||||
is MST.Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: Throwable) {
|
||||
null
|
||||
}
|
||||
|
||||
if (symbol != null)
|
||||
loadTConstant(symbol)
|
||||
else
|
||||
loadVariable(node.value)
|
||||
}
|
||||
|
||||
is MST.Numeric -> loadNumeric(node.value)
|
||||
|
||||
is MST.Unary -> buildAlgebraOperationCall(
|
||||
|
@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
/**
|
||||
* Loads a [T] constant from [constants].
|
||||
*/
|
||||
private fun loadTConstant(value: T) {
|
||||
internal fun loadTConstant(value: T) {
|
||||
if (classOfT in INLINABLE_NUMBERS) {
|
||||
val expectedType = expectationStack.pop()
|
||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||
@ -340,7 +340,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
checkcast(type)
|
||||
}
|
||||
|
||||
fun loadNumeric(value: Number) {
|
||||
internal fun loadNumeric(value: Number) {
|
||||
if (expectationStack.peek() == NUMBER_TYPE) {
|
||||
loadNumberConstant(value, true)
|
||||
expectationStack.pop()
|
||||
|
@ -0,0 +1,36 @@
|
||||
package scietifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.ast.evaluate
|
||||
import scientifik.kmath.ast.parseMath
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class ParserPrecedenceTest {
|
||||
private val f: Field<Double> = RealField
|
||||
|
||||
@Test
|
||||
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
|
||||
|
||||
@Test
|
||||
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
|
||||
|
||||
@Test
|
||||
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
|
||||
|
||||
@Test
|
||||
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
|
||||
|
||||
@Test
|
||||
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
|
||||
|
||||
@Test
|
||||
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
|
||||
|
||||
@Test
|
||||
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
|
||||
|
||||
@Test
|
||||
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
|
||||
}
|
@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.ast.parseMath
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.ComplexField
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@ -22,4 +24,37 @@ internal class ParserTest {
|
||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `evaluate MST with singular`() {
|
||||
val mst = "i".parseMath()
|
||||
val res = ComplexField.evaluate(mst)
|
||||
assertEquals(ComplexField.i, res)
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `evaluate MST with unary function`() {
|
||||
val mst = "sin(0)".parseMath()
|
||||
val res = RealField.evaluate(mst)
|
||||
assertEquals(0.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `evaluate MST with binary function`() {
|
||||
val magicalAlgebra = object : Algebra<String> {
|
||||
override fun symbol(value: String): String = value
|
||||
|
||||
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||
|
||||
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||
"magic" -> "$left ★ $right"
|
||||
else -> throw NotImplementedError()
|
||||
}
|
||||
}
|
||||
|
||||
val mst = "magic(a, b)".parseMath()
|
||||
val res = magicalAlgebra.evaluate(mst)
|
||||
assertEquals("a ★ b", res)
|
||||
}
|
||||
}
|
||||
|
40
kmath-core/README.md
Normal file
40
kmath-core/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
# The Core Module (`kmath-ast`)
|
||||
|
||||
The core features of KMath:
|
||||
|
||||
- Algebraic structures: contexts and elements.
|
||||
- ND structures.
|
||||
- Buffers.
|
||||
- Functional Expressions.
|
||||
- Domains.
|
||||
- Automatic differentiation.
|
||||
|
||||
> #### Artifact:
|
||||
> This module is distributed in the artifact `scientifik:kmath-core:0.1.4-dev-8`.
|
||||
>
|
||||
> **Gradle:**
|
||||
>
|
||||
> ```gradle
|
||||
> repositories {
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url https://dl.bintray.com/hotkeytlt/maven' }
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation 'scientifik:kmath-core:0.1.4-dev-8'
|
||||
> }
|
||||
> ```
|
||||
> **Gradle Kotlin DSL:**
|
||||
>
|
||||
> ```kotlin
|
||||
> repositories {
|
||||
> maven("https://dl.bintray.com/mipt-npm/scientifik")
|
||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
> }
|
||||
>
|
||||
> dependencies {``
|
||||
> implementation("scientifik:kmath-core:0.1.4-dev-8")
|
||||
> }
|
||||
> ```
|
@ -1,11 +1,7 @@
|
||||
plugins {
|
||||
id("scientifik.mpp")
|
||||
}
|
||||
plugins { id("scientifik.mpp") }
|
||||
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-memory"))
|
||||
}
|
||||
dependencies { api(project(":kmath-memory")) }
|
||||
}
|
||||
}
|
||||
|
@ -3,13 +3,18 @@ package scientifik.kmath.domains
|
||||
import scientifik.kmath.linear.Point
|
||||
|
||||
/**
|
||||
* A simple geometric domain
|
||||
* A simple geometric domain.
|
||||
*
|
||||
* @param T the type of element of this domain.
|
||||
*/
|
||||
interface Domain<T : Any> {
|
||||
/**
|
||||
* Checks if the specified point is contained in this domain.
|
||||
*/
|
||||
operator fun contains(point: Point<T>): Boolean
|
||||
|
||||
/**
|
||||
* Number of hyperspace dimensions
|
||||
* Number of hyperspace dimensions.
|
||||
*/
|
||||
val dimension: Int
|
||||
}
|
||||
}
|
||||
|
@ -42,13 +42,14 @@ class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBu
|
||||
override fun getUpperBound(num: Int): Double? = upper[num]
|
||||
|
||||
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||
val res: DoubleArray = DoubleArray(point.size) { i ->
|
||||
val res = DoubleArray(point.size) { i ->
|
||||
when {
|
||||
point[i] < lower[i] -> lower[i]
|
||||
point[i] > upper[i] -> upper[i]
|
||||
else -> point[i]
|
||||
}
|
||||
}
|
||||
|
||||
return RealBuffer(*res)
|
||||
}
|
||||
|
||||
@ -64,4 +65,4 @@ class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBu
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,8 +22,7 @@ import scientifik.kmath.linear.Point
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
interface RealDomain: Domain<Double> {
|
||||
|
||||
interface RealDomain : Domain<Double> {
|
||||
fun nearestInDomain(point: Point<Double>): Point<Double>
|
||||
|
||||
/**
|
||||
@ -61,5 +60,4 @@ interface RealDomain: Domain<Double> {
|
||||
* @return
|
||||
*/
|
||||
fun volume(): Double
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -18,7 +18,6 @@ package scientifik.kmath.domains
|
||||
import scientifik.kmath.linear.Point
|
||||
|
||||
class UnconstrainedDomain(override val dimension: Int) : RealDomain {
|
||||
|
||||
override operator fun contains(point: Point<Double>): Boolean = true
|
||||
|
||||
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
|
||||
@ -32,5 +31,4 @@ class UnconstrainedDomain(override val dimension: Int) : RealDomain {
|
||||
override fun nearestInDomain(point: Point<Double>): Point<Double> = point
|
||||
|
||||
override fun volume(): Double = Double.POSITIVE_INFINITY
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
|
||||
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
|
||||
|
||||
operator fun contains(d: Double): Boolean = range.contains(d)
|
||||
|
||||
override operator fun contains(point: Point<Double>): Boolean {
|
||||
@ -15,10 +14,10 @@ inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : Rea
|
||||
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||
require(point.size == 1)
|
||||
val value = point[0]
|
||||
return when{
|
||||
return when {
|
||||
value in range -> point
|
||||
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
|
||||
else -> doubleArrayOf(range.start).asBuffer()
|
||||
else -> doubleArrayOf(range.start).asBuffer()
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,4 +44,4 @@ inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : Rea
|
||||
override fun volume(): Double = range.endInclusive - range.start
|
||||
|
||||
override val dimension: Int get() = 1
|
||||
}
|
||||
}
|
||||
|
@ -1,23 +1,31 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
/**
|
||||
* Create a functional expression on this [Space]
|
||||
* Creates a functional expression with this [Space].
|
||||
*/
|
||||
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionSpace(this).run(block)
|
||||
|
||||
/**
|
||||
* Create a functional expression on this [Ring]
|
||||
* Creates a functional expression with this [Ring].
|
||||
*/
|
||||
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionRing(this).run(block)
|
||||
|
||||
/**
|
||||
* Create a functional expression on this [Field]
|
||||
* Creates a functional expression with this [Field].
|
||||
*/
|
||||
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionField(this).run(block)
|
||||
|
||||
/**
|
||||
* Creates a functional expression with this [ExtendedField].
|
||||
*/
|
||||
fun <T> ExtendedField<T>.fieldExpression(
|
||||
block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>
|
||||
): Expression<T> = FunctionalExpressionExtendedField(this).run(block)
|
||||
|
@ -6,6 +6,12 @@ import scientifik.kmath.operations.Algebra
|
||||
* An elementary function that could be invoked on a map of arguments
|
||||
*/
|
||||
interface Expression<T> {
|
||||
/**
|
||||
* Calls this expression from arguments.
|
||||
*
|
||||
* @param arguments the map of arguments.
|
||||
* @return the value.
|
||||
*/
|
||||
operator fun invoke(arguments: Map<String, T>): T
|
||||
|
||||
companion object
|
||||
@ -14,10 +20,17 @@ interface Expression<T> {
|
||||
/**
|
||||
* Create simple lazily evaluated expression inside given algebra
|
||||
*/
|
||||
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> = object: Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = block(arguments)
|
||||
}
|
||||
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> =
|
||||
object : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = block(arguments)
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls this expression from arguments.
|
||||
*
|
||||
* @param pairs the pair of arguments' names to values.
|
||||
* @return the value.
|
||||
*/
|
||||
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
||||
|
||||
/**
|
||||
@ -33,4 +46,4 @@ interface ExpressionAlgebra<T, E> : Algebra<E> {
|
||||
* A constant expression which does not depend on arguments
|
||||
*/
|
||||
fun const(value: T): E
|
||||
}
|
||||
}
|
||||
|
@ -40,7 +40,6 @@ internal class FunctionalConstProductExpression<T>(
|
||||
* @param algebra The algebra to provide for Expressions built.
|
||||
*/
|
||||
abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) : ExpressionAlgebra<T, Expression<T>> {
|
||||
|
||||
/**
|
||||
* Builds an Expression of constant expression which does not depend on arguments.
|
||||
*/
|
||||
@ -69,14 +68,13 @@ abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) :
|
||||
*/
|
||||
open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||
|
||||
override val zero: Expression<T> get() = const(algebra.zero)
|
||||
|
||||
/**
|
||||
* Builds an Expression of addition of two another expressions.
|
||||
*/
|
||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
|
||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
/**
|
||||
* Builds an Expression of multiplication of expression by number.
|
||||
@ -105,7 +103,7 @@ open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpac
|
||||
* Builds an Expression of multiplication of two expressions.
|
||||
*/
|
||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
|
||||
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||
@ -124,7 +122,7 @@ open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
* Builds an Expression of division an expression by another one.
|
||||
*/
|
||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
|
||||
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
|
||||
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||
@ -136,6 +134,34 @@ open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||
FunctionalExpressionField<T, A>(algebra),
|
||||
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||
override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
|
||||
override fun asin(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg)
|
||||
|
||||
override fun acos(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg)
|
||||
|
||||
override fun atan(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg)
|
||||
|
||||
override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
|
||||
override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionField>.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionSpace(this).block()
|
||||
|
||||
@ -143,4 +169,7 @@ inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T
|
||||
FunctionalExpressionRing(this).block()
|
||||
|
||||
inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionField(this).block()
|
||||
FunctionalExpressionField(this).block()
|
||||
|
||||
inline fun <T, A : ExtendedField<T>> A.expressionInExtendedField(block: FunctionalExpressionExtendedField<T, A>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionExtendedField(this).block()
|
||||
|
@ -19,22 +19,20 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
|
||||
|
||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
companion object
|
||||
}
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||
|
||||
override val elementContext get() = RealField
|
||||
override val elementContext: RealField get() = RealField
|
||||
|
||||
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
|
||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size,initializer)
|
||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size, initializer)
|
||||
}
|
||||
|
||||
class BufferMatrix<T : Any>(
|
||||
@ -52,7 +50,7 @@ class BufferMatrix<T : Any>(
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||
override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
||||
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
||||
|
||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||
@ -84,8 +82,8 @@ class BufferMatrix<T : Any>(
|
||||
override fun toString(): String {
|
||||
return if (rowNum <= 5 && colNum <= 5) {
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
||||
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") {
|
||||
it.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
|
||||
buffer.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||
}
|
||||
} else {
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
||||
@ -121,4 +119,4 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
|
||||
|
||||
val buffer = RealBuffer(array)
|
||||
return BufferMatrix(rowNum, other.colNum, buffer)
|
||||
}
|
||||
}
|
||||
|
@ -23,12 +23,10 @@ interface FeaturedMatrix<T : Any> : Matrix<T> {
|
||||
*/
|
||||
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
companion object
|
||||
}
|
||||
|
||||
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
|
||||
MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
@ -41,7 +39,7 @@ fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet()
|
||||
val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
|
||||
|
||||
/**
|
||||
* Check if matrix has the given feature class
|
||||
@ -68,7 +66,7 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
|
||||
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero }
|
||||
|
||||
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
@ -83,4 +81,4 @@ fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
|
||||
) { i, j -> get(j, i) }
|
||||
}
|
||||
|
||||
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }
|
||||
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }
|
||||
|
@ -18,7 +18,7 @@ class LUPDecomposition<T : Any>(
|
||||
private val even: Boolean
|
||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||
|
||||
val elementContext get() = context.elementContext
|
||||
val elementContext: Field<T> get() = context.elementContext
|
||||
|
||||
/**
|
||||
* Returns the matrix L of the decomposition.
|
||||
@ -67,7 +67,7 @@ class LUPDecomposition<T : Any>(
|
||||
|
||||
}
|
||||
|
||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
|
||||
if (value > elementContext.zero) value else with(elementContext) { -value }
|
||||
|
||||
|
||||
@ -169,9 +169,10 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
matrix: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean
|
||||
) = lup(T::class, matrix, checkSingular)
|
||||
): LUPDecomposition<T> = lup(T::class, matrix, checkSingular)
|
||||
|
||||
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 }
|
||||
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
|
||||
lup(Double::class, matrix) { it < 1e-11 }
|
||||
|
||||
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
||||
|
||||
@ -185,7 +186,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
|
||||
// Apply permutations to b
|
||||
val bp = create { _, _ -> zero }
|
||||
|
||||
for (row in 0 until pivot.size) {
|
||||
for (row in pivot.indices) {
|
||||
val bpRow = bp.row(row)
|
||||
val pRow = pivot[row]
|
||||
for (col in 0 until matrix.colNum) {
|
||||
@ -194,7 +195,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
|
||||
}
|
||||
|
||||
// Solve LY = b
|
||||
for (col in 0 until pivot.size) {
|
||||
for (col in pivot.indices) {
|
||||
val bpCol = bp.row(col)
|
||||
for (i in col + 1 until pivot.size) {
|
||||
val bpI = bp.row(i)
|
||||
@ -225,7 +226,7 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
|
||||
}
|
||||
}
|
||||
|
||||
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix)
|
||||
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> = solve(T::class, matrix)
|
||||
|
||||
/**
|
||||
* Solve a linear equation **a*x = b**
|
||||
@ -240,13 +241,12 @@ inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.
|
||||
return decomposition.solve(T::class, b)
|
||||
}
|
||||
|
||||
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) =
|
||||
solve(a, b) { it < 1e-11 }
|
||||
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> = solve(a, b) { it < 1e-11 }
|
||||
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
||||
matrix: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean
|
||||
) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||
): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||
|
||||
fun RealMatrixContext.inverse(matrix: Matrix<Double>) =
|
||||
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
||||
fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
|
||||
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
||||
|
@ -25,4 +25,4 @@ fun <T : Any> Matrix<T>.asPoint(): Point<T> =
|
||||
error("Can't convert matrix with more than one column to vector")
|
||||
}
|
||||
|
||||
fun <T : Any> Point<T>.asMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
||||
fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
||||
|
@ -29,7 +29,7 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
||||
/**
|
||||
* Non-boxing double matrix
|
||||
*/
|
||||
val real = RealMatrixContext
|
||||
val real: RealMatrixContext = RealMatrixContext
|
||||
|
||||
/**
|
||||
* A structured matrix with custom buffer
|
||||
@ -82,12 +82,12 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||
}
|
||||
}
|
||||
|
||||
override operator fun Matrix<T>.unaryMinus() =
|
||||
override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||
|
||||
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]")
|
||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) + b[i, j] } }
|
||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } }
|
||||
}
|
||||
|
||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||
@ -96,7 +96,7 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||
}
|
||||
|
||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) * k } }
|
||||
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } }
|
||||
|
||||
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
/**
|
||||
* A marker interface representing some matrix feature like diagonal, sparce, zero, etc. Features used to optimize matrix
|
||||
* A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix
|
||||
* operations performance in some cases.
|
||||
*/
|
||||
interface MatrixFeature
|
||||
@ -36,19 +36,19 @@ interface DeterminantFeature<T : Any> : MatrixFeature {
|
||||
}
|
||||
|
||||
@Suppress("FunctionName")
|
||||
fun <T: Any> DeterminantFeature(determinant: T) = object: DeterminantFeature<T>{
|
||||
fun <T : Any> DeterminantFeature(determinant: T): DeterminantFeature<T> = object : DeterminantFeature<T> {
|
||||
override val determinant: T = determinant
|
||||
}
|
||||
|
||||
/**
|
||||
* Lower triangular matrix
|
||||
*/
|
||||
object LFeature: MatrixFeature
|
||||
object LFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Upper triangular feature
|
||||
*/
|
||||
object UFeature: MatrixFeature
|
||||
object UFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* TODO add documentation
|
||||
@ -59,4 +59,4 @@ interface LUPDecompositionFeature<T : Any> : MatrixFeature {
|
||||
val p: FeaturedMatrix<T>
|
||||
}
|
||||
|
||||
//TODO add sparse matrix feature
|
||||
//TODO add sparse matrix feature
|
||||
|
@ -54,7 +54,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||
size: Int,
|
||||
space: S,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||
) = BufferVectorSpace(size, space, bufferFactory)
|
||||
): BufferVectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory)
|
||||
|
||||
/**
|
||||
* Automatic buffered vector, unboxed if it is possible
|
||||
@ -70,6 +70,6 @@ class BufferVectorSpace<T : Any, S : Space<T>>(
|
||||
override val space: S,
|
||||
val bufferFactory: BufferFactory<T>
|
||||
) : VectorSpace<T, S> {
|
||||
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
|
||||
override fun produce(initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
|
||||
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ class VirtualMatrix<T : Any>(
|
||||
|
||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
|
||||
VirtualMatrix(rowNum, colNum, this.features + features, generator)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
@ -56,4 +56,4 @@ class VirtualMatrix<T : Any>(
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,12 +22,12 @@ class DerivationResult<T : Any>(
|
||||
val deriv: Map<Variable<T>, T>,
|
||||
val context: Field<T>
|
||||
) : Variable<T>(value) {
|
||||
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
|
||||
fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
||||
|
||||
/**
|
||||
* compute divergence
|
||||
*/
|
||||
fun div() = context.run { sum(deriv.values) }
|
||||
fun div(): T = context.run { sum(deriv.values) }
|
||||
|
||||
/**
|
||||
* Compute a gradient for variables in given order
|
||||
@ -53,7 +53,7 @@ class DerivationResult<T : Any>(
|
||||
* ```
|
||||
*/
|
||||
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
|
||||
AutoDiffContext<T, F>(this).run {
|
||||
AutoDiffContext(this).run {
|
||||
val result = body()
|
||||
result.d = context.one// computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
@ -86,7 +86,7 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||
|
||||
abstract fun variable(value: T): Variable<T>
|
||||
|
||||
inline fun variable(block: F.() -> T) = variable(context.block())
|
||||
inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
@ -236,4 +236,4 @@ fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Var
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
||||
derive(variable { cos(x.value) }) { z ->
|
||||
x.d -= z.d * sin(x.value)
|
||||
}
|
||||
}
|
||||
|
@ -43,4 +43,4 @@ fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Seque
|
||||
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
||||
if (numPoints < 2) error("Can't create generic grid with less than two points")
|
||||
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,15 @@
|
||||
package scientifik.kmath.misc
|
||||
|
||||
import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.operations.invoke
|
||||
import kotlin.jvm.JvmName
|
||||
|
||||
|
||||
/**
|
||||
* Generic cumulative operation on iterator
|
||||
* @param T type of initial iterable
|
||||
* @param R type of resulting iterable
|
||||
* @param initial lazy evaluated
|
||||
* Generic cumulative operation on iterator.
|
||||
*
|
||||
* @param T the type of initial iterable.
|
||||
* @param R the type of resulting iterable.
|
||||
* @param initial lazy evaluated.
|
||||
*/
|
||||
fun <T, R> Iterator<T>.cumulative(initial: R, operation: (R, T) -> R): Iterator<R> = object : Iterator<R> {
|
||||
var state: R = initial
|
||||
@ -36,41 +37,41 @@ fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
|
||||
/**
|
||||
* Cumulative sum with custom space
|
||||
*/
|
||||
fun <T> Iterable<T>.cumulativeSum(space: Space<T>) = with(space) {
|
||||
fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> = space {
|
||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
||||
}
|
||||
|
||||
@JvmName("cumulativeSumOfDouble")
|
||||
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
||||
fun Iterable<Double>.cumulativeSum(): Iterable<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
||||
|
||||
@JvmName("cumulativeSumOfInt")
|
||||
fun Iterable<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
||||
fun Iterable<Int>.cumulativeSum(): Iterable<Int> = this.cumulative(0) { element, sum -> sum + element }
|
||||
|
||||
@JvmName("cumulativeSumOfLong")
|
||||
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
||||
fun Iterable<Long>.cumulativeSum(): Iterable<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
||||
|
||||
fun <T> Sequence<T>.cumulativeSum(space: Space<T>) = with(space) {
|
||||
fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> = with(space) {
|
||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
||||
}
|
||||
|
||||
@JvmName("cumulativeSumOfDouble")
|
||||
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
||||
fun Sequence<Double>.cumulativeSum(): Sequence<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
||||
|
||||
@JvmName("cumulativeSumOfInt")
|
||||
fun Sequence<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
||||
fun Sequence<Int>.cumulativeSum(): Sequence<Int> = this.cumulative(0) { element, sum -> sum + element }
|
||||
|
||||
@JvmName("cumulativeSumOfLong")
|
||||
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
||||
fun Sequence<Long>.cumulativeSum(): Sequence<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
||||
|
||||
fun <T> List<T>.cumulativeSum(space: Space<T>) = with(space) {
|
||||
fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> = with(space) {
|
||||
cumulative(zero) { element: T, sum: T -> sum + element }
|
||||
}
|
||||
|
||||
@JvmName("cumulativeSumOfDouble")
|
||||
fun List<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
||||
fun List<Double>.cumulativeSum(): List<Double> = this.cumulative(0.0) { element, sum -> sum + element }
|
||||
|
||||
@JvmName("cumulativeSumOfInt")
|
||||
fun List<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
||||
fun List<Int>.cumulativeSum(): List<Int> = this.cumulative(0) { element, sum -> sum + element }
|
||||
|
||||
@JvmName("cumulativeSumOfLong")
|
||||
fun List<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
||||
fun List<Long>.cumulativeSum(): List<Long> = this.cumulative(0L) { element, sum -> sum + element }
|
||||
|
@ -1,10 +1,15 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
/**
|
||||
* Stub for DSL the [Algebra] is.
|
||||
*/
|
||||
@DslMarker
|
||||
annotation class KMathContext
|
||||
|
||||
/**
|
||||
* Marker interface for any algebra
|
||||
* Represents an algebraic structure.
|
||||
*
|
||||
* @param T the type of element of this structure.
|
||||
*/
|
||||
interface Algebra<T> {
|
||||
/**
|
||||
@ -24,50 +29,123 @@ interface Algebra<T> {
|
||||
}
|
||||
|
||||
/**
|
||||
* An algebra with numeric representation of members
|
||||
* An algebraic structure where elements can have numeric representation.
|
||||
*
|
||||
* @param T the type of element of this structure.
|
||||
*/
|
||||
interface NumericAlgebra<T> : Algebra<T> {
|
||||
/**
|
||||
* Wrap a number
|
||||
* Wraps a number.
|
||||
*/
|
||||
fun number(value: Number): T
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
|
||||
*/
|
||||
fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
binaryOperation(operation, number(left), right)
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
|
||||
*/
|
||||
fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
leftSideNumberOperation(operation, right, left)
|
||||
}
|
||||
|
||||
/**
|
||||
* Call a block with an [Algebra] as receiver
|
||||
* Call a block with an [Algebra] as receiver.
|
||||
*/
|
||||
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
|
||||
|
||||
/**
|
||||
* Space-like operations without neutral element
|
||||
* Represents semispace, i.e. algebraic structure with associative binary operation called "addition" as well as
|
||||
* multiplication by scalars.
|
||||
*
|
||||
* In KMath groups are called spaces, and also define multiplication of element by [Number].
|
||||
*
|
||||
* @param T the type of element of this semigroup.
|
||||
*/
|
||||
interface SpaceOperations<T> : Algebra<T> {
|
||||
/**
|
||||
* Addition operation for two context elements
|
||||
* Addition of two elements.
|
||||
*
|
||||
* @param a the addend.
|
||||
* @param b the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
fun add(a: T, b: T): T
|
||||
|
||||
/**
|
||||
* Multiplication operation for context element and real number
|
||||
* Multiplication of element by scalar.
|
||||
*
|
||||
* @param a the multiplier.
|
||||
* @param k the multiplicand.
|
||||
* @return the produce.
|
||||
*/
|
||||
fun multiply(a: T, k: Number): T
|
||||
|
||||
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
|
||||
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176
|
||||
|
||||
/**
|
||||
* The negation of this element.
|
||||
*
|
||||
* @receiver this value.
|
||||
* @return the additive inverse of this value.
|
||||
*/
|
||||
operator fun T.unaryMinus(): T = multiply(this, -1.0)
|
||||
|
||||
/**
|
||||
* Returns this value.
|
||||
*
|
||||
* @receiver this value.
|
||||
* @return this value.
|
||||
*/
|
||||
operator fun T.unaryPlus(): T = this
|
||||
|
||||
/**
|
||||
* Addition of two elements.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
operator fun T.plus(b: T): T = add(this, b)
|
||||
|
||||
/**
|
||||
* Subtraction of two elements.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @return the difference.
|
||||
*/
|
||||
operator fun T.minus(b: T): T = add(this, -b)
|
||||
operator fun T.times(k: Number) = multiply(this, k.toDouble())
|
||||
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
|
||||
operator fun Number.times(b: T) = b * this
|
||||
|
||||
/**
|
||||
* Multiplication of this element by a scalar.
|
||||
*
|
||||
* @receiver the multiplier.
|
||||
* @param k the multiplicand.
|
||||
* @return the product.
|
||||
*/
|
||||
operator fun T.times(k: Number): T = multiply(this, k.toDouble())
|
||||
|
||||
/**
|
||||
* Division of this element by scalar.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param k the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble())
|
||||
|
||||
/**
|
||||
* Multiplication of this number by element.
|
||||
*
|
||||
* @receiver the multiplier.
|
||||
* @param b the multiplicand.
|
||||
* @return the product.
|
||||
*/
|
||||
operator fun Number.times(b: T): T = b * this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
PLUS_OPERATION -> arg
|
||||
@ -82,37 +160,54 @@ interface SpaceOperations<T> : Algebra<T> {
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val PLUS_OPERATION = "+"
|
||||
const val MINUS_OPERATION = "-"
|
||||
const val NOT_OPERATION = "!"
|
||||
/**
|
||||
* The identifier of addition.
|
||||
*/
|
||||
const val PLUS_OPERATION: String = "+"
|
||||
|
||||
/**
|
||||
* The identifier of subtraction (and negation).
|
||||
*/
|
||||
const val MINUS_OPERATION: String = "-"
|
||||
|
||||
const val NOT_OPERATION: String = "!"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A general interface representing linear context of some kind.
|
||||
* The context defines sum operation for its elements and multiplication by real value.
|
||||
* One must note that in some cases context is a singleton class, but in some cases it
|
||||
* works as a context for operations inside it.
|
||||
* Represents linear space, i.e. algebraic structure with associative binary operation called "addition" and its neutral
|
||||
* element as well as multiplication by scalars.
|
||||
*
|
||||
* TODO do we need non-commutative context?
|
||||
* @param T the type of element of this group.
|
||||
*/
|
||||
interface Space<T> : SpaceOperations<T> {
|
||||
/**
|
||||
* Neutral element for sum operation
|
||||
* The neutral element of addition.
|
||||
*/
|
||||
val zero: T
|
||||
}
|
||||
|
||||
/**
|
||||
* Operations on ring without multiplication neutral element
|
||||
* Represents semiring, i.e. algebraic structure with two associative binary operations called "addition" and
|
||||
* "multiplication".
|
||||
*
|
||||
* @param T the type of element of this semiring.
|
||||
*/
|
||||
interface RingOperations<T> : SpaceOperations<T> {
|
||||
/**
|
||||
* Multiplication for two field elements
|
||||
* Multiplies two elements.
|
||||
*
|
||||
* @param a the multiplier.
|
||||
* @param b the multiplicand.
|
||||
*/
|
||||
fun multiply(a: T, b: T): T
|
||||
|
||||
/**
|
||||
* Multiplies this element by scalar.
|
||||
*
|
||||
* @receiver the multiplier.
|
||||
* @param b the multiplicand.
|
||||
*/
|
||||
operator fun T.times(b: T): T = multiply(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
@ -121,12 +216,18 @@ interface RingOperations<T> : SpaceOperations<T> {
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val TIMES_OPERATION = "*"
|
||||
/**
|
||||
* The identifier of multiplication.
|
||||
*/
|
||||
const val TIMES_OPERATION: String = "*"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The same as {@link Space} but with additional multiplication operation
|
||||
* Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and
|
||||
* "multiplication" and their neutral elements.
|
||||
*
|
||||
* @param T the type of element of this ring.
|
||||
*/
|
||||
interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||
/**
|
||||
@ -150,20 +251,64 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Addition of element and scalar.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
operator fun T.plus(b: Number): T = this + number(b)
|
||||
|
||||
operator fun T.plus(b: Number) = this.plus(number(b))
|
||||
operator fun Number.plus(b: T) = b + this
|
||||
/**
|
||||
* Addition of scalar and element.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
operator fun Number.plus(b: T): T = b + this
|
||||
|
||||
operator fun T.minus(b: Number) = this.minus(number(b))
|
||||
operator fun Number.minus(b: T) = -b + this
|
||||
/**
|
||||
* Subtraction of element from number.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
operator fun T.minus(b: Number): T = this - number(b)
|
||||
|
||||
/**
|
||||
* Subtraction of number from element.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
operator fun Number.minus(b: T): T = -b + this
|
||||
}
|
||||
|
||||
/**
|
||||
* All ring operations but without neutral elements
|
||||
* Represents semifield, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
|
||||
* and "division".
|
||||
*
|
||||
* @param T the type of element of this semifield.
|
||||
*/
|
||||
interface FieldOperations<T> : RingOperations<T> {
|
||||
/**
|
||||
* Division of two elements.
|
||||
*
|
||||
* @param a the dividend.
|
||||
* @param b the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
fun divide(a: T, b: T): T
|
||||
|
||||
/**
|
||||
* Division of two elements.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param b the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun T.div(b: T): T = divide(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
@ -172,13 +317,26 @@ interface FieldOperations<T> : RingOperations<T> {
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DIV_OPERATION = "/"
|
||||
/**
|
||||
* The identifier of division.
|
||||
*/
|
||||
const val DIV_OPERATION: String = "/"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Four operations algebra
|
||||
* Represents field, i.e. algebraic structure with three operations: associative "addition" and "multiplication",
|
||||
* and "division" and their neutral elements.
|
||||
*
|
||||
* @param T the type of element of this semifield.
|
||||
*/
|
||||
interface Field<T> : Ring<T>, FieldOperations<T> {
|
||||
operator fun Number.div(b: T) = this * divide(one, b)
|
||||
/**
|
||||
* Division of element by scalar.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param b the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun Number.div(b: T): T = this * divide(one, b)
|
||||
}
|
||||
|
@ -2,47 +2,107 @@ package scientifik.kmath.operations
|
||||
|
||||
/**
|
||||
* The generic mathematics elements which is able to store its context
|
||||
* @param T the type of space operation results
|
||||
* @param I self type of the element. Needed for static type checking
|
||||
* @param C the type of mathematical context for this element
|
||||
*
|
||||
* @param C the type of mathematical context for this element.
|
||||
*/
|
||||
interface MathElement<C> {
|
||||
/**
|
||||
* The context this element belongs to
|
||||
* The context this element belongs to.
|
||||
*/
|
||||
val context: C
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents element that can be wrapped to its "primitive" value.
|
||||
*
|
||||
* @param T the type wrapped by this wrapper.
|
||||
* @param I the type of this wrapper.
|
||||
*/
|
||||
interface MathWrapper<T, I> {
|
||||
/**
|
||||
* Unwraps [I] to [T].
|
||||
*/
|
||||
fun unwrap(): T
|
||||
|
||||
/**
|
||||
* Wraps [T] to [I].
|
||||
*/
|
||||
fun T.wrap(): I
|
||||
}
|
||||
|
||||
/**
|
||||
* The element of linear context
|
||||
* @param T the type of space operation results
|
||||
* @param I self type of the element. Needed for static type checking
|
||||
* @param S the type of space
|
||||
* The element of [Space].
|
||||
*
|
||||
* @param T the type of space operation results.
|
||||
* @param I self type of the element. Needed for static type checking.
|
||||
* @param S the type of space.
|
||||
*/
|
||||
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> {
|
||||
/**
|
||||
* Adds element to this one.
|
||||
*
|
||||
* @param b the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
operator fun plus(b: T): I = context.add(unwrap(), b).wrap()
|
||||
|
||||
operator fun plus(b: T) = context.add(unwrap(), b).wrap()
|
||||
operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
|
||||
operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap()
|
||||
operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
|
||||
/**
|
||||
* Subtracts element from this one.
|
||||
*
|
||||
* @param b the subtrahend.
|
||||
* @return the difference.
|
||||
*/
|
||||
operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
|
||||
|
||||
/**
|
||||
* Multiplies this element by number.
|
||||
*
|
||||
* @param k the multiplicand.
|
||||
* @return the product.
|
||||
*/
|
||||
operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap()
|
||||
|
||||
/**
|
||||
* Divides this element by number.
|
||||
*
|
||||
* @param k the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring element
|
||||
* The element of [Ring].
|
||||
*
|
||||
* @param T the type of space operation results.
|
||||
* @param I self type of the element. Needed for static type checking.
|
||||
* @param R the type of space.
|
||||
*/
|
||||
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
||||
operator fun times(b: T) = context.multiply(unwrap(), b).wrap()
|
||||
/**
|
||||
* Multiplies this element by another one.
|
||||
*
|
||||
* @param b the multiplicand.
|
||||
* @return the product.
|
||||
*/
|
||||
operator fun times(b: T): I = context.multiply(unwrap(), b).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Field element
|
||||
* The element of [Field].
|
||||
*
|
||||
* @param T the type of space operation results.
|
||||
* @param I self type of the element. Needed for static type checking.
|
||||
* @param F the type of field.
|
||||
*/
|
||||
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
|
||||
override val context: F
|
||||
operator fun div(b: T) = context.divide(unwrap(), b).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Divides this element by another one.
|
||||
*
|
||||
* @param b the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun div(b: T): I = context.divide(unwrap(), b).wrap()
|
||||
}
|
||||
|
@ -1,15 +1,62 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
/**
|
||||
* Returns the sum of all elements in the iterable in this [Space].
|
||||
*
|
||||
* @receiver the algebra that provides addition.
|
||||
* @param data the collection to sum up.
|
||||
* @return the sum.
|
||||
*/
|
||||
fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> add(left, right) }
|
||||
|
||||
/**
|
||||
* Returns the sum of all elements in the sequence in this [Space].
|
||||
*
|
||||
* @receiver the algebra that provides addition.
|
||||
* @param data the collection to sum up.
|
||||
* @return the sum.
|
||||
*/
|
||||
fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> add(left, right) }
|
||||
|
||||
/**
|
||||
* Returns the sum of all elements in the iterable in provided space.
|
||||
*
|
||||
* @receiver the collection to sum up.
|
||||
* @param space the algebra that provides addition.
|
||||
* @return the sum.
|
||||
*/
|
||||
fun <T : Any, S : Space<T>> Iterable<T>.sumWith(space: S): T = space.sum(this)
|
||||
|
||||
//TODO optimized power operation
|
||||
fun <T> RingOperations<T>.power(arg: T, power: Int): T {
|
||||
|
||||
/**
|
||||
* Raises [arg] to the natural power [power].
|
||||
*
|
||||
* @receiver the algebra to provide multiplication.
|
||||
* @param arg the base.
|
||||
* @param power the exponent.
|
||||
* @return the base raised to the power.
|
||||
*/
|
||||
fun <T> Ring<T>.power(arg: T, power: Int): T {
|
||||
require(power >= 0) { "The power can't be negative." }
|
||||
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
|
||||
if (power == 0) return one
|
||||
var res = arg
|
||||
repeat(power - 1) {
|
||||
res *= arg
|
||||
}
|
||||
repeat(power - 1) { res *= arg }
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Raises [arg] to the integer power [power].
|
||||
*
|
||||
* @receiver the algebra to provide multiplication and division.
|
||||
* @param arg the base.
|
||||
* @param power the exponent.
|
||||
* @return the base raised to the power.
|
||||
*/
|
||||
fun <T> Field<T>.power(arg: T, power: Int): T {
|
||||
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
|
||||
if (power == 0) return one
|
||||
if (power < 0) return one / (this as Ring<T>).power(arg, -power)
|
||||
return (this as Ring<T>).power(arg, power)
|
||||
}
|
||||
|
@ -194,8 +194,8 @@ class BigInt internal constructor(
|
||||
}
|
||||
|
||||
infix fun or(other: BigInt): BigInt {
|
||||
if (this == ZERO) return other;
|
||||
if (other == ZERO) return this;
|
||||
if (this == ZERO) return other
|
||||
if (other == ZERO) return this
|
||||
val resSize = max(this.magnitude.size, other.magnitude.size)
|
||||
val newMagnitude: Magnitude = Magnitude(resSize)
|
||||
for (i in 0 until resSize) {
|
||||
@ -210,7 +210,7 @@ class BigInt internal constructor(
|
||||
}
|
||||
|
||||
infix fun and(other: BigInt): BigInt {
|
||||
if ((this == ZERO) or (other == ZERO)) return ZERO;
|
||||
if ((this == ZERO) or (other == ZERO)) return ZERO
|
||||
val resSize = min(this.magnitude.size, other.magnitude.size)
|
||||
val newMagnitude: Magnitude = Magnitude(resSize)
|
||||
for (i in 0 until resSize) {
|
||||
@ -260,7 +260,7 @@ class BigInt internal constructor(
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val BASE = 0xffffffffUL
|
||||
const val BASE: ULong = 0xffffffffUL
|
||||
const val BASE_SIZE: Int = 32
|
||||
val ZERO: BigInt = BigInt(0, uintArrayOf())
|
||||
val ONE: BigInt = BigInt(1, uintArrayOf(1u))
|
||||
@ -394,12 +394,12 @@ fun abs(x: BigInt): BigInt = x.abs()
|
||||
/**
|
||||
* Convert this [Int] to [BigInt]
|
||||
*/
|
||||
fun Int.toBigInt() = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))
|
||||
fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))
|
||||
|
||||
/**
|
||||
* Convert this [Long] to [BigInt]
|
||||
*/
|
||||
fun Long.toBigInt() = BigInt(
|
||||
fun Long.toBigInt(): BigInt = BigInt(
|
||||
sign.toByte(), stripLeadingZeros(
|
||||
uintArrayOf(
|
||||
(kotlin.math.abs(this).toULong() and BASE).toUInt(),
|
||||
@ -411,17 +411,17 @@ fun Long.toBigInt() = BigInt(
|
||||
/**
|
||||
* Convert UInt to [BigInt]
|
||||
*/
|
||||
fun UInt.toBigInt() = BigInt(1, uintArrayOf(this))
|
||||
fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this))
|
||||
|
||||
/**
|
||||
* Convert ULong to [BigInt]
|
||||
*/
|
||||
fun ULong.toBigInt() = BigInt(
|
||||
fun ULong.toBigInt(): BigInt = BigInt(
|
||||
1,
|
||||
stripLeadingZeros(
|
||||
uintArrayOf(
|
||||
(this and BigInt.BASE).toUInt(),
|
||||
((this shr BigInt.BASE_SIZE) and BigInt.BASE).toUInt()
|
||||
(this and BASE).toUInt(),
|
||||
((this shr BASE_SIZE) and BASE).toUInt()
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -434,7 +434,7 @@ fun UIntArray.toBigInt(sign: Byte): BigInt {
|
||||
return BigInt(sign, this.copyOf())
|
||||
}
|
||||
|
||||
val hexChToInt = hashMapOf(
|
||||
val hexChToInt: MutableMap<Char, Int> = hashMapOf(
|
||||
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
|
||||
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
|
||||
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
|
||||
@ -497,4 +497,4 @@ fun NDElement.Companion.bigInt(
|
||||
vararg shape: Int,
|
||||
initializer: BigIntField.(IntArray) -> BigInt
|
||||
): BufferedNDRingElement<BigInt, BigIntField> =
|
||||
NDAlgebra.bigInt(*shape).produce(initializer)
|
||||
NDAlgebra.bigInt(*shape).produce(initializer)
|
||||
|
@ -38,18 +38,19 @@ val Complex.theta: Double
|
||||
private val PI_DIV_2 = Complex(PI / 2, 0)
|
||||
|
||||
/**
|
||||
* A field for complex numbers.
|
||||
* A field of [Complex].
|
||||
*/
|
||||
object ComplexField : ExtendedField<Complex> {
|
||||
override val zero: Complex = 0.0.toComplex()
|
||||
override val one: Complex = 1.0.toComplex()
|
||||
|
||||
/**
|
||||
* The imaginary unit constant.
|
||||
* The imaginary unit.
|
||||
*/
|
||||
val i = Complex(0, 1)
|
||||
val i: Complex = Complex(0.0, 1.0)
|
||||
|
||||
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
||||
|
||||
override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
|
||||
|
||||
override fun multiply(a: Complex, b: Complex): Complex =
|
||||
@ -111,21 +112,62 @@ object ComplexField : ExtendedField<Complex> {
|
||||
exp(pow * ln(arg))
|
||||
|
||||
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
|
||||
|
||||
override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re)
|
||||
|
||||
operator fun Double.plus(c: Complex): Complex = add(toComplex(), c)
|
||||
operator fun Double.minus(c: Complex): Complex = add(toComplex(), -c)
|
||||
/**
|
||||
* Adds complex number to real one.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param c the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c)
|
||||
|
||||
/**
|
||||
* Subtracts complex number from real one.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param c the subtrahend.
|
||||
* @return the difference.
|
||||
*/
|
||||
operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c)
|
||||
|
||||
/**
|
||||
* Adds real number to complex one.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param d the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
operator fun Complex.plus(d: Double): Complex = d + this
|
||||
|
||||
/**
|
||||
* Subtracts real number from complex one.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param d the subtrahend.
|
||||
* @return the difference.
|
||||
*/
|
||||
operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex())
|
||||
|
||||
/**
|
||||
* Multiplies real number by complex one.
|
||||
*
|
||||
* @receiver the multiplier.
|
||||
* @param c the multiplicand.
|
||||
* @receiver the product.
|
||||
*/
|
||||
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
|
||||
|
||||
override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Complex number class.
|
||||
* Represents complex number.
|
||||
*
|
||||
* @property re the real part of the number.
|
||||
* @property im the imaginary part of the number.
|
||||
* @property re The real part.
|
||||
* @property im The imaginary part.
|
||||
*/
|
||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
||||
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
||||
@ -133,7 +175,9 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
|
||||
override val context: ComplexField get() = ComplexField
|
||||
|
||||
override fun unwrap(): Complex = this
|
||||
|
||||
override fun Complex.wrap(): Complex = this
|
||||
|
||||
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
|
||||
|
||||
companion object : MemorySpec<Complex> {
|
||||
@ -150,12 +194,17 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a [Complex] with real part equal to this number.
|
||||
* Creates a complex number with real part equal to this real.
|
||||
*
|
||||
* @receiver the real part.
|
||||
* @return the new complex number.
|
||||
*/
|
||||
fun Number.toComplex(): Complex = Complex(this, 0.0)
|
||||
fun Number.toComplex() = Complex(this, 0.0)
|
||||
|
||||
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> =
|
||||
MemoryBuffer.create(Complex, size, init)
|
||||
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||
return MemoryBuffer.create(Complex, size, init)
|
||||
}
|
||||
|
||||
inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> =
|
||||
MemoryBuffer.create(Complex, size, init)
|
||||
inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||
return MemoryBuffer.create(Complex, size, init)
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import kotlin.math.abs
|
||||
import kotlin.math.pow as kpow
|
||||
|
||||
/**
|
||||
* Advanced Number-like field that implements basic operations
|
||||
* Advanced Number-like semifield that implements basic operations.
|
||||
*/
|
||||
interface ExtendedFieldOperations<T> :
|
||||
TrigonometricOperations<T>,
|
||||
@ -22,12 +22,12 @@ interface ExtendedFieldOperations<T> :
|
||||
TrigonometricOperations.ACOS_OPERATION -> acos(arg)
|
||||
TrigonometricOperations.ASIN_OPERATION -> asin(arg)
|
||||
TrigonometricOperations.ATAN_OPERATION -> atan(arg)
|
||||
HyperbolicTrigonometricOperations.COSH_OPERATION -> cos(arg)
|
||||
HyperbolicTrigonometricOperations.SINH_OPERATION -> sin(arg)
|
||||
HyperbolicTrigonometricOperations.TANH_OPERATION -> tan(arg)
|
||||
HyperbolicTrigonometricOperations.ACOSH_OPERATION -> acos(arg)
|
||||
HyperbolicTrigonometricOperations.ASINH_OPERATION -> asin(arg)
|
||||
HyperbolicTrigonometricOperations.ATANH_OPERATION -> atan(arg)
|
||||
HyperbolicTrigonometricOperations.COSH_OPERATION -> cosh(arg)
|
||||
HyperbolicTrigonometricOperations.SINH_OPERATION -> sinh(arg)
|
||||
HyperbolicTrigonometricOperations.TANH_OPERATION -> tanh(arg)
|
||||
HyperbolicTrigonometricOperations.ACOSH_OPERATION -> acosh(arg)
|
||||
HyperbolicTrigonometricOperations.ASINH_OPERATION -> asinh(arg)
|
||||
HyperbolicTrigonometricOperations.ATANH_OPERATION -> atanh(arg)
|
||||
PowerOperations.SQRT_OPERATION -> sqrt(arg)
|
||||
ExponentialOperations.EXP_OPERATION -> exp(arg)
|
||||
ExponentialOperations.LN_OPERATION -> ln(arg)
|
||||
@ -35,6 +35,10 @@ interface ExtendedFieldOperations<T> :
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Advanced Number-like field that implements basic operations.
|
||||
*/
|
||||
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> power(left, right)
|
||||
@ -45,6 +49,8 @@ interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
/**
|
||||
* Real field element wrapping double.
|
||||
*
|
||||
* @property value the [Double] value wrapped by this [Real].
|
||||
*
|
||||
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
|
||||
*/
|
||||
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
@ -59,7 +65,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for double without boxing. Does not produce appropriate field element
|
||||
* A field for [Double] without boxing. Does not produce appropriate field element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
@ -103,6 +109,9 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
override inline fun Double.div(b: Double): Double = this / b
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Float] without boxing. Does not produce appropriate field element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
override val zero: Float
|
||||
|
@ -1,14 +1,14 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
/**
|
||||
* A container for trigonometric operations for specific type. Trigonometric operations are limited to fields.
|
||||
* A container for trigonometric operations for specific type. They are limited to semifields.
|
||||
*
|
||||
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
|
||||
* It also allows to override behavior for optional operations.
|
||||
*/
|
||||
interface TrigonometricOperations<T> : FieldOperations<T> {
|
||||
/**
|
||||
* Computes the sine of [arg] .
|
||||
* Computes the sine of [arg].
|
||||
*/
|
||||
fun sin(arg: T): T
|
||||
|
||||
@ -38,20 +38,65 @@ interface TrigonometricOperations<T> : FieldOperations<T> {
|
||||
fun atan(arg: T): T
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* The identifier of sine.
|
||||
*/
|
||||
const val SIN_OPERATION = "sin"
|
||||
|
||||
/**
|
||||
* The identifier of cosine.
|
||||
*/
|
||||
const val COS_OPERATION = "cos"
|
||||
|
||||
/**
|
||||
* The identifier of tangent.
|
||||
*/
|
||||
const val TAN_OPERATION = "tan"
|
||||
|
||||
/**
|
||||
* The identifier of inverse sine.
|
||||
*/
|
||||
const val ASIN_OPERATION = "asin"
|
||||
|
||||
/**
|
||||
* The identifier of inverse cosine.
|
||||
*/
|
||||
const val ACOS_OPERATION = "acos"
|
||||
/**
|
||||
* The identifier of inverse tangent.
|
||||
*/
|
||||
const val ATAN_OPERATION = "atan"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the sine of [arg].
|
||||
*/
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
|
||||
|
||||
/**
|
||||
* Computes the cosine of [arg].
|
||||
*/
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
|
||||
|
||||
/**
|
||||
* Computes the tangent of [arg].
|
||||
*/
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse sine of [arg].
|
||||
*/
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse cosine of [arg].
|
||||
*/
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse tangent of [arg].
|
||||
*/
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
|
||||
|
||||
/**
|
||||
@ -110,39 +155,104 @@ fun <T : MathElement<out HyperbolicTrigonometricOperations<T>>> acosh(arg: T): T
|
||||
fun <T : MathElement<out HyperbolicTrigonometricOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg)
|
||||
|
||||
/**
|
||||
* A context extension to include power operations like square roots, etc
|
||||
* A context extension to include power operations based on exponentiation.
|
||||
*/
|
||||
interface PowerOperations<T> : Algebra<T> {
|
||||
/**
|
||||
* Raises [arg] to the power [pow].
|
||||
*/
|
||||
fun power(arg: T, pow: Number): T
|
||||
fun sqrt(arg: T) = power(arg, 0.5)
|
||||
|
||||
infix fun T.pow(pow: Number) = power(this, pow)
|
||||
/**
|
||||
* Computes the square root of the value [arg].
|
||||
*/
|
||||
fun sqrt(arg: T): T = power(arg, 0.5)
|
||||
|
||||
/**
|
||||
* Raises this value to the power [pow].
|
||||
*/
|
||||
infix fun T.pow(pow: Number): T = power(this, pow)
|
||||
|
||||
companion object {
|
||||
const val POW_OPERATION = "pow"
|
||||
const val SQRT_OPERATION = "sqrt"
|
||||
/**
|
||||
* The identifier of exponentiation.
|
||||
*/
|
||||
const val POW_OPERATION: String = "pow"
|
||||
|
||||
/**
|
||||
* The identifier of square root.
|
||||
*/
|
||||
const val SQRT_OPERATION: String = "sqrt"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Raises this element to the power [pow].
|
||||
*
|
||||
* @receiver the base.
|
||||
* @param power the exponent.
|
||||
* @return the base raised to the power.
|
||||
*/
|
||||
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
||||
|
||||
/**
|
||||
* Computes the square root of the value [arg].
|
||||
*/
|
||||
fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
|
||||
|
||||
/**
|
||||
* Computes the square of the value [arg].
|
||||
*/
|
||||
fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
|
||||
|
||||
/**
|
||||
* A container for operations related to `exp` and `ln` functions.
|
||||
*/
|
||||
interface ExponentialOperations<T> : Algebra<T> {
|
||||
/**
|
||||
* Computes Euler's number `e` raised to the power of the value [arg].
|
||||
*/
|
||||
fun exp(arg: T): T
|
||||
|
||||
/**
|
||||
* Computes the natural logarithm (base `e`) of the value [arg].
|
||||
*/
|
||||
fun ln(arg: T): T
|
||||
|
||||
companion object {
|
||||
const val EXP_OPERATION = "exp"
|
||||
const val LN_OPERATION = "ln"
|
||||
/**
|
||||
* The identifier of exponential function.
|
||||
*/
|
||||
const val EXP_OPERATION: String = "exp"
|
||||
|
||||
/**
|
||||
* The identifier of natural logarithm.
|
||||
*/
|
||||
const val LN_OPERATION: String = "ln"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The identifier of exponential function.
|
||||
*/
|
||||
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
|
||||
|
||||
/**
|
||||
* The identifier of natural logarithm.
|
||||
*/
|
||||
fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
|
||||
|
||||
/**
|
||||
* A container for norm functional on element.
|
||||
*/
|
||||
interface Norm<in T : Any, out R> {
|
||||
/**
|
||||
* Computes the norm of [arg] (i.e. absolute value or vector length).
|
||||
*/
|
||||
fun norm(arg: T): R
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the norm of [arg] (i.e. absolute value or vector length).
|
||||
*/
|
||||
fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)
|
||||
|
@ -3,7 +3,6 @@ package scientifik.kmath.structures
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
|
||||
|
||||
class BoxingNDField<T, F : Field<T>>(
|
||||
override val shape: IntArray,
|
||||
override val elementContext: F,
|
||||
@ -19,10 +18,10 @@ class BoxingNDField<T, F : Field<T>>(
|
||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||
}
|
||||
|
||||
override val zero by lazy { produce { zero } }
|
||||
override val one by lazy { produce { one } }
|
||||
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
||||
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
|
||||
|
||||
override fun produce(initializer: F.(IntArray) -> T) =
|
||||
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
||||
BufferedNDFieldElement(
|
||||
this,
|
||||
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
||||
@ -79,4 +78,4 @@ inline fun <T : Any, F : Field<T>, R> F.nd(
|
||||
): R {
|
||||
val ndfield: BoxingNDField<T, F> = NDField.boxing(this, *shape, bufferFactory = bufferFactory)
|
||||
return ndfield.action()
|
||||
}
|
||||
}
|
||||
|
@ -18,10 +18,10 @@ class BoxingNDRing<T, R : Ring<T>>(
|
||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||
}
|
||||
|
||||
override val zero by lazy { produce { zero } }
|
||||
override val one by lazy { produce { one } }
|
||||
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
|
||||
override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
|
||||
|
||||
override fun produce(initializer: R.(IntArray) -> T) =
|
||||
override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
|
||||
BufferedNDRingElement(
|
||||
this,
|
||||
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
||||
@ -69,4 +69,4 @@ class BoxingNDRing<T, R : Ring<T>>(
|
||||
|
||||
override fun NDBuffer<T>.toElement(): RingElement<NDBuffer<T>, *, out BufferedNDRing<T, R>> =
|
||||
BufferedNDRingElement(this@BoxingNDRing, buffer)
|
||||
}
|
||||
}
|
||||
|
@ -7,16 +7,16 @@ import kotlin.reflect.KClass
|
||||
*/
|
||||
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
||||
|
||||
operator fun Buffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
|
||||
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
|
||||
|
||||
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
||||
set(i + colNum * j, value)
|
||||
}
|
||||
|
||||
inline fun create(init: (i: Int, j: Int) -> T) =
|
||||
inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer<T> =
|
||||
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||
|
||||
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] }
|
||||
fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
|
||||
|
||||
//TODO optimize wrapper
|
||||
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
||||
@ -41,5 +41,5 @@ class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum
|
||||
/**
|
||||
* Get row
|
||||
*/
|
||||
fun MutableBuffer<T>.row(i: Int) = Row(this, i)
|
||||
fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
||||
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||
val strides: Strides
|
||||
|
||||
override fun check(vararg elements: NDBuffer<T>) {
|
||||
@ -11,7 +11,8 @@ interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
||||
|
||||
/**
|
||||
* Convert any [NDStructure] to buffered structure using strides from this context.
|
||||
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes
|
||||
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over
|
||||
* indices.
|
||||
*
|
||||
* If the argument is [NDBuffer] with different strides structure, the new element will be produced.
|
||||
*/
|
||||
@ -30,7 +31,7 @@ interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
||||
}
|
||||
|
||||
|
||||
interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T,S> {
|
||||
interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
|
||||
override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@ import scientifik.kmath.operations.*
|
||||
abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> {
|
||||
abstract override val context: BufferedNDAlgebra<T, C>
|
||||
|
||||
override val strides get() = context.strides
|
||||
override val strides: Strides get() = context.strides
|
||||
|
||||
override val shape: IntArray get() = context.shape
|
||||
}
|
||||
@ -54,9 +54,9 @@ class BufferedNDFieldElement<T, F : Field<T>>(
|
||||
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy.
|
||||
*/
|
||||
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>) =
|
||||
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>): MathElement<out BufferedNDAlgebra<T, F>> =
|
||||
ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
|
||||
|
||||
/* plus and minus */
|
||||
@ -64,13 +64,13 @@ operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedN
|
||||
/**
|
||||
* Summation operation for [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T) =
|
||||
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||
context.map(this) { it + arg }.wrap()
|
||||
|
||||
/**
|
||||
* Subtraction operation between [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) =
|
||||
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||
context.map(this) { it - arg }.wrap()
|
||||
|
||||
/* prod and div */
|
||||
@ -78,11 +78,11 @@ operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T) =
|
||||
/**
|
||||
* Product operation for [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T) =
|
||||
operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||
context.map(this) { it * arg }.wrap()
|
||||
|
||||
/**
|
||||
* Division operation between [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T) =
|
||||
context.map(this) { it / arg }.wrap()
|
||||
operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T): NDElement<T, F, NDBuffer<T>> =
|
||||
context.map(this) { it / arg }.wrap()
|
||||
|
@ -4,39 +4,48 @@ import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.complex
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
|
||||
/**
|
||||
* Function that produces [Buffer] from its size and function that supplies values.
|
||||
*
|
||||
* @param T the type of buffer.
|
||||
*/
|
||||
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
||||
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
||||
|
||||
|
||||
/**
|
||||
* A generic random access structure for both primitives and objects
|
||||
* Function that produces [MutableBuffer] from its size and function that supplies values.
|
||||
*
|
||||
* @param T the type of buffer.
|
||||
*/
|
||||
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
||||
|
||||
/**
|
||||
* A generic immutable random-access structure for both primitives and objects.
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
*/
|
||||
interface Buffer<T> {
|
||||
|
||||
/**
|
||||
* The size of the buffer
|
||||
* The size of this buffer.
|
||||
*/
|
||||
val size: Int
|
||||
|
||||
/**
|
||||
* Get element at given index
|
||||
* Gets element at given index.
|
||||
*/
|
||||
operator fun get(index: Int): T
|
||||
|
||||
/**
|
||||
* Iterate over all elements
|
||||
* Iterates over all elements.
|
||||
*/
|
||||
operator fun iterator(): Iterator<T>
|
||||
|
||||
/**
|
||||
* Check content eqiality with another buffer
|
||||
* Checks content equality with another buffer.
|
||||
*/
|
||||
fun contentEquals(other: Buffer<*>): Boolean =
|
||||
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
|
||||
|
||||
companion object {
|
||||
|
||||
inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
|
||||
val array = DoubleArray(size) { initializer(it) }
|
||||
return RealBuffer(array)
|
||||
@ -69,17 +78,34 @@ interface Buffer<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a sequence that returns all elements from this [Buffer].
|
||||
*/
|
||||
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
|
||||
|
||||
/**
|
||||
* Creates an iterable that returns all elements from this [Buffer].
|
||||
*/
|
||||
fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator)
|
||||
|
||||
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1)
|
||||
/**
|
||||
* Returns an [IntRange] of the valid indices for this [Buffer].
|
||||
*/
|
||||
val Buffer<*>.indices: IntRange get() = 0 until size
|
||||
|
||||
/**
|
||||
* A generic mutable random-access structure for both primitives and objects.
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
*/
|
||||
interface MutableBuffer<T> : Buffer<T> {
|
||||
/**
|
||||
* Sets the array element at the specified [index] to the specified [value].
|
||||
*/
|
||||
operator fun set(index: Int, value: T)
|
||||
|
||||
/**
|
||||
* A shallow copy of the buffer
|
||||
* Returns a shallow copy of the buffer.
|
||||
*/
|
||||
fun copy(): MutableBuffer<T>
|
||||
|
||||
@ -114,8 +140,13 @@ interface MutableBuffer<T> : Buffer<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* [Buffer] implementation over [List].
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @property list The underlying list.
|
||||
*/
|
||||
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
||||
|
||||
override val size: Int
|
||||
get() = list.size
|
||||
|
||||
@ -124,11 +155,26 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
||||
override fun iterator(): Iterator<T> = list.iterator()
|
||||
}
|
||||
|
||||
fun <T> List<T>.asBuffer() = ListBuffer<T>(this)
|
||||
/**
|
||||
* Returns an [ListBuffer] that wraps the original list.
|
||||
*/
|
||||
fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
||||
|
||||
@Suppress("FunctionName")
|
||||
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
|
||||
/**
|
||||
* Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||
* [init] function.
|
||||
*
|
||||
* The function [init] is called for each array element sequentially starting from the first one.
|
||||
* It should return the value for an array element given its index.
|
||||
*/
|
||||
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer()
|
||||
|
||||
/**
|
||||
* [MutableBuffer] implementation over [MutableList].
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @property list The underlying list.
|
||||
*/
|
||||
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
||||
|
||||
override val size: Int
|
||||
@ -144,8 +190,14 @@ inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
||||
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
||||
}
|
||||
|
||||
/**
|
||||
* [MutableBuffer] implementation over [Array].
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @property array The underlying array.
|
||||
*/
|
||||
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
||||
//Can't inline because array is invariant
|
||||
// Can't inline because array is invariant
|
||||
override val size: Int
|
||||
get() = array.size
|
||||
|
||||
@ -160,19 +212,30 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
||||
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an [ArrayBuffer] that wraps the original array.
|
||||
*/
|
||||
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
|
||||
|
||||
/**
|
||||
* Immutable wrapper for [MutableBuffer].
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @property buffer The underlying buffer.
|
||||
*/
|
||||
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
||||
override val size: Int get() = buffer.size
|
||||
|
||||
override fun get(index: Int): T = buffer.get(index)
|
||||
override fun get(index: Int): T = buffer[index]
|
||||
|
||||
override fun iterator() = buffer.iterator()
|
||||
override fun iterator(): Iterator<T> = buffer.iterator()
|
||||
}
|
||||
|
||||
/**
|
||||
* A buffer with content calculated on-demand. The calculated contect is not stored, so it is recalculated on each call.
|
||||
* A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call.
|
||||
* Useful when one needs single element from the buffer.
|
||||
*
|
||||
* @param T the type of elements provided by the buffer.
|
||||
*/
|
||||
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
||||
override fun get(index: Int): T {
|
||||
@ -192,17 +255,16 @@ class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert this buffer to read-only buffer
|
||||
* Convert this buffer to read-only buffer.
|
||||
*/
|
||||
fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) {
|
||||
ReadOnlyBuffer(this)
|
||||
} else {
|
||||
this
|
||||
}
|
||||
fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) ReadOnlyBuffer(this) else this
|
||||
|
||||
/**
|
||||
* Typealias for buffer transformations
|
||||
* Typealias for buffer transformations.
|
||||
*/
|
||||
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
|
||||
|
||||
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>
|
||||
/**
|
||||
* Typealias for buffer transformations with suspend function.
|
||||
*/
|
||||
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>
|
||||
|
@ -16,8 +16,8 @@ class ComplexNDField(override val shape: IntArray) :
|
||||
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
override val elementContext: ComplexField get() = ComplexField
|
||||
override val zero by lazy { produce { zero } }
|
||||
override val one by lazy { produce { one } }
|
||||
override val zero: ComplexNDElement by lazy { produce { zero } }
|
||||
override val one: ComplexNDElement by lazy { produce { one } }
|
||||
|
||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
|
||||
Buffer.complex(size) { initializer(it) }
|
||||
@ -101,13 +101,13 @@ inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline init
|
||||
}
|
||||
|
||||
/**
|
||||
* Map one [ComplexNDElement] using function with indexes
|
||||
* Map one [ComplexNDElement] using function with indices.
|
||||
*/
|
||||
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex) =
|
||||
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
|
||||
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
||||
|
||||
/**
|
||||
* Map one [ComplexNDElement] using function without indexes
|
||||
* Map one [ComplexNDElement] using function without indices.
|
||||
*/
|
||||
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
||||
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
||||
@ -117,7 +117,7 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) ->
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
*/
|
||||
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
|
||||
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
|
||||
ndElement.map { this@invoke(it) }
|
||||
|
||||
|
||||
@ -126,19 +126,18 @@ operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
|
||||
/**
|
||||
* Summation operation for [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun ComplexNDElement.plus(arg: Complex) =
|
||||
map { it + arg }
|
||||
operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun ComplexNDElement.minus(arg: Complex) =
|
||||
operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement =
|
||||
map { it - arg }
|
||||
|
||||
operator fun ComplexNDElement.plus(arg: Double) =
|
||||
operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement =
|
||||
map { it + arg }
|
||||
|
||||
operator fun ComplexNDElement.minus(arg: Double) =
|
||||
operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement =
|
||||
map { it - arg }
|
||||
|
||||
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
||||
@ -151,4 +150,4 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In
|
||||
*/
|
||||
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
|
||||
return NDField.complex(*shape).run(action)
|
||||
}
|
||||
}
|
||||
|
@ -2,9 +2,15 @@ package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
|
||||
/**
|
||||
* [ExtendedField] over [NDStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param F the extended field of structure elements.
|
||||
*/
|
||||
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
|
||||
|
||||
|
||||
///**
|
||||
// * NDField that supports [ExtendedField] operations on its elements
|
||||
// */
|
||||
@ -36,5 +42,3 @@ interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : N
|
||||
// return produce { with(elementContext) { cos(arg[it]) } }
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
|
@ -2,15 +2,35 @@ package scientifik.kmath.structures
|
||||
|
||||
import kotlin.experimental.and
|
||||
|
||||
/**
|
||||
* Represents flags to supply additional info about values of buffer.
|
||||
*
|
||||
* @property mask bit mask value of this flag.
|
||||
*/
|
||||
enum class ValueFlag(val mask: Byte) {
|
||||
/**
|
||||
* Reports the value is NaN.
|
||||
*/
|
||||
NAN(0b0000_0001),
|
||||
|
||||
/**
|
||||
* Reports the value doesn't present in the buffer (when the type of value doesn't support `null`).
|
||||
*/
|
||||
MISSING(0b0000_0010),
|
||||
|
||||
/**
|
||||
* Reports the value is negative infinity.
|
||||
*/
|
||||
NEGATIVE_INFINITY(0b0000_0100),
|
||||
|
||||
/**
|
||||
* Reports the value is positive infinity
|
||||
*/
|
||||
POSITIVE_INFINITY(0b0000_1000)
|
||||
}
|
||||
|
||||
/**
|
||||
* A buffer with flagged values
|
||||
* A buffer with flagged values.
|
||||
*/
|
||||
interface FlaggedBuffer<T> : Buffer<T> {
|
||||
fun getFlag(index: Int): Byte
|
||||
@ -19,11 +39,11 @@ interface FlaggedBuffer<T> : Buffer<T> {
|
||||
/**
|
||||
* The value is valid if all flags are down
|
||||
*/
|
||||
fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte()
|
||||
fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
|
||||
|
||||
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte()
|
||||
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
|
||||
|
||||
fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING)
|
||||
fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
|
||||
|
||||
/**
|
||||
* A real buffer which supports flags for each value like NaN or Missing
|
||||
@ -45,9 +65,9 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
|
||||
}
|
||||
|
||||
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
||||
for(i in indices){
|
||||
if(isValid(i)){
|
||||
for (i in indices) {
|
||||
if (isValid(i)) {
|
||||
block(values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,49 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
/**
|
||||
* Specialized [MutableBuffer] implementation over [FloatArray].
|
||||
*
|
||||
* @property array the underlying array.
|
||||
*/
|
||||
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Float = array[index]
|
||||
|
||||
override fun set(index: Int, value: Float) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator(): FloatIterator = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Float> =
|
||||
FloatBuffer(array.copyOf())
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new [FloatBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||
* [init] function.
|
||||
*
|
||||
* The function [init] is called for each array element sequentially starting from the first one.
|
||||
* It should return the value for an buffer element given its index.
|
||||
*/
|
||||
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) })
|
||||
|
||||
/**
|
||||
* Returns a new [FloatBuffer] of given elements.
|
||||
*/
|
||||
fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats)
|
||||
|
||||
/**
|
||||
* Returns a [FloatArray] containing all of the elements of this [MutableBuffer].
|
||||
*/
|
||||
val MutableBuffer<out Float>.array: FloatArray
|
||||
get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) })
|
||||
|
||||
/**
|
||||
* Returns [FloatBuffer] over this array.
|
||||
*
|
||||
* @receiver the array.
|
||||
* @return the new buffer.
|
||||
*/
|
||||
fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this)
|
@ -1,5 +1,10 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
/**
|
||||
* Specialized [MutableBuffer] implementation over [IntArray].
|
||||
*
|
||||
* @property array the underlying array.
|
||||
*/
|
||||
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
@ -9,12 +14,37 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
override fun iterator(): IntIterator = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Int> =
|
||||
IntBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new [IntBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||
* [init] function.
|
||||
*
|
||||
* The function [init] is called for each array element sequentially starting from the first one.
|
||||
* It should return the value for an buffer element given its index.
|
||||
*/
|
||||
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) })
|
||||
|
||||
fun IntArray.asBuffer() = IntBuffer(this)
|
||||
/**
|
||||
* Returns a new [IntBuffer] of given elements.
|
||||
*/
|
||||
fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints)
|
||||
|
||||
/**
|
||||
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
|
||||
*/
|
||||
val MutableBuffer<out Int>.array: IntArray
|
||||
get() = (if (this is IntBuffer) array else IntArray(size) { get(it) })
|
||||
|
||||
/**
|
||||
* Returns [IntBuffer] over this array.
|
||||
*
|
||||
* @receiver the array.
|
||||
* @return the new buffer.
|
||||
*/
|
||||
fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)
|
||||
|
@ -1,5 +1,10 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
/**
|
||||
* Specialized [MutableBuffer] implementation over [LongArray].
|
||||
*
|
||||
* @property array the underlying array.
|
||||
*/
|
||||
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
@ -9,11 +14,37 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
override fun iterator(): LongIterator = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Long> =
|
||||
LongBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
fun LongArray.asBuffer() = LongBuffer(this)
|
||||
/**
|
||||
* Creates a new [LongBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||
* [init] function.
|
||||
*
|
||||
* The function [init] is called for each array element sequentially starting from the first one.
|
||||
* It should return the value for an buffer element given its index.
|
||||
*/
|
||||
inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) })
|
||||
|
||||
/**
|
||||
* Returns a new [LongBuffer] of given elements.
|
||||
*/
|
||||
fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs)
|
||||
|
||||
/**
|
||||
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
|
||||
*/
|
||||
val MutableBuffer<out Long>.array: LongArray
|
||||
get() = (if (this is LongBuffer) array else LongArray(size) { get(it) })
|
||||
|
||||
/**
|
||||
* Returns [LongBuffer] over this array.
|
||||
*
|
||||
* @receiver the array.
|
||||
* @return the new buffer.
|
||||
*/
|
||||
fun LongArray.asBuffer(): LongBuffer = LongBuffer(this)
|
||||
|
@ -3,13 +3,16 @@ package scientifik.kmath.structures
|
||||
import scientifik.memory.*
|
||||
|
||||
/**
|
||||
* A non-boxing buffer based on [ByteBuffer] storage
|
||||
* A non-boxing buffer over [Memory] object.
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @property memory the underlying memory segment.
|
||||
* @property spec the spec of [T] type.
|
||||
*/
|
||||
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
|
||||
|
||||
override val size: Int get() = memory.size / spec.objectSize
|
||||
|
||||
private val reader = memory.reader()
|
||||
private val reader: MemoryReader = memory.reader()
|
||||
|
||||
override fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
||||
|
||||
@ -17,7 +20,7 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
|
||||
|
||||
|
||||
companion object {
|
||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int) =
|
||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
||||
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
||||
|
||||
inline fun <T : Any> create(
|
||||
@ -33,28 +36,35 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A mutable non-boxing buffer over [Memory] object.
|
||||
*
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @property memory the underlying memory segment.
|
||||
* @property spec the spec of [T] type.
|
||||
*/
|
||||
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
|
||||
MutableBuffer<T> {
|
||||
|
||||
private val writer = memory.writer()
|
||||
private val writer: MemoryWriter = memory.writer()
|
||||
|
||||
override fun set(index: Int, value: T) = writer.write(spec, spec.objectSize * index, value)
|
||||
override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
|
||||
|
||||
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
|
||||
|
||||
companion object {
|
||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int) =
|
||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
|
||||
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
||||
|
||||
inline fun <T : Any> create(
|
||||
spec: MemorySpec<T>,
|
||||
size: Int,
|
||||
crossinline initializer: (Int) -> T
|
||||
) =
|
||||
): MutableMemoryBuffer<T> =
|
||||
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
||||
(0 until size).forEach {
|
||||
buffer[it] = initializer(it)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ interface NDAlgebra<T, C, N : NDStructure<T>> {
|
||||
/**
|
||||
* element-by-element invoke a function working on [T] on a [NDStructure]
|
||||
*/
|
||||
operator fun Function1<T, T>.invoke(structure: N) = map(structure) { value -> this@invoke(value) }
|
||||
operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
|
||||
|
||||
companion object
|
||||
}
|
||||
@ -76,12 +76,12 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
|
||||
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
||||
operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
|
||||
|
||||
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
||||
operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
|
||||
|
||||
operator fun T.plus(arg: N) = map(arg) { value -> add(this@plus, value) }
|
||||
operator fun T.minus(arg: N) = map(arg) { value -> add(-this@minus, value) }
|
||||
operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
|
||||
operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
|
||||
|
||||
companion object
|
||||
}
|
||||
@ -97,20 +97,19 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
|
||||
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) }
|
||||
operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
|
||||
|
||||
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) }
|
||||
operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
|
||||
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* Field for n-dimensional structures.
|
||||
* @param shape - the list of dimensions of the array
|
||||
* @param elementField - operations field defined on individual array element
|
||||
* @param T - the type of the element contained in ND structure
|
||||
* @param F - field of structure elements
|
||||
* @param R - actual nd-element type of this field
|
||||
* Field of [NDStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param F field of structure elements.
|
||||
*/
|
||||
interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
|
||||
|
||||
@ -120,9 +119,9 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
||||
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) }
|
||||
operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
|
||||
|
||||
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
||||
operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
|
||||
|
||||
companion object {
|
||||
|
||||
@ -131,7 +130,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
||||
/**
|
||||
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
||||
*/
|
||||
fun real(vararg shape: Int) = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
||||
fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
||||
|
||||
/**
|
||||
* Create a nd-field with boxing generic buffer
|
||||
@ -140,7 +139,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
||||
field: F,
|
||||
vararg shape: Int,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||
) = BoxingNDField(shape, field, bufferFactory)
|
||||
): BoxingNDField<T, F> = BoxingNDField(shape, field, bufferFactory)
|
||||
|
||||
/**
|
||||
* Create a most suitable implementation for nd-field using reified class.
|
||||
|
@ -23,19 +23,23 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
||||
/**
|
||||
* Create a optimized NDArray of doubles
|
||||
*/
|
||||
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }) =
|
||||
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
|
||||
NDField.real(*shape).produce(initializer)
|
||||
|
||||
|
||||
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) =
|
||||
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
|
||||
real(intArrayOf(dim)) { initializer(it[0]) }
|
||||
|
||||
|
||||
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) =
|
||||
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement =
|
||||
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||
|
||||
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) =
|
||||
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
fun real3D(
|
||||
dim1: Int,
|
||||
dim2: Int,
|
||||
dim3: Int,
|
||||
initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }
|
||||
): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
|
||||
|
||||
/**
|
||||
@ -62,16 +66,17 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
||||
}
|
||||
|
||||
|
||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T) =
|
||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
|
||||
context.mapIndexed(unwrap(), transform).wrap()
|
||||
|
||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap()
|
||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
|
||||
context.map(unwrap(), transform).wrap()
|
||||
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole [NDElement]
|
||||
*/
|
||||
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>) =
|
||||
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> =
|
||||
ndElement.map { value -> this@invoke(value) }
|
||||
|
||||
/* plus and minus */
|
||||
@ -79,13 +84,13 @@ operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElem
|
||||
/**
|
||||
* Summation operation for [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T) =
|
||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> =
|
||||
map { value -> arg + value }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T) =
|
||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> =
|
||||
map { value -> arg - value }
|
||||
|
||||
/* prod and div */
|
||||
@ -93,13 +98,13 @@ operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg:
|
||||
/**
|
||||
* Product operation for [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T) =
|
||||
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> =
|
||||
map { value -> arg * value }
|
||||
|
||||
/**
|
||||
* Division operation between [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T) =
|
||||
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
|
||||
map { value -> arg / value }
|
||||
|
||||
|
||||
|
@ -3,15 +3,38 @@ package scientifik.kmath.structures
|
||||
import kotlin.jvm.JvmName
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
|
||||
/**
|
||||
* Represents n-dimensional structure, i.e. multidimensional container of items of the same type and size. The number
|
||||
* of dimensions and items in an array is defined by its shape, which is a sequence of non-negative integers that
|
||||
* specify the sizes of each dimension.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
interface NDStructure<T> {
|
||||
|
||||
/**
|
||||
* The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of
|
||||
* this structure.
|
||||
*/
|
||||
val shape: IntArray
|
||||
|
||||
val dimension get() = shape.size
|
||||
/**
|
||||
* The count of dimensions in this structure. It should be equal to size of [shape].
|
||||
*/
|
||||
val dimension: Int get() = shape.size
|
||||
|
||||
/**
|
||||
* Returns the value at the specified indices.
|
||||
*
|
||||
* @param index the indices.
|
||||
* @return the value.
|
||||
*/
|
||||
operator fun get(index: IntArray): T
|
||||
|
||||
/**
|
||||
* Returns the sequence of all the elements associated by their indices.
|
||||
*
|
||||
* @return the lazy sequence of pairs of indices to values.
|
||||
*/
|
||||
fun elements(): Sequence<Pair<IntArray, T>>
|
||||
|
||||
override fun equals(other: Any?): Boolean
|
||||
@ -19,6 +42,9 @@ interface NDStructure<T> {
|
||||
override fun hashCode(): Int
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Indicates whether some [NDStructure] is equal to another one.
|
||||
*/
|
||||
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||
if (st1 === st2) return true
|
||||
|
||||
@ -36,47 +62,79 @@ interface NDStructure<T> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a NDStructure with explicit buffer factory
|
||||
* Creates a NDStructure with explicit buffer factory.
|
||||
*
|
||||
* Strides should be reused if possible
|
||||
* Strides should be reused if possible.
|
||||
*/
|
||||
fun <T> build(
|
||||
strides: Strides,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
): BufferNDStructure<T> =
|
||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
/**
|
||||
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
inline fun <reified T : Any> auto(
|
||||
strides: Strides,
|
||||
crossinline initializer: (IntArray) -> T
|
||||
): BufferNDStructure<T> =
|
||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
inline fun <T : Any> auto(type: KClass<T>, strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
inline fun <T : Any> auto(
|
||||
type: KClass<T>,
|
||||
strides: Strides,
|
||||
crossinline initializer: (IntArray) -> T
|
||||
): BufferNDStructure<T> =
|
||||
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
fun <T> build(
|
||||
shape: IntArray,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
inline fun <reified T : Any> auto(
|
||||
shape: IntArray,
|
||||
crossinline initializer: (IntArray) -> T
|
||||
): BufferNDStructure<T> =
|
||||
auto(DefaultStrides(shape), initializer)
|
||||
|
||||
@JvmName("autoVarArg")
|
||||
inline fun <reified T : Any> auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||
inline fun <reified T : Any> auto(
|
||||
vararg shape: Int,
|
||||
crossinline initializer: (IntArray) -> T
|
||||
): BufferNDStructure<T> =
|
||||
auto(DefaultStrides(shape), initializer)
|
||||
|
||||
inline fun <T : Any> auto(type: KClass<T>, vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||
inline fun <T : Any> auto(
|
||||
type: KClass<T>,
|
||||
vararg shape: Int,
|
||||
crossinline initializer: (IntArray) -> T
|
||||
): BufferNDStructure<T> =
|
||||
auto(type, DefaultStrides(shape), initializer)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value at the specified indices.
|
||||
*
|
||||
* @param index the indices.
|
||||
* @return the value.
|
||||
*/
|
||||
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
||||
|
||||
/**
|
||||
* Represents mutable [NDStructure].
|
||||
*/
|
||||
interface MutableNDStructure<T> : NDStructure<T> {
|
||||
/**
|
||||
* Inserts an item at the specified indices.
|
||||
*
|
||||
* @param index the indices.
|
||||
* @param value the value.
|
||||
*/
|
||||
operator fun set(index: IntArray, value: T)
|
||||
}
|
||||
|
||||
@ -87,7 +145,7 @@ inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
||||
}
|
||||
|
||||
/**
|
||||
* A way to convert ND index to linear one and back
|
||||
* A way to convert ND index to linear one and back.
|
||||
*/
|
||||
interface Strides {
|
||||
/**
|
||||
@ -124,11 +182,14 @@ interface Strides {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple implementation of [Strides].
|
||||
*/
|
||||
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
override val strides by lazy {
|
||||
override val strides: List<Int> by lazy {
|
||||
sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
@ -163,19 +224,14 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
||||
override val linearSize: Int
|
||||
get() = strides[shape.size]
|
||||
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is DefaultStrides) return false
|
||||
|
||||
if (!shape.contentEquals(other.shape)) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
return shape.contentHashCode()
|
||||
}
|
||||
override fun hashCode(): Int = shape.contentHashCode()
|
||||
|
||||
companion object {
|
||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||
@ -187,8 +243,20 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDStructure] over [Buffer].
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
abstract class NDBuffer<T> : NDStructure<T> {
|
||||
/**
|
||||
* The underlying buffer.
|
||||
*/
|
||||
abstract val buffer: Buffer<T>
|
||||
|
||||
/**
|
||||
* The strides to access elements of [Buffer] by linear indices.
|
||||
*/
|
||||
abstract val strides: Strides
|
||||
|
||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||
@ -238,7 +306,7 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
||||
}
|
||||
|
||||
/**
|
||||
* Mutable ND buffer based on linear [autoBuffer]
|
||||
* Mutable ND buffer based on linear [MutableBuffer].
|
||||
*/
|
||||
class MutableBufferNDStructure<T>(
|
||||
override val strides: Strides,
|
||||
@ -246,12 +314,12 @@ class MutableBufferNDStructure<T>(
|
||||
) : NDBuffer<T>(), MutableNDStructure<T> {
|
||||
|
||||
init {
|
||||
if (strides.linearSize != buffer.size) {
|
||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
||||
require(strides.linearSize == buffer.size) {
|
||||
"Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
|
||||
}
|
||||
}
|
||||
|
||||
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value)
|
||||
override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
||||
}
|
||||
|
||||
inline fun <reified T : Any> NDStructure<T>.combine(
|
||||
@ -260,4 +328,4 @@ inline fun <reified T : Any> NDStructure<T>.combine(
|
||||
): NDStructure<T> {
|
||||
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
|
||||
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,10 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
/**
|
||||
* Specialized [MutableBuffer] implementation over [DoubleArray].
|
||||
*
|
||||
* @property array the underlying array.
|
||||
*/
|
||||
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
@ -9,26 +14,36 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
override fun iterator(): DoubleIterator = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Double> =
|
||||
RealBuffer(array.copyOf())
|
||||
}
|
||||
|
||||
@Suppress("FunctionName")
|
||||
/**
|
||||
* Creates a new [RealBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||
* [init] function.
|
||||
*
|
||||
* The function [init] is called for each array element sequentially starting from the first one.
|
||||
* It should return the value for an buffer element given its index.
|
||||
*/
|
||||
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
|
||||
|
||||
@Suppress("FunctionName")
|
||||
/**
|
||||
* Returns a new [RealBuffer] of given elements.
|
||||
*/
|
||||
fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
||||
|
||||
/**
|
||||
* Transform buffer of doubles into array for high performance operations
|
||||
* Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
|
||||
*/
|
||||
val MutableBuffer<out Double>.array: DoubleArray
|
||||
get() = if (this is RealBuffer) {
|
||||
array
|
||||
} else {
|
||||
DoubleArray(size) { get(it) }
|
||||
}
|
||||
get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) })
|
||||
|
||||
fun DoubleArray.asBuffer() = RealBuffer(this)
|
||||
/**
|
||||
* Returns [RealBuffer] over this array.
|
||||
*
|
||||
* @receiver the array.
|
||||
* @return the new buffer.
|
||||
*/
|
||||
fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this)
|
||||
|
@ -6,7 +6,7 @@ import kotlin.math.*
|
||||
|
||||
|
||||
/**
|
||||
* A simple field over linear buffers of [Double].
|
||||
* [ExtendedFieldOperations] over [RealBuffer].
|
||||
*/
|
||||
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
@ -39,7 +39,8 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
||||
} else RealBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||
} else
|
||||
RealBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||
}
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
@ -72,7 +73,9 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
|
||||
} else RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
|
||||
} else {
|
||||
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
|
||||
}
|
||||
|
||||
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
@ -132,6 +135,11 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
} else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
||||
}
|
||||
|
||||
/**
|
||||
* [ExtendedField] over [RealBuffer].
|
||||
*
|
||||
* @property size the size of buffers to operate on.
|
||||
*/
|
||||
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
||||
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||
|
@ -12,7 +12,7 @@ class RealNDField(override val shape: IntArray) :
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
|
||||
override val elementContext: RealField get() = RealField
|
||||
override val zero: BufferedNDFieldElement<Double, RealField> by lazy { produce { zero } }
|
||||
override val zero: RealNDElement by lazy { produce { zero } }
|
||||
override val one: RealNDElement by lazy { produce { one } }
|
||||
|
||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||
@ -67,9 +67,9 @@ class RealNDField(override val shape: IntArray) :
|
||||
|
||||
override fun power(arg: NDBuffer<Double>, pow: Number): RealNDElement = map(arg) { power(it, pow) }
|
||||
|
||||
override fun exp(arg: NDBuffer<Double>) = map(arg) { exp(it) }
|
||||
override fun exp(arg: NDBuffer<Double>): RealNDElement = map(arg) { exp(it) }
|
||||
|
||||
override fun ln(arg: NDBuffer<Double>) = map(arg) { ln(it) }
|
||||
override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) }
|
||||
|
||||
override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) }
|
||||
override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) }
|
||||
@ -96,13 +96,13 @@ inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initiali
|
||||
}
|
||||
|
||||
/**
|
||||
* Map one [RealNDElement] using function with indexes
|
||||
* Map one [RealNDElement] using function with indices.
|
||||
*/
|
||||
inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double) =
|
||||
inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double): RealNDElement =
|
||||
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
||||
|
||||
/**
|
||||
* Map one [RealNDElement] using function without indexes
|
||||
* Map one [RealNDElement] using function without indices.
|
||||
*/
|
||||
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
|
||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
|
||||
@ -110,9 +110,9 @@ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double
|
||||
}
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy.
|
||||
*/
|
||||
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
||||
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement =
|
||||
ndElement.map { this@invoke(it) }
|
||||
|
||||
|
||||
@ -121,14 +121,17 @@ operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
||||
/**
|
||||
* Summation operation for [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun RealNDElement.plus(arg: Double): RealNDElement = map { it + arg }
|
||||
operator fun RealNDElement.plus(arg: Double): RealNDElement =
|
||||
map { it + arg }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun RealNDElement.minus(arg: Double): RealNDElement = map { it - arg }
|
||||
operator fun RealNDElement.minus(arg: Double): RealNDElement =
|
||||
map { it - arg }
|
||||
|
||||
/**
|
||||
* Produce a context for n-dimensional operations inside this real field
|
||||
*/
|
||||
|
||||
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)
|
||||
|
@ -1,5 +1,10 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
/**
|
||||
* Specialized [MutableBuffer] implementation over [ShortBuffer].
|
||||
*
|
||||
* @property array the underlying array.
|
||||
*/
|
||||
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
@ -9,12 +14,37 @@ inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
override fun iterator(): ShortIterator = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Short> =
|
||||
ShortBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new [ShortBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||
* [init] function.
|
||||
*
|
||||
* The function [init] is called for each array element sequentially starting from the first one.
|
||||
* It should return the value for an buffer element given its index.
|
||||
*/
|
||||
inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) })
|
||||
|
||||
fun ShortArray.asBuffer() = ShortBuffer(this)
|
||||
/**
|
||||
* Returns a new [ShortBuffer] of given elements.
|
||||
*/
|
||||
fun ShortBuffer(vararg shorts: Short): ShortBuffer = ShortBuffer(shorts)
|
||||
|
||||
/**
|
||||
* Returns a [ShortArray] containing all of the elements of this [MutableBuffer].
|
||||
*/
|
||||
val MutableBuffer<out Short>.array: ShortArray
|
||||
get() = (if (this is ShortBuffer) array else ShortArray(size) { get(it) })
|
||||
|
||||
/**
|
||||
* Returns [ShortBuffer] over this array.
|
||||
*
|
||||
* @receiver the array.
|
||||
* @return the new buffer.
|
||||
*/
|
||||
fun ShortArray.asBuffer(): ShortBuffer = ShortBuffer(this)
|
||||
|
@ -12,8 +12,8 @@ class ShortNDRing(override val shape: IntArray) :
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
|
||||
override val elementContext: ShortRing get() = ShortRing
|
||||
override val zero by lazy { produce { ShortRing.zero } }
|
||||
override val one by lazy { produce { ShortRing.one } }
|
||||
override val zero: ShortNDElement by lazy { produce { zero } }
|
||||
override val one: ShortNDElement by lazy { produce { one } }
|
||||
|
||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
|
||||
ShortBuffer(ShortArray(size) { initializer(it) })
|
||||
@ -40,6 +40,7 @@ class ShortNDRing(override val shape: IntArray) :
|
||||
transform: ShortRing.(index: IntArray, Short) -> Short
|
||||
): ShortNDElement {
|
||||
check(arg)
|
||||
|
||||
return BufferedNDRingElement(
|
||||
this,
|
||||
buildBuffer(arg.strides.linearSize) { offset ->
|
||||
@ -67,7 +68,7 @@ class ShortNDRing(override val shape: IntArray) :
|
||||
|
||||
|
||||
/**
|
||||
* Fast element production using function inlining
|
||||
* Fast element production using function inlining.
|
||||
*/
|
||||
inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement {
|
||||
val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }
|
||||
@ -75,22 +76,22 @@ inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initialize
|
||||
}
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
* Element by element application of any operation on elements to the whole array.
|
||||
*/
|
||||
operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement) =
|
||||
operator fun Function1<Short, Short>.invoke(ndElement: ShortNDElement): ShortNDElement =
|
||||
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
|
||||
|
||||
|
||||
/* plus and minus */
|
||||
|
||||
/**
|
||||
* Summation operation for [StridedNDFieldElement] and single element
|
||||
* Summation operation for [ShortNDElement] and single element.
|
||||
*/
|
||||
operator fun ShortNDElement.plus(arg: Short) =
|
||||
operator fun ShortNDElement.plus(arg: Short): ShortNDElement =
|
||||
context.produceInline { i -> (buffer[i] + arg).toShort() }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [StridedNDFieldElement] and single element
|
||||
* Subtraction operation between [ShortNDElement] and single element.
|
||||
*/
|
||||
operator fun ShortNDElement.minus(arg: Short) =
|
||||
context.produceInline { i -> (buffer[i] - arg).toShort() }
|
||||
operator fun ShortNDElement.minus(arg: Short): ShortNDElement =
|
||||
context.produceInline { i -> (buffer[i] - arg).toShort() }
|
||||
|
@ -17,7 +17,7 @@ interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
||||
/**
|
||||
* A 1D wrapper for nd-structure
|
||||
*/
|
||||
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T>{
|
||||
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
|
||||
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override val size: Int get() = structure.shape[0]
|
||||
@ -39,14 +39,14 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
||||
|
||||
override fun get(index: Int): T = buffer.get(index)
|
||||
override fun get(index: Int): T = buffer[index]
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
if( this is NDBuffer){
|
||||
if (this is NDBuffer) {
|
||||
Buffer1DWrapper(this.buffer)
|
||||
} else {
|
||||
Structure1DWrapper(this)
|
||||
@ -59,4 +59,4 @@ fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
/**
|
||||
* Represent this buffer as 1D structure
|
||||
*/
|
||||
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||
|
@ -32,9 +32,7 @@ interface Structure2D<T> : NDStructure<T> {
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
@ -57,4 +55,4 @@ fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
typealias Matrix<T> = Structure2D<T>
|
||||
typealias Matrix<T> = Structure2D<T>
|
||||
|
@ -31,7 +31,7 @@ class ExpressionFieldTest {
|
||||
|
||||
@Test
|
||||
fun separateContext() {
|
||||
fun <T> FunctionalExpressionField<T,*>.expression(): Expression<T> {
|
||||
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
||||
val x = variable("x")
|
||||
return x * x + 2 * x + one
|
||||
}
|
||||
@ -42,7 +42,7 @@ class ExpressionFieldTest {
|
||||
|
||||
@Test
|
||||
fun valueExpression() {
|
||||
val expressionBuilder: FunctionalExpressionField<Double,*>.() -> Expression<Double> = {
|
||||
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
||||
val x = variable("x")
|
||||
x * x + 2 * x + one
|
||||
}
|
||||
@ -50,4 +50,4 @@ class ExpressionFieldTest {
|
||||
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -49,17 +49,17 @@ class MatrixTest {
|
||||
|
||||
@Test
|
||||
fun test2DDot() {
|
||||
val firstMatrix = NDStructure.auto(2,3){ (i, j) -> (i + j).toDouble() }.as2D()
|
||||
val secondMatrix = NDStructure.auto(3,2){ (i, j) -> (i + j).toDouble() }.as2D()
|
||||
val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D()
|
||||
val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D()
|
||||
MatrixContext.real.run {
|
||||
// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
|
||||
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
|
||||
val result = firstMatrix dot secondMatrix
|
||||
assertEquals(2, result.rowNum)
|
||||
assertEquals(2, result.colNum)
|
||||
assertEquals(8.0, result[0,1])
|
||||
assertEquals(8.0, result[1,0])
|
||||
assertEquals(14.0, result[1,1])
|
||||
assertEquals(8.0, result[0, 1])
|
||||
assertEquals(8.0, result[1, 0])
|
||||
assertEquals(14.0, result[1, 1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,4 +48,4 @@ class RealLUSolverTest {
|
||||
|
||||
assertEquals(expected, inverted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,10 +8,10 @@ import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class AutoDiffTest {
|
||||
fun Variable(int: Int): Variable<Double> = Variable(int.toDouble())
|
||||
|
||||
fun Variable(int: Int) = Variable(int.toDouble())
|
||||
|
||||
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>) = RealField.deriv(body)
|
||||
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
|
||||
RealField.deriv(body)
|
||||
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
@ -178,5 +178,4 @@ class AutoDiffTest {
|
||||
private fun assertApprox(a: Double, b: Double) {
|
||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -10,4 +10,4 @@ class CumulativeKtTest {
|
||||
val cumulative = initial.cumulativeSum()
|
||||
assertEquals(listOf(-1.0, 1.0, 2.0, 3.0), cumulative)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -47,4 +47,3 @@ class BigIntAlgebraTest {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -19,8 +19,8 @@ class BigIntConstructorTest {
|
||||
|
||||
@Test
|
||||
fun testConstructor_0xffffffffaL() {
|
||||
val x = -0xffffffffaL.toBigInt()
|
||||
val x = (-0xffffffffaL).toBigInt()
|
||||
val y = uintArrayOf(0xfffffffaU, 0xfU).toBigInt(-1)
|
||||
assertEquals(x, y)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ class BigIntConversionsTest {
|
||||
|
||||
@Test
|
||||
fun testToString_0x17ead2ffffd() {
|
||||
val x = -0x17ead2ffffdL.toBigInt()
|
||||
val x = (-0x17ead2ffffdL).toBigInt()
|
||||
assertEquals("-0x17ead2ffffd", x.toString())
|
||||
}
|
||||
|
||||
@ -40,4 +40,4 @@ class BigIntConversionsTest {
|
||||
val x = "-7059135710711894913860".parseBigInteger()
|
||||
assertEquals("-0x17ead2ffffd11223344", x.toString())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ class BigIntOperationsTest {
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val x = 1234.toBigInt()
|
||||
val y = -1234.toBigInt()
|
||||
val y = (-1234).toBigInt()
|
||||
assertEquals(-x, y)
|
||||
}
|
||||
|
||||
@ -48,18 +48,18 @@ class BigIntOperationsTest {
|
||||
|
||||
@Test
|
||||
fun testMinus__2_1() {
|
||||
val x = -2.toBigInt()
|
||||
val x = (-2).toBigInt()
|
||||
val y = 1.toBigInt()
|
||||
|
||||
val res = x - y
|
||||
val sum = -3.toBigInt()
|
||||
val sum = (-3).toBigInt()
|
||||
|
||||
assertEquals(sum, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus___2_1() {
|
||||
val x = -2.toBigInt()
|
||||
val x = (-2).toBigInt()
|
||||
val y = 1.toBigInt()
|
||||
|
||||
val res = -x - y
|
||||
@ -74,7 +74,7 @@ class BigIntOperationsTest {
|
||||
val y = 0xffffffffaL.toBigInt()
|
||||
|
||||
val res = x - y
|
||||
val sum = -0xfffffcfc1L.toBigInt()
|
||||
val sum = (-0xfffffcfc1L).toBigInt()
|
||||
|
||||
assertEquals(sum, res)
|
||||
}
|
||||
@ -92,11 +92,11 @@ class BigIntOperationsTest {
|
||||
|
||||
@Test
|
||||
fun testMultiply__2_3() {
|
||||
val x = -2.toBigInt()
|
||||
val x = (-2).toBigInt()
|
||||
val y = 3.toBigInt()
|
||||
|
||||
val res = x * y
|
||||
val prod = -6.toBigInt()
|
||||
val prod = (-6).toBigInt()
|
||||
|
||||
assertEquals(prod, res)
|
||||
}
|
||||
@ -129,7 +129,7 @@ class BigIntOperationsTest {
|
||||
val y = -0xfff456
|
||||
|
||||
val res = x * y
|
||||
val prod = -0xffe579ad5dc2L.toBigInt()
|
||||
val prod = (-0xffe579ad5dc2L).toBigInt()
|
||||
|
||||
assertEquals(prod, res)
|
||||
}
|
||||
@ -259,7 +259,7 @@ class BigIntOperationsTest {
|
||||
val y = -3
|
||||
|
||||
val res = x / y
|
||||
val div = -6.toBigInt()
|
||||
val div = (-6).toBigInt()
|
||||
|
||||
assertEquals(div, res)
|
||||
}
|
||||
@ -267,10 +267,10 @@ class BigIntOperationsTest {
|
||||
@Test
|
||||
fun testBigDivision_20__3() {
|
||||
val x = 20.toBigInt()
|
||||
val y = -3.toBigInt()
|
||||
val y = (-3).toBigInt()
|
||||
|
||||
val res = x / y
|
||||
val div = -6.toBigInt()
|
||||
val div = (-6).toBigInt()
|
||||
|
||||
assertEquals(div, res)
|
||||
}
|
||||
@ -378,4 +378,4 @@ class BigIntOperationsTest {
|
||||
|
||||
return assertEquals(res, x % mod)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,4 +11,4 @@ class RealFieldTest {
|
||||
}
|
||||
assertEquals(5.0, sqrt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,4 +11,4 @@ class ComplexBufferSpecTest {
|
||||
val buffer = Buffer.complex(20) { Complex(it.toDouble(), -it.toDouble()) }
|
||||
assertEquals(Complex(5.0, -5.0), buffer[5])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10,4 +10,4 @@ class NDFieldTest {
|
||||
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
|
||||
assertEquals(ndArray[5, 5], 10.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,8 +8,8 @@ import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class NumberNDFieldTest {
|
||||
val array1 = real2D(3, 3) { i, j -> (i + j).toDouble() }
|
||||
val array2 = real2D(3, 3) { i, j -> (i - j).toDouble() }
|
||||
val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() }
|
||||
val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() }
|
||||
|
||||
@Test
|
||||
fun testSum() {
|
||||
|
@ -5,33 +5,56 @@ import java.math.BigInteger
|
||||
import java.math.MathContext
|
||||
|
||||
/**
|
||||
* A field wrapper for Java [BigInteger]
|
||||
* A field over [BigInteger].
|
||||
*/
|
||||
object JBigIntegerField : Field<BigInteger> {
|
||||
override val zero: BigInteger = BigInteger.ZERO
|
||||
override val one: BigInteger = BigInteger.ONE
|
||||
override val zero: BigInteger
|
||||
get() = BigInteger.ZERO
|
||||
|
||||
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
||||
|
||||
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
||||
|
||||
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||
override val one: BigInteger
|
||||
get() = BigInteger.ONE
|
||||
|
||||
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
||||
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
|
||||
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
||||
override fun BigInteger.minus(b: BigInteger): BigInteger = this.subtract(b)
|
||||
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
||||
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||
override fun BigInteger.unaryMinus(): BigInteger = negate()
|
||||
}
|
||||
|
||||
/**
|
||||
* A Field wrapper for Java [BigDecimal]
|
||||
* An abstract field over [BigDecimal].
|
||||
*
|
||||
* @property mathContext the [MathContext] to use.
|
||||
*/
|
||||
class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
|
||||
override val zero: BigDecimal = BigDecimal.ZERO
|
||||
override val one: BigDecimal = BigDecimal.ONE
|
||||
abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathContext = MathContext.DECIMAL64) :
|
||||
Field<BigDecimal>,
|
||||
PowerOperations<BigDecimal> {
|
||||
override val zero: BigDecimal
|
||||
get() = BigDecimal.ZERO
|
||||
|
||||
override val one: BigDecimal
|
||||
get() = BigDecimal.ONE
|
||||
|
||||
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
||||
override fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
|
||||
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
||||
|
||||
override fun multiply(a: BigDecimal, k: Number): BigDecimal =
|
||||
a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext)
|
||||
|
||||
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
||||
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
||||
}
|
||||
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
||||
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
||||
override fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* A field over [BigDecimal].
|
||||
*/
|
||||
class JBigDecimalField(mathContext: MathContext = MathContext.DECIMAL64) : JBigDecimalFieldBase(mathContext) {
|
||||
companion object : JBigDecimalFieldBase()
|
||||
}
|
||||
|
@ -9,4 +9,4 @@ abstract class BlockingIntChain : Chain<Int> {
|
||||
override suspend fun next(): Int = nextInt()
|
||||
|
||||
fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
|
||||
}
|
||||
}
|
||||
|
@ -9,4 +9,4 @@ abstract class BlockingRealChain : Chain<Double> {
|
||||
override suspend fun next(): Double = nextDouble()
|
||||
|
||||
fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
|
||||
}
|
||||
}
|
||||
|
@ -22,12 +22,11 @@ import kotlinx.coroutines.flow.FlowCollector
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
|
||||
|
||||
/**
|
||||
* A not-necessary-Markov chain of some type
|
||||
* @param R - the chain element type
|
||||
*/
|
||||
interface Chain<out R>: Flow<R> {
|
||||
interface Chain<out R> : Flow<R> {
|
||||
/**
|
||||
* Generate next value, changing state if needed
|
||||
*/
|
||||
@ -41,7 +40,7 @@ interface Chain<out R>: Flow<R> {
|
||||
@OptIn(InternalCoroutinesApi::class)
|
||||
override suspend fun collect(collector: FlowCollector<R>) {
|
||||
kotlinx.coroutines.flow.flow {
|
||||
while (true){
|
||||
while (true) {
|
||||
emit(next())
|
||||
}
|
||||
}.collect(collector)
|
||||
@ -71,7 +70,7 @@ class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val ge
|
||||
|
||||
private var value: R? = null
|
||||
|
||||
fun value() = value
|
||||
fun value(): R? = value
|
||||
|
||||
override suspend fun next(): R {
|
||||
mutex.withLock {
|
||||
@ -97,12 +96,11 @@ class StatefulChain<S, out R>(
|
||||
private val forkState: ((S) -> S),
|
||||
private val gen: suspend S.(R) -> R
|
||||
) : Chain<R> {
|
||||
|
||||
private val mutex = Mutex()
|
||||
private val mutex: Mutex = Mutex()
|
||||
|
||||
private var value: R? = null
|
||||
|
||||
fun value() = value
|
||||
fun value(): R? = value
|
||||
|
||||
override suspend fun next(): R {
|
||||
mutex.withLock {
|
||||
@ -112,9 +110,7 @@ class StatefulChain<S, out R>(
|
||||
}
|
||||
}
|
||||
|
||||
override fun fork(): Chain<R> {
|
||||
return StatefulChain(forkState(state), seed, forkState, gen)
|
||||
}
|
||||
override fun fork(): Chain<R> = StatefulChain(forkState(state), seed, forkState, gen)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -163,7 +159,8 @@ fun <T, R> Chain<T>.collect(mapper: suspend (Chain<T>) -> R): Chain<R> = object
|
||||
fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
||||
object : Chain<R> {
|
||||
override suspend fun next(): R = state.mapper(this@collectWithState)
|
||||
override fun fork(): Chain<R> = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
|
||||
override fun fork(): Chain<R> =
|
||||
this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -173,4 +170,4 @@ fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R
|
||||
override suspend fun next(): R = block(this@zip.next(), other.next())
|
||||
|
||||
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
||||
}
|
||||
}
|
||||
|
@ -24,4 +24,4 @@ fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = with(space) {
|
||||
this.num += 1
|
||||
}
|
||||
}.map { it.sum / it.num }
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,8 @@ import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.channels.produce
|
||||
import kotlinx.coroutines.flow.*
|
||||
|
||||
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
|
||||
val Dispatchers.Math: CoroutineDispatcher
|
||||
get() = Default
|
||||
|
||||
/**
|
||||
* An imitator of [Deferred] which holds a suspended function block and dispatcher
|
||||
@ -42,7 +43,7 @@ fun <T, R> Flow<T>.async(
|
||||
}
|
||||
|
||||
@FlowPreview
|
||||
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) =
|
||||
fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
|
||||
AsyncFlow(deferredFlow.map { input ->
|
||||
//TODO add function composition
|
||||
LazyDeferred(input.dispatcher) {
|
||||
@ -82,9 +83,9 @@ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
@FlowPreview
|
||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit {
|
||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit) {
|
||||
collect(concurrency, object : FlowCollector<T> {
|
||||
override suspend fun emit(value: T) = action(value)
|
||||
override suspend fun emit(value: T): Unit = action(value)
|
||||
})
|
||||
}
|
||||
|
||||
@ -94,9 +95,7 @@ fun <T, R> Flow<T>.mapParallel(
|
||||
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
||||
transform: suspend (T) -> R
|
||||
): Flow<R> {
|
||||
return flatMapMerge{ value ->
|
||||
return flatMapMerge { value ->
|
||||
flow { emit(transform(value)) }
|
||||
}.flowOn(dispatcher)
|
||||
}
|
||||
|
||||
|
||||
|
@ -11,7 +11,7 @@ import scientifik.kmath.structures.asBuffer
|
||||
/**
|
||||
* Create a [Flow] from buffer
|
||||
*/
|
||||
fun <T> Buffer<T>.asFlow() = iterator().asFlow()
|
||||
fun <T> Buffer<T>.asFlow(): Flow<T> = iterator().asFlow()
|
||||
|
||||
/**
|
||||
* Flat map a [Flow] of [Buffer] into continuous [Flow] of elements
|
||||
@ -83,4 +83,4 @@ fun <T> Flow<T>.windowed(window: Int): Flow<Buffer<T>> = flow {
|
||||
ringBuffer.push(element)
|
||||
emit(ringBuffer.snapshot())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,19 +5,17 @@ import kotlinx.coroutines.sync.withLock
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.MutableBuffer
|
||||
import scientifik.kmath.structures.VirtualBuffer
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* Thread-safe ring buffer
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
internal class RingBuffer<T>(
|
||||
class RingBuffer<T>(
|
||||
private val buffer: MutableBuffer<T?>,
|
||||
private var startIndex: Int = 0,
|
||||
size: Int = 0
|
||||
) : Buffer<T> {
|
||||
|
||||
private val mutex = Mutex()
|
||||
private val mutex: Mutex = Mutex()
|
||||
|
||||
override var size: Int = size
|
||||
private set
|
||||
@ -28,7 +26,7 @@ internal class RingBuffer<T>(
|
||||
return buffer[startIndex.forward(index)] as T
|
||||
}
|
||||
|
||||
fun isFull() = size == buffer.size
|
||||
fun isFull(): Boolean = size == buffer.size
|
||||
|
||||
/**
|
||||
* Iterator could provide wrong results if buffer is changed in initialization (iteration is safe)
|
||||
@ -90,4 +88,4 @@ internal class RingBuffer<T>(
|
||||
return RingBuffer(buffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import kotlin.sequences.Sequence
|
||||
/**
|
||||
* Represent a chain as regular iterator (uses blocking calls)
|
||||
*/
|
||||
operator fun <R> Chain<R>.iterator() = object : Iterator<R> {
|
||||
operator fun <R> Chain<R>.iterator(): Iterator<R> = object : Iterator<R> {
|
||||
override fun hasNext(): Boolean = true
|
||||
|
||||
override fun next(): R = runBlocking { next() }
|
||||
|
@ -8,10 +8,9 @@ class LazyNDStructure<T>(
|
||||
override val shape: IntArray,
|
||||
val function: suspend (IntArray) -> T
|
||||
) : NDStructure<T> {
|
||||
private val cache: MutableMap<IntArray, Deferred<T>> = hashMapOf()
|
||||
|
||||
private val cache = HashMap<IntArray, Deferred<T>>()
|
||||
|
||||
fun deferred(index: IntArray) = cache.getOrPut(index) {
|
||||
fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) {
|
||||
scope.async(context = Dispatchers.Math) {
|
||||
function(index)
|
||||
}
|
||||
@ -42,21 +41,21 @@ class LazyNDStructure<T>(
|
||||
result = 31 * result + cache.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
fun <T> NDStructure<T>.deferred(index: IntArray) =
|
||||
fun <T> NDStructure<T>.deferred(index: IntArray): Deferred<T> =
|
||||
if (this is LazyNDStructure<T>) this.deferred(index) else CompletableDeferred(get(index))
|
||||
|
||||
suspend fun <T> NDStructure<T>.await(index: IntArray) =
|
||||
suspend fun <T> NDStructure<T>.await(index: IntArray): T =
|
||||
if (this is LazyNDStructure<T>) this.await(index) else get(index)
|
||||
|
||||
/**
|
||||
* PENDING would benifit from KEEP-176
|
||||
* PENDING would benefit from KEEP-176
|
||||
*/
|
||||
fun <T, R> NDStructure<T>.mapAsyncIndexed(scope: CoroutineScope, function: suspend (T, index: IntArray) -> R) =
|
||||
LazyNDStructure(scope, shape) { index -> function(get(index), index) }
|
||||
fun <T, R> NDStructure<T>.mapAsyncIndexed(
|
||||
scope: CoroutineScope,
|
||||
function: suspend (T, index: IntArray) -> R
|
||||
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) }
|
||||
|
||||
fun <T, R> NDStructure<T>.mapAsync(scope: CoroutineScope, function: suspend (T) -> R) =
|
||||
fun <T, R> NDStructure<T>.mapAsync(scope: CoroutineScope, function: suspend (T) -> R): LazyNDStructure<R> =
|
||||
LazyNDStructure(scope, shape) { index -> function(get(index)) }
|
@ -15,14 +15,13 @@ import kotlin.test.Test
|
||||
@InternalCoroutinesApi
|
||||
@FlowPreview
|
||||
class BufferFlowTest {
|
||||
|
||||
val dispatcher = Executors.newFixedThreadPool(4).asCoroutineDispatcher()
|
||||
val dispatcher: CoroutineDispatcher = Executors.newFixedThreadPool(4).asCoroutineDispatcher()
|
||||
|
||||
@Test
|
||||
@Timeout(2000)
|
||||
fun map() {
|
||||
runBlocking {
|
||||
(1..20).asFlow().mapParallel( dispatcher) {
|
||||
(1..20).asFlow().mapParallel(dispatcher) {
|
||||
println("Started $it on ${Thread.currentThread().name}")
|
||||
@Suppress("BlockingMethodInNonBlockingContext")
|
||||
Thread.sleep(200)
|
||||
|
@ -19,17 +19,17 @@ class RingBufferTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun windowed(){
|
||||
val flow = flow{
|
||||
fun windowed() {
|
||||
val flow = flow {
|
||||
var i = 0
|
||||
while(true){
|
||||
emit(i++)
|
||||
}
|
||||
while (true) emit(i++)
|
||||
}
|
||||
|
||||
val windowed = flow.windowed(10)
|
||||
|
||||
runBlocking {
|
||||
val first = windowed.take(1).single()
|
||||
val res = windowed.take(15).map { it -> it.asSequence().average() }.toList()
|
||||
val res = windowed.take(15).map { it.asSequence().average() }.toList()
|
||||
assertEquals(0.0, res[0])
|
||||
assertEquals(4.5, res[9])
|
||||
assertEquals(9.5, res[14])
|
||||
|
@ -31,5 +31,5 @@ object D2 : Dimension {
|
||||
}
|
||||
|
||||
object D3 : Dimension {
|
||||
override val dim: UInt get() = 31U
|
||||
}
|
||||
override val dim: UInt get() = 3U
|
||||
}
|
||||
|
@ -1,77 +1,148 @@
|
||||
package scientifik.memory
|
||||
|
||||
/**
|
||||
* Represents a display of certain memory structure.
|
||||
*/
|
||||
interface Memory {
|
||||
/**
|
||||
* The length of this memory in bytes.
|
||||
*/
|
||||
val size: Int
|
||||
|
||||
/**
|
||||
* Get a projection of this memory (it reflects the changes in the parent memory block)
|
||||
* Get a projection of this memory (it reflects the changes in the parent memory block).
|
||||
*/
|
||||
fun view(offset: Int, length: Int): Memory
|
||||
|
||||
/**
|
||||
* Create a copy of this memory, which does not know anything about this memory
|
||||
* Creates an independent copy of this memory.
|
||||
*/
|
||||
fun copy(): Memory
|
||||
|
||||
/**
|
||||
* Create and possibly register a new reader
|
||||
* Gets or creates a reader of this memory.
|
||||
*/
|
||||
fun reader(): MemoryReader
|
||||
|
||||
/**
|
||||
* Gets or creates a writer of this memory.
|
||||
*/
|
||||
fun writer(): MemoryWriter
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* The interface to read primitive types in this memory.
|
||||
*/
|
||||
interface MemoryReader {
|
||||
/**
|
||||
* The underlying memory.
|
||||
*/
|
||||
val memory: Memory
|
||||
|
||||
/**
|
||||
* Reads [Double] at certain [offset].
|
||||
*/
|
||||
fun readDouble(offset: Int): Double
|
||||
|
||||
/**
|
||||
* Reads [Float] at certain [offset].
|
||||
*/
|
||||
fun readFloat(offset: Int): Float
|
||||
|
||||
/**
|
||||
* Reads [Byte] at certain [offset].
|
||||
*/
|
||||
fun readByte(offset: Int): Byte
|
||||
|
||||
/**
|
||||
* Reads [Short] at certain [offset].
|
||||
*/
|
||||
fun readShort(offset: Int): Short
|
||||
|
||||
/**
|
||||
* Reads [Int] at certain [offset].
|
||||
*/
|
||||
fun readInt(offset: Int): Int
|
||||
|
||||
/**
|
||||
* Reads [Long] at certain [offset].
|
||||
*/
|
||||
fun readLong(offset: Int): Long
|
||||
|
||||
/**
|
||||
* Disposes this reader if needed.
|
||||
*/
|
||||
fun release()
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the memory for read then release the reader
|
||||
* Uses the memory for read then releases the reader.
|
||||
*/
|
||||
inline fun Memory.read(block: MemoryReader.() -> Unit) {
|
||||
reader().apply(block).apply { release() }
|
||||
reader().apply(block).release()
|
||||
}
|
||||
|
||||
/**
|
||||
* The interface to write primitive types into this memory.
|
||||
*/
|
||||
interface MemoryWriter {
|
||||
/**
|
||||
* The underlying memory.
|
||||
*/
|
||||
val memory: Memory
|
||||
|
||||
/**
|
||||
* Writes [Double] at certain [offset].
|
||||
*/
|
||||
fun writeDouble(offset: Int, value: Double)
|
||||
|
||||
/**
|
||||
* Writes [Float] at certain [offset].
|
||||
*/
|
||||
fun writeFloat(offset: Int, value: Float)
|
||||
|
||||
/**
|
||||
* Writes [Byte] at certain [offset].
|
||||
*/
|
||||
fun writeByte(offset: Int, value: Byte)
|
||||
|
||||
/**
|
||||
* Writes [Short] at certain [offset].
|
||||
*/
|
||||
fun writeShort(offset: Int, value: Short)
|
||||
|
||||
/**
|
||||
* Writes [Int] at certain [offset].
|
||||
*/
|
||||
fun writeInt(offset: Int, value: Int)
|
||||
|
||||
/**
|
||||
* Writes [Long] at certain [offset].
|
||||
*/
|
||||
fun writeLong(offset: Int, value: Long)
|
||||
|
||||
/**
|
||||
* Disposes this writer if needed.
|
||||
*/
|
||||
fun release()
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the memory for write then release the writer
|
||||
* Uses the memory for write then releases the writer.
|
||||
*/
|
||||
inline fun Memory.write(block: MemoryWriter.() -> Unit) {
|
||||
writer().apply(block).apply { release() }
|
||||
writer().apply(block).release()
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate the most effective platform-specific memory
|
||||
* Allocates the most effective platform-specific memory.
|
||||
*/
|
||||
expect fun Memory.Companion.allocate(length: Int): Memory
|
||||
|
||||
/**
|
||||
* Wrap a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied
|
||||
* and could be mutated independently from the resulting [Memory]
|
||||
* Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied
|
||||
* and could be mutated independently from the resulting [Memory].
|
||||
*/
|
||||
expect fun Memory.Companion.wrap(array: ByteArray): Memory
|
||||
|
@ -1,35 +1,53 @@
|
||||
package scientifik.memory
|
||||
|
||||
/**
|
||||
* A specification to read or write custom objects with fixed size in bytes
|
||||
* A specification to read or write custom objects with fixed size in bytes.
|
||||
*
|
||||
* @param T the type of object this spec manages.
|
||||
*/
|
||||
interface MemorySpec<T : Any> {
|
||||
/**
|
||||
* Size of [T] in bytes after serialization
|
||||
* Size of [T] in bytes after serialization.
|
||||
*/
|
||||
val objectSize: Int
|
||||
|
||||
/**
|
||||
* Reads the object starting from [offset].
|
||||
*/
|
||||
fun MemoryReader.read(offset: Int): T
|
||||
//TODO consider thread safety
|
||||
|
||||
// TODO consider thread safety
|
||||
|
||||
/**
|
||||
* Writes the object [value] starting from [offset].
|
||||
*/
|
||||
fun MemoryWriter.write(offset: Int, value: T)
|
||||
}
|
||||
|
||||
fun <T : Any> MemoryReader.read(spec: MemorySpec<T>, offset: Int): T = spec.run { read(offset) }
|
||||
fun <T : Any> MemoryWriter.write(spec: MemorySpec<T>, offset: Int, value: T) = spec.run { write(offset, value) }
|
||||
/**
|
||||
* Reads the object with [spec] starting from [offset].
|
||||
*/
|
||||
fun <T : Any> MemoryReader.read(spec: MemorySpec<T>, offset: Int): T = with(spec) { read(offset) }
|
||||
|
||||
inline fun <reified T : Any> MemoryReader.readArray(spec: MemorySpec<T>, offset: Int, size: Int) =
|
||||
/**
|
||||
* Writes the object [value] with [spec] starting from [offset].
|
||||
*/
|
||||
fun <T : Any> MemoryWriter.write(spec: MemorySpec<T>, offset: Int, value: T): Unit = with(spec) { write(offset, value) }
|
||||
|
||||
/**
|
||||
* Reads array of [size] objects mapped by [spec] at certain [offset].
|
||||
*/
|
||||
inline fun <reified T : Any> MemoryReader.readArray(spec: MemorySpec<T>, offset: Int, size: Int): Array<T> =
|
||||
Array(size) { i ->
|
||||
spec.run {
|
||||
read(offset + i * objectSize)
|
||||
}
|
||||
}
|
||||
|
||||
fun <T : Any> MemoryWriter.writeArray(spec: MemorySpec<T>, offset: Int, array: Array<T>) {
|
||||
spec.run {
|
||||
for (i in array.indices) {
|
||||
write(offset + i * objectSize, array[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Writes [array] of objects mapped by [spec] at certain [offset].
|
||||
*/
|
||||
fun <T : Any> MemoryWriter.writeArray(spec: MemorySpec<T>, offset: Int, array: Array<T>): Unit =
|
||||
with(spec) { array.indices.forEach { i -> write(offset + i * objectSize, array[i]) } }
|
||||
|
||||
//TODO It is possible to add elastic MemorySpec with unknown object size
|
||||
// TODO It is possible to add elastic MemorySpec with unknown object size
|
||||
|
@ -4,13 +4,13 @@ import org.khronos.webgl.ArrayBuffer
|
||||
import org.khronos.webgl.DataView
|
||||
import org.khronos.webgl.Int8Array
|
||||
|
||||
class DataViewMemory(val view: DataView) : Memory {
|
||||
|
||||
private class DataViewMemory(val view: DataView) : Memory {
|
||||
override val size: Int get() = view.byteLength
|
||||
|
||||
override fun view(offset: Int, length: Int): Memory {
|
||||
require(offset >= 0) { "offset shouldn't be negative: $offset" }
|
||||
require(length >= 0) { "length shouldn't be negative: $length" }
|
||||
require(offset + length <= size) { "Can't view memory outside the parent region." }
|
||||
|
||||
if (offset + length > size)
|
||||
throw IndexOutOfBoundsException("offset + length > size: $offset + $length > $size")
|
||||
@ -33,11 +33,11 @@ class DataViewMemory(val view: DataView) : Memory {
|
||||
|
||||
override fun readInt(offset: Int): Int = view.getInt32(offset, false)
|
||||
|
||||
override fun readLong(offset: Int): Long = (view.getInt32(offset, false).toLong() shl 32) or
|
||||
view.getInt32(offset + 4, false).toLong()
|
||||
override fun readLong(offset: Int): Long =
|
||||
view.getInt32(offset, false).toLong() shl 32 or view.getInt32(offset + 4, false).toLong()
|
||||
|
||||
override fun release() {
|
||||
// does nothing on JS because of GC
|
||||
// does nothing on JS
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,7 +72,7 @@ class DataViewMemory(val view: DataView) : Memory {
|
||||
}
|
||||
|
||||
override fun release() {
|
||||
//does nothing on JS
|
||||
// does nothing on JS
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,13 +81,17 @@ class DataViewMemory(val view: DataView) : Memory {
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate the most effective platform-specific memory
|
||||
* Allocates memory based on a [DataView].
|
||||
*/
|
||||
actual fun Memory.Companion.allocate(length: Int): Memory {
|
||||
val buffer = ArrayBuffer(length)
|
||||
return DataViewMemory(DataView(buffer, 0, length))
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied
|
||||
* and could be mutated independently from the resulting [Memory].
|
||||
*/
|
||||
actual fun Memory.Companion.wrap(array: ByteArray): Memory {
|
||||
@Suppress("CAST_NEVER_SUCCEEDS") val int8Array = array as Int8Array
|
||||
return DataViewMemory(DataView(int8Array.buffer, int8Array.byteOffset, int8Array.length))
|
||||
|
@ -6,19 +6,18 @@ import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import java.nio.file.StandardOpenOption
|
||||
|
||||
|
||||
private class ByteBufferMemory(
|
||||
val buffer: ByteBuffer,
|
||||
val startOffset: Int = 0,
|
||||
override val size: Int = buffer.limit()
|
||||
) : Memory {
|
||||
|
||||
|
||||
@Suppress("NOTHING_TO_INLINE")
|
||||
private inline fun position(o: Int): Int = startOffset + o
|
||||
|
||||
override fun view(offset: Int, length: Int): Memory {
|
||||
if (offset + length > size) error("Selecting a Memory view outside of memory range")
|
||||
require(offset >= 0) { "offset shouldn't be negative: $offset" }
|
||||
require(length >= 0) { "length shouldn't be negative: $length" }
|
||||
require(offset + length <= size) { "Can't view memory outside the parent region." }
|
||||
return ByteBufferMemory(buffer, position(offset), length)
|
||||
}
|
||||
|
||||
@ -28,10 +27,9 @@ private class ByteBufferMemory(
|
||||
copy.put(buffer)
|
||||
copy.flip()
|
||||
return ByteBufferMemory(copy)
|
||||
|
||||
}
|
||||
|
||||
private val reader = object : MemoryReader {
|
||||
private val reader: MemoryReader = object : MemoryReader {
|
||||
override val memory: Memory get() = this@ByteBufferMemory
|
||||
|
||||
override fun readDouble(offset: Int) = buffer.getDouble(position(offset))
|
||||
@ -47,13 +45,13 @@ private class ByteBufferMemory(
|
||||
override fun readLong(offset: Int) = buffer.getLong(position(offset))
|
||||
|
||||
override fun release() {
|
||||
//does nothing on JVM
|
||||
// does nothing on JVM
|
||||
}
|
||||
}
|
||||
|
||||
override fun reader(): MemoryReader = reader
|
||||
|
||||
private val writer = object : MemoryWriter {
|
||||
private val writer: MemoryWriter = object : MemoryWriter {
|
||||
override val memory: Memory get() = this@ByteBufferMemory
|
||||
|
||||
override fun writeDouble(offset: Int, value: Double) {
|
||||
@ -81,7 +79,7 @@ private class ByteBufferMemory(
|
||||
}
|
||||
|
||||
override fun release() {
|
||||
//does nothing on JVM
|
||||
// does nothing on JVM
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,26 +87,32 @@ private class ByteBufferMemory(
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate the most effective platform-specific memory
|
||||
* Allocates memory based on a [ByteBuffer].
|
||||
*/
|
||||
actual fun Memory.Companion.allocate(length: Int): Memory {
|
||||
val buffer = ByteBuffer.allocate(length)
|
||||
return ByteBufferMemory(buffer)
|
||||
}
|
||||
actual fun Memory.Companion.allocate(length: Int): Memory =
|
||||
ByteBufferMemory(checkNotNull(ByteBuffer.allocate(length)))
|
||||
|
||||
actual fun Memory.Companion.wrap(array: ByteArray): Memory {
|
||||
val buffer = ByteBuffer.wrap(array)
|
||||
return ByteBufferMemory(buffer)
|
||||
}
|
||||
/**
|
||||
* Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied
|
||||
* and could be mutated independently from the resulting [Memory].
|
||||
*/
|
||||
actual fun Memory.Companion.wrap(array: ByteArray): Memory = ByteBufferMemory(checkNotNull(ByteBuffer.wrap(array)))
|
||||
|
||||
/**
|
||||
* Wraps this [ByteBuffer] to [Memory] object.
|
||||
*
|
||||
* @receiver the byte buffer.
|
||||
* @param startOffset the start offset.
|
||||
* @param size the size of memory to map.
|
||||
* @return the [Memory] object.
|
||||
*/
|
||||
fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory =
|
||||
ByteBufferMemory(this, startOffset, size)
|
||||
|
||||
/**
|
||||
* Use direct memory-mapped buffer from file to read something and close it afterwards.
|
||||
* Uses direct memory-mapped buffer from file to read something and close it afterwards.
|
||||
*/
|
||||
fun <R> Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R {
|
||||
return FileChannel.open(this, StandardOpenOption.READ).use {
|
||||
fun <R> Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R =
|
||||
FileChannel.open(this, StandardOpenOption.READ).use {
|
||||
ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user