forked from kscience/kmath
Merge branch 'dev' into kotlingrad
This commit is contained in:
commit
520f6cedeb
@ -7,12 +7,13 @@
|
||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
||||
- Automatic README generation for features (#139)
|
||||
- Native support for `memory`, `core` and `dimensions`
|
||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
|
||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136).
|
||||
- A separate `Symbol` entity, which is used for global unbound symbol.
|
||||
- A `Symbol` indexing scope.
|
||||
- Basic optimization API for Commons-math.
|
||||
- Chi squared optimization for array-like data in CM
|
||||
- `Fitting` utility object in prob/stat
|
||||
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`.
|
||||
|
||||
### Changed
|
||||
- Package changed from `scientifik` to `kscience.kmath`.
|
||||
|
92
README.md
92
README.md
@ -8,41 +8,50 @@ Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience
|
||||
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
|
||||
|
||||
# KMath
|
||||
Could be pronounced as `key-math`.
|
||||
The Kotlin MATHematics library was initially intended as a Kotlin-based analog to Python's `numpy` library. Later we found that kotlin is much more flexible language and allows superior architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
||||
|
||||
Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to
|
||||
Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture
|
||||
designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could
|
||||
be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
||||
|
||||
## Publications and talks
|
||||
|
||||
* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2)
|
||||
* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814)
|
||||
* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103)
|
||||
|
||||
# Goal
|
||||
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future).
|
||||
|
||||
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native).
|
||||
* Provide basic multiplatform implementations for those abstractions (without significant performance optimization).
|
||||
* Provide bindings and wrappers with those abstractions for popular optimized platform libraries.
|
||||
|
||||
## Non-goals
|
||||
* Be like Numpy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
||||
* Provide best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
||||
|
||||
* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
||||
* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
||||
* Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually.
|
||||
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better experience for those, who want to work with specific types.
|
||||
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like
|
||||
for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better
|
||||
experience for those, who want to work with specific types.
|
||||
|
||||
## Features
|
||||
|
||||
Actual feature list is [here](/docs/features.md)
|
||||
Current feature list is [here](/docs/features.md)
|
||||
|
||||
* **Algebra**
|
||||
* Algebraic structures like rings, spaces and field (**TODO** add example to wiki)
|
||||
* Algebraic structures like rings, spaces and fields (**TODO** add example to wiki)
|
||||
* Basic linear algebra operations (sums, products, etc.), backed by the `Space` API.
|
||||
* Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays).
|
||||
* Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and
|
||||
N-dimensional arrays).
|
||||
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
||||
|
||||
* **Array-like structures** Full support of many-dimensional array-like structures
|
||||
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
||||
|
||||
* **Expressions** By writing a single mathematical expression
|
||||
once, users will be able to apply different types of objects to the expression by providing a context. Expressions
|
||||
can be used for a wide variety of purposes from high performance calculations to code generation.
|
||||
* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of
|
||||
objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high
|
||||
performance calculations to code generation.
|
||||
|
||||
* **Histograms** Fast multi-dimensional histograms.
|
||||
|
||||
@ -50,9 +59,10 @@ can be used for a wide variety of purposes from high performance calculations to
|
||||
|
||||
* **Type-safe dimensions** Type-safe dimensions for matrix operations.
|
||||
|
||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||
to submit a feature request if you want something to be done first.
|
||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of
|
||||
[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some
|
||||
parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to
|
||||
submit a feature request if you want something to be implemented first.
|
||||
|
||||
## Planned features
|
||||
|
||||
@ -151,6 +161,18 @@ can be used for a wide variety of purposes from high performance calculations to
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-nd4j](kmath-nd4j)
|
||||
> ND4J NDStructure implementation and according NDAlgebra classes
|
||||
>
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
>
|
||||
> **Features:**
|
||||
> - [nd4jarraystrucure](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray
|
||||
> - [nd4jarrayrings](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long
|
||||
> - [nd4jarrayfields](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double
|
||||
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-stat](kmath-stat)
|
||||
>
|
||||
>
|
||||
@ -166,39 +188,53 @@ can be used for a wide variety of purposes from high performance calculations to
|
||||
|
||||
## Multi-platform support
|
||||
|
||||
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the [common module](/kmath-core/src/commonMain). Implementation is also done in the common module wherever possible. In some cases, features are delegated to platform-specific implementations even if they could be done in the common module for performance reasons. Currently, the JVM is the main focus of development, however Kotlin/Native and Kotlin/JS contributions are also welcome.
|
||||
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the
|
||||
[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features
|
||||
are delegated to platform-specific implementations even if they could be provided in the common module for performance
|
||||
reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and
|
||||
feedback are also welcome.
|
||||
|
||||
## Performance
|
||||
|
||||
Calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve both performance and flexibility. We expect to focus on creating convenient universal API first and then work on increasing performance for specific cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy.
|
||||
Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve
|
||||
both performance and flexibility.
|
||||
|
||||
### Dependency
|
||||
We expect to focus on creating convenient universal API first and then work on increasing performance for specific
|
||||
cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized
|
||||
native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be
|
||||
better than SciPy.
|
||||
|
||||
Release artifacts are accessible from bintray with following configuration (see documentation for [kotlin-multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) form more details):
|
||||
### Repositories
|
||||
|
||||
Release artifacts are accessible from bintray with following configuration (see documentation of
|
||||
[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details):
|
||||
|
||||
```kotlin
|
||||
repositories{
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
}
|
||||
|
||||
dependencies{
|
||||
api("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
|
||||
dependencies {
|
||||
api("kscience.kmath:kmath-core:0.2.0-dev-3")
|
||||
// api("kscience.kmath:kmath-core-jvm:0.2.0-dev-3") for jvm-specific version
|
||||
}
|
||||
```
|
||||
|
||||
Gradle `6.0+` is required for multiplatform artifacts.
|
||||
|
||||
### Development
|
||||
#### Development
|
||||
|
||||
Development builds are uploaded to the separate repository:
|
||||
|
||||
Development builds are accessible from the reposirtory
|
||||
```kotlin
|
||||
repositories{
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
}
|
||||
```
|
||||
with the same artifact names.
|
||||
|
||||
## Contributing
|
||||
|
||||
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
||||
The project requires a lot of additional work. The most important thing we need is a feedback about what features are
|
||||
required the most. Feel free to create feature requests. We are also welcome to code contributions,
|
||||
especially in issues marked with
|
||||
[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label.
|
||||
|
@ -2,9 +2,9 @@ plugins {
|
||||
id("ru.mipt.npm.project")
|
||||
}
|
||||
|
||||
val kmathVersion: String by extra("0.2.0-dev-3")
|
||||
val bintrayRepo: String by extra("kscience")
|
||||
val githubProject: String by extra("kmath")
|
||||
internal val kmathVersion: String by extra("0.2.0-dev-3")
|
||||
internal val bintrayRepo: String by extra("kscience")
|
||||
internal val githubProject: String by extra("kmath")
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
@ -27,6 +27,6 @@ readme {
|
||||
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||
}
|
||||
|
||||
apiValidation{
|
||||
apiValidation {
|
||||
validationDisabled = true
|
||||
}
|
||||
}
|
||||
|
78
docs/templates/README-TEMPLATE.md
vendored
78
docs/templates/README-TEMPLATE.md
vendored
@ -8,41 +8,50 @@ Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience
|
||||
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
|
||||
|
||||
# KMath
|
||||
Could be pronounced as `key-math`.
|
||||
The Kotlin MATHematics library was initially intended as a Kotlin-based analog to Python's `numpy` library. Later we found that kotlin is much more flexible language and allows superior architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
||||
|
||||
Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to
|
||||
Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture
|
||||
designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could
|
||||
be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
||||
|
||||
## Publications and talks
|
||||
|
||||
* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2)
|
||||
* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814)
|
||||
* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103)
|
||||
|
||||
# Goal
|
||||
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future).
|
||||
|
||||
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native).
|
||||
* Provide basic multiplatform implementations for those abstractions (without significant performance optimization).
|
||||
* Provide bindings and wrappers with those abstractions for popular optimized platform libraries.
|
||||
|
||||
## Non-goals
|
||||
* Be like Numpy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
||||
* Provide best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
||||
|
||||
* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
||||
* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
||||
* Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually.
|
||||
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better experience for those, who want to work with specific types.
|
||||
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like
|
||||
for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better
|
||||
experience for those, who want to work with specific types.
|
||||
|
||||
## Features
|
||||
|
||||
Actual feature list is [here](/docs/features.md)
|
||||
Current feature list is [here](/docs/features.md)
|
||||
|
||||
* **Algebra**
|
||||
* Algebraic structures like rings, spaces and field (**TODO** add example to wiki)
|
||||
* Algebraic structures like rings, spaces and fields (**TODO** add example to wiki)
|
||||
* Basic linear algebra operations (sums, products, etc.), backed by the `Space` API.
|
||||
* Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays).
|
||||
* Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and
|
||||
N-dimensional arrays).
|
||||
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
||||
|
||||
* **Array-like structures** Full support of many-dimensional array-like structures
|
||||
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
||||
|
||||
* **Expressions** By writing a single mathematical expression
|
||||
once, users will be able to apply different types of objects to the expression by providing a context. Expressions
|
||||
can be used for a wide variety of purposes from high performance calculations to code generation.
|
||||
* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of
|
||||
objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high
|
||||
performance calculations to code generation.
|
||||
|
||||
* **Histograms** Fast multi-dimensional histograms.
|
||||
|
||||
@ -50,9 +59,10 @@ can be used for a wide variety of purposes from high performance calculations to
|
||||
|
||||
* **Type-safe dimensions** Type-safe dimensions for matrix operations.
|
||||
|
||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||
to submit a feature request if you want something to be done first.
|
||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of
|
||||
[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some
|
||||
parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to
|
||||
submit a feature request if you want something to be implemented first.
|
||||
|
||||
## Planned features
|
||||
|
||||
@ -72,39 +82,53 @@ $modules
|
||||
|
||||
## Multi-platform support
|
||||
|
||||
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the [common module](/kmath-core/src/commonMain). Implementation is also done in the common module wherever possible. In some cases, features are delegated to platform-specific implementations even if they could be done in the common module for performance reasons. Currently, the JVM is the main focus of development, however Kotlin/Native and Kotlin/JS contributions are also welcome.
|
||||
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the
|
||||
[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features
|
||||
are delegated to platform-specific implementations even if they could be provided in the common module for performance
|
||||
reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and
|
||||
feedback are also welcome.
|
||||
|
||||
## Performance
|
||||
|
||||
Calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve both performance and flexibility. We expect to focus on creating convenient universal API first and then work on increasing performance for specific cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy.
|
||||
Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve
|
||||
both performance and flexibility.
|
||||
|
||||
### Dependency
|
||||
We expect to focus on creating convenient universal API first and then work on increasing performance for specific
|
||||
cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized
|
||||
native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be
|
||||
better than SciPy.
|
||||
|
||||
Release artifacts are accessible from bintray with following configuration (see documentation for [kotlin-multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) form more details):
|
||||
### Repositories
|
||||
|
||||
Release artifacts are accessible from bintray with following configuration (see documentation of
|
||||
[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details):
|
||||
|
||||
```kotlin
|
||||
repositories{
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
}
|
||||
|
||||
dependencies{
|
||||
dependencies {
|
||||
api("kscience.kmath:kmath-core:$version")
|
||||
//api("kscience.kmath:kmath-core-jvm:$version") for jvm-specific version
|
||||
// api("kscience.kmath:kmath-core-jvm:$version") for jvm-specific version
|
||||
}
|
||||
```
|
||||
|
||||
Gradle `6.0+` is required for multiplatform artifacts.
|
||||
|
||||
### Development
|
||||
#### Development
|
||||
|
||||
Development builds are uploaded to the separate repository:
|
||||
|
||||
Development builds are accessible from the reposirtory
|
||||
```kotlin
|
||||
repositories{
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
}
|
||||
```
|
||||
with the same artifact names.
|
||||
|
||||
## Contributing
|
||||
|
||||
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
||||
The project requires a lot of additional work. The most important thing we need is a feedback about what features are
|
||||
required the most. Feel free to create feature requests. We are also welcome to code contributions,
|
||||
especially in issues marked with
|
||||
[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label.
|
||||
|
@ -12,16 +12,22 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
*/
|
||||
public class DerivativeStructureField(
|
||||
public val order: Int,
|
||||
private val bindings: Map<Symbol, Double>
|
||||
bindings: Map<Symbol, Double>,
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
|
||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
|
||||
public val numberOfVariables: Int = bindings.size
|
||||
|
||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
|
||||
|
||||
/**
|
||||
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||
*/
|
||||
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||
public inner class DerivativeStructureSymbol(
|
||||
size: Int,
|
||||
index: Int,
|
||||
symbol: Symbol,
|
||||
value: Double,
|
||||
) : DerivativeStructure(size, order, index, value), Symbol {
|
||||
override val identity: String = symbol.identity
|
||||
override fun toString(): String = identity
|
||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||
@ -31,27 +37,26 @@ public class DerivativeStructureField(
|
||||
/**
|
||||
* Identity-based symbol bindings map
|
||||
*/
|
||||
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
|
||||
key.identity to DerivativeStructureSymbol(key, value)
|
||||
}
|
||||
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
|
||||
key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
|
||||
}.toMap()
|
||||
|
||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
|
||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value)
|
||||
|
||||
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||
|
||||
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||
|
||||
//public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||
override fun symbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value))
|
||||
|
||||
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
||||
return derivative(mapOf(parameter to order))
|
||||
public fun DerivativeStructure.derivative(symbols: List<Symbol>): Double {
|
||||
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||
val ordersCount = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
|
||||
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList())
|
||||
|
||||
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
|
||||
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||
|
||||
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||
@ -97,6 +102,7 @@ public class DerivativeStructureField(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A constructs that creates a derivative structure with required order on-demand
|
||||
*/
|
||||
@ -109,7 +115,7 @@ public class DerivativeStructureExpression(
|
||||
/**
|
||||
* Get the derivative expression with given orders
|
||||
*/
|
||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
||||
with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) }
|
||||
}
|
||||
}
|
||||
|
@ -5,14 +5,15 @@ import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFails
|
||||
|
||||
internal inline fun <R> diff(
|
||||
internal inline fun diff(
|
||||
order: Int,
|
||||
vararg parameters: Pair<Symbol, Double>,
|
||||
block: DerivativeStructureField.() -> R,
|
||||
): R {
|
||||
block: DerivativeStructureField.() -> Unit,
|
||||
): Unit {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||
DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||
}
|
||||
|
||||
internal class AutoDiffTest {
|
||||
@ -21,13 +22,16 @@ internal class AutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun derivativeStructureFieldTest() {
|
||||
val res: Double = diff(3, x to 1.0, y to 1.0) {
|
||||
diff(2, x to 1.0, y to 1.0) {
|
||||
val x = bind(x)//by binding()
|
||||
val y = symbol("y")
|
||||
val z = x * (-sin(x * y) + y)
|
||||
z.derivative(x)
|
||||
val z = x * (-sin(x * y) + y) + 2.0
|
||||
println(z.derivative(x))
|
||||
println(z.derivative(y,x))
|
||||
assertEquals(z.derivative(x, y), z.derivative(y, x))
|
||||
//check that improper order cause failure
|
||||
assertFails { z.derivative(x,x,y) }
|
||||
}
|
||||
println(res)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -40,5 +44,7 @@ internal class AutoDiffTest {
|
||||
|
||||
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
|
||||
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import kscience.kmath.stat.Distribution
|
||||
import kscience.kmath.stat.Fitting
|
||||
import kscience.kmath.stat.RandomGenerator
|
||||
import kscience.kmath.stat.normal
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import org.junit.jupiter.api.Test
|
||||
import kotlin.math.pow
|
||||
|
||||
@ -53,7 +52,7 @@ internal class OptimizeTest {
|
||||
it.pow(2) + it + 1 + chain.nextDouble()
|
||||
}
|
||||
val yErr = x.map { sigma }
|
||||
val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
||||
val chi2 = Fitting.chiSquared(x, y, yErr) { x ->
|
||||
val cWithDefault = bindOrNull(c) ?: one
|
||||
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ The core features of KMath:
|
||||
|
||||
> #### Artifact:
|
||||
>
|
||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
|
||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-3`.
|
||||
>
|
||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
||||
>
|
||||
@ -30,7 +30,7 @@ The core features of KMath:
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
|
||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-3'
|
||||
> }
|
||||
> ```
|
||||
> **Gradle Kotlin DSL:**
|
||||
@ -44,6 +44,6 @@ The core features of KMath:
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-3")
|
||||
> }
|
||||
> ```
|
||||
|
@ -1,3 +1,5 @@
|
||||
import ru.mipt.npm.gradle.Maturity
|
||||
|
||||
plugins {
|
||||
id("ru.mipt.npm.mpp")
|
||||
id("ru.mipt.npm.native")
|
||||
@ -11,36 +13,42 @@ kotlin.sourceSets.commonMain {
|
||||
|
||||
readme {
|
||||
description = "Core classes, algebra definitions, basic linear algebra"
|
||||
maturity = ru.mipt.npm.gradle.Maturity.DEVELOPMENT
|
||||
maturity = Maturity.DEVELOPMENT
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
|
||||
feature(
|
||||
id = "algebras",
|
||||
description = "Algebraic structures: contexts and elements",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "nd",
|
||||
description = "Many-dimensional structures",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "buffers",
|
||||
description = "One-dimensional structure",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "expressions",
|
||||
description = "Functional Expressions",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/expressions"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "domains",
|
||||
description = "Domains",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/domains"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "autodif",
|
||||
description = "Automatic differentiation",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -3,20 +3,18 @@ package kscience.kmath.expressions
|
||||
/**
|
||||
* An expression that provides derivatives
|
||||
*/
|
||||
public interface DifferentiableExpression<T> : Expression<T>{
|
||||
public fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>?
|
||||
public interface DifferentiableExpression<T> : Expression<T> {
|
||||
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
|
||||
}
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(orders: Map<Symbol, Int>): Expression<T> =
|
||||
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
|
||||
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
|
||||
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||
derivative(mapOf(*orders))
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
|
||||
derivative(symbols.toList())
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||
derivative(StringSymbol(name) to 1)
|
||||
derivative(StringSymbol(name))
|
||||
|
||||
/**
|
||||
* A [DifferentiableExpression] that defines only first derivatives
|
||||
@ -25,8 +23,8 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
||||
|
||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||
|
||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? {
|
||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
|
||||
val dSymbol = symbols.firstOrNull() ?: return null
|
||||
return derivativeOrNull(dSymbol)
|
||||
}
|
||||
}
|
||||
|
@ -35,20 +35,27 @@ public fun interface Expression<T> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoke an expression without parameters
|
||||
* Calls this expression without providing any arguments.
|
||||
*
|
||||
* @return a value.
|
||||
*/
|
||||
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||
//This method exists to avoid resolution ambiguity of vararg methods
|
||||
|
||||
/**
|
||||
* Calls this expression from arguments.
|
||||
*
|
||||
* @param pairs the pair of arguments' names to values.
|
||||
* @return the value.
|
||||
* @param pairs the pairs of arguments to values.
|
||||
* @return a value.
|
||||
*/
|
||||
@JvmName("callBySymbol")
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
||||
|
||||
/**
|
||||
* Calls this expression from arguments.
|
||||
*
|
||||
* @param pairs the pairs of arguments' names to values.
|
||||
* @return a value.
|
||||
*/
|
||||
@JvmName("callByString")
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||
@ -61,7 +68,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||
* @param E type of the actual expression state
|
||||
*/
|
||||
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
||||
|
||||
/**
|
||||
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
|
||||
*/
|
||||
@ -87,7 +93,7 @@ public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
|
||||
/**
|
||||
* A delegate to create a symbol with a string identity in this scope
|
||||
*/
|
||||
public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty { thisRef, property ->
|
||||
public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty { _, property ->
|
||||
StringSymbol(property.name)
|
||||
}
|
||||
|
||||
@ -96,4 +102,4 @@ public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty {
|
||||
*/
|
||||
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> = ReadOnlyProperty { _, property ->
|
||||
bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
|
||||
}
|
||||
}
|
||||
|
@ -74,9 +74,9 @@ public interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : Math
|
||||
/**
|
||||
* The element of [Ring].
|
||||
*
|
||||
* @param T the type of space operation results.
|
||||
* @param T the type of ring operation results.
|
||||
* @param I self type of the element. Needed for static type checking.
|
||||
* @param R the type of space.
|
||||
* @param R the type of ring.
|
||||
*/
|
||||
public interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
||||
/**
|
||||
@ -91,7 +91,7 @@ public interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceEl
|
||||
/**
|
||||
* The element of [Field].
|
||||
*
|
||||
* @param T the type of space operation results.
|
||||
* @param T the type of field operation results.
|
||||
* @param I self type of the element. Needed for static type checking.
|
||||
* @param F the type of field.
|
||||
*/
|
||||
|
@ -73,7 +73,7 @@ public interface NDAlgebra<T, C, N : NDStructure<T>> {
|
||||
public fun check(vararg elements: N): Array<out N> = elements
|
||||
.map(NDStructure<T>::shape)
|
||||
.singleOrNull { !shape.contentEquals(it) }
|
||||
?.let { throw ShapeMismatchException(shape, it) }
|
||||
?.let<IntArray, Array<out N>> { throw ShapeMismatchException(shape, it) }
|
||||
?: elements
|
||||
|
||||
/**
|
||||
|
82
kmath-nd4j/README.md
Normal file
82
kmath-nd4j/README.md
Normal file
@ -0,0 +1,82 @@
|
||||
# ND4J NDStructure implementation (`kmath-nd4j`)
|
||||
|
||||
This subproject implements the following features:
|
||||
|
||||
- [nd4jarraystrucure](src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray
|
||||
- [nd4jarrayrings](src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long
|
||||
- [nd4jarrayfields](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double
|
||||
|
||||
|
||||
> #### Artifact:
|
||||
>
|
||||
> This module artifact: `kscience.kmath:kmath-nd4j:0.2.0-dev-3`.
|
||||
>
|
||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-nd4j/_latestVersion)
|
||||
>
|
||||
> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-nd4j/_latestVersion)
|
||||
>
|
||||
> **Gradle:**
|
||||
>
|
||||
> ```gradle
|
||||
> repositories {
|
||||
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation 'kscience.kmath:kmath-nd4j:0.2.0-dev-3'
|
||||
> }
|
||||
> ```
|
||||
> **Gradle Kotlin DSL:**
|
||||
>
|
||||
> ```kotlin
|
||||
> repositories {
|
||||
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation("kscience.kmath:kmath-nd4j:0.2.0-dev-3")
|
||||
> }
|
||||
> ```
|
||||
|
||||
## Examples
|
||||
|
||||
NDStructure wrapper for INDArray:
|
||||
|
||||
```kotlin
|
||||
import org.nd4j.linalg.factory.*
|
||||
import scientifik.kmath.nd4j.*
|
||||
import scientifik.kmath.structures.*
|
||||
|
||||
val array = Nd4j.ones(2, 2).asRealStructure()
|
||||
println(array[0, 0]) // 1.0
|
||||
array[intArrayOf(0, 0)] = 24.0
|
||||
println(array[0, 0]) // 24.0
|
||||
```
|
||||
|
||||
Fast element-wise and in-place arithmetics for INDArray:
|
||||
|
||||
```kotlin
|
||||
import org.nd4j.linalg.factory.*
|
||||
import scientifik.kmath.nd4j.*
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
val field = RealNd4jArrayField(intArrayOf(2, 2))
|
||||
val array = Nd4j.rand(2, 2).asRealStructure()
|
||||
|
||||
val res = field {
|
||||
(25.0 / array + 20) * 4
|
||||
}
|
||||
|
||||
println(res.ndArray)
|
||||
// [[ 250.6449, 428.5840],
|
||||
// [ 269.7913, 202.2077]]
|
||||
```
|
||||
|
||||
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).
|
37
kmath-nd4j/build.gradle.kts
Normal file
37
kmath-nd4j/build.gradle.kts
Normal file
@ -0,0 +1,37 @@
|
||||
import ru.mipt.npm.gradle.Maturity
|
||||
|
||||
plugins {
|
||||
id("ru.mipt.npm.jvm")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api("org.nd4j:nd4j-api:1.0.0-beta7")
|
||||
testImplementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
|
||||
testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||
testImplementation("org.slf4j:slf4j-simple:1.7.30")
|
||||
}
|
||||
|
||||
readme {
|
||||
description = "ND4J NDStructure implementation and according NDAlgebra classes"
|
||||
maturity = Maturity.EXPERIMENTAL
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
|
||||
feature(
|
||||
id = "nd4jarraystructure",
|
||||
description = "NDStructure wrapper for INDArray",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "nd4jarrayrings",
|
||||
description = "Rings over Nd4jArrayStructure of Int and Long",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "nd4jarrayfields",
|
||||
description = "Fields over Nd4jArrayStructure of Float and Double",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt"
|
||||
)
|
||||
}
|
43
kmath-nd4j/docs/README-TEMPLATE.md
Normal file
43
kmath-nd4j/docs/README-TEMPLATE.md
Normal file
@ -0,0 +1,43 @@
|
||||
# ND4J NDStructure implementation (`kmath-nd4j`)
|
||||
|
||||
This subproject implements the following features:
|
||||
|
||||
${features}
|
||||
|
||||
${artifact}
|
||||
|
||||
## Examples
|
||||
|
||||
NDStructure wrapper for INDArray:
|
||||
|
||||
```kotlin
|
||||
import org.nd4j.linalg.factory.*
|
||||
import scientifik.kmath.nd4j.*
|
||||
import scientifik.kmath.structures.*
|
||||
|
||||
val array = Nd4j.ones(2, 2).asRealStructure()
|
||||
println(array[0, 0]) // 1.0
|
||||
array[intArrayOf(0, 0)] = 24.0
|
||||
println(array[0, 0]) // 24.0
|
||||
```
|
||||
|
||||
Fast element-wise and in-place arithmetics for INDArray:
|
||||
|
||||
```kotlin
|
||||
import org.nd4j.linalg.factory.*
|
||||
import scientifik.kmath.nd4j.*
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
val field = RealNd4jArrayField(intArrayOf(2, 2))
|
||||
val array = Nd4j.rand(2, 2).asRealStructure()
|
||||
|
||||
val res = field {
|
||||
(25.0 / array + 20) * 4
|
||||
}
|
||||
|
||||
println(res.ndArray)
|
||||
// [[ 250.6449, 428.5840],
|
||||
// [ 269.7913, 202.2077]]
|
||||
```
|
||||
|
||||
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).
|
@ -0,0 +1,288 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.NDAlgebra
|
||||
import kscience.kmath.structures.NDField
|
||||
import kscience.kmath.structures.NDRing
|
||||
import kscience.kmath.structures.NDSpace
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
|
||||
/**
|
||||
* Represents [NDAlgebra] over [Nd4jArrayAlgebra].
|
||||
*
|
||||
* @param T the type of ND-structure element.
|
||||
* @param C the type of the element context.
|
||||
*/
|
||||
public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>> {
|
||||
/**
|
||||
* Wraps [INDArray] to [N].
|
||||
*/
|
||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||
|
||||
public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||
val struct = Nd4j.create(*shape)!!.wrap()
|
||||
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
|
||||
return struct
|
||||
}
|
||||
|
||||
public override fun map(arg: Nd4jArrayStructure<T>, transform: C.(T) -> T): Nd4jArrayStructure<T> {
|
||||
check(arg)
|
||||
val newStruct = arg.ndArray.dup().wrap()
|
||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
||||
return newStruct
|
||||
}
|
||||
|
||||
public override fun mapIndexed(
|
||||
arg: Nd4jArrayStructure<T>,
|
||||
transform: C.(index: IntArray, T) -> T
|
||||
): Nd4jArrayStructure<T> {
|
||||
check(arg)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
||||
return new
|
||||
}
|
||||
|
||||
public override fun combine(
|
||||
a: Nd4jArrayStructure<T>,
|
||||
b: Nd4jArrayStructure<T>,
|
||||
transform: C.(T, T) -> T
|
||||
): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||
return new
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDSpace] over [Nd4jArrayStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param S the type of space of structure elements.
|
||||
*/
|
||||
public interface Nd4jArraySpace<T, S> : NDSpace<T, S, Nd4jArrayStructure<T>>,
|
||||
Nd4jArrayAlgebra<T, S> where S : Space<T> {
|
||||
public override val zero: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
|
||||
public override fun add(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
return a.ndArray.add(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(this, b)
|
||||
return ndArray.sub(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.unaryMinus(): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.neg().wrap()
|
||||
}
|
||||
|
||||
public override fun multiply(a: Nd4jArrayStructure<T>, k: Number): Nd4jArrayStructure<T> {
|
||||
check(a)
|
||||
return a.ndArray.mul(k).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.div(k: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.div(k).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.times(k: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.mul(k).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDRing] over [Nd4jArrayStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param R the type of ring of structure elements.
|
||||
*/
|
||||
public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4jArraySpace<T, R> where R : Ring<T> {
|
||||
public override val one: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
|
||||
public override fun multiply(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
return a.ndArray.mul(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.sub(b).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.plus(b: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.add(b).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Number.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(b)
|
||||
return b.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDField] over [Nd4jArrayStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd4jArrayRing<T, F> where F : Field<T> {
|
||||
public override fun divide(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
return a.ndArray.div(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Number.div(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(b)
|
||||
return b.ndArray.rdiv(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDField] over [Nd4jArrayRealStructure].
|
||||
*/
|
||||
public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, RealField> {
|
||||
public override val elementContext: RealField
|
||||
get() = RealField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = check(asRealStructure())
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
|
||||
check(this)
|
||||
return ndArray.div(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Double>.plus(arg: Double): Nd4jArrayStructure<Double> {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Double>.minus(arg: Double): Nd4jArrayStructure<Double> {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Double>.times(arg: Double): Nd4jArrayStructure<Double> {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Double.div(arg: Nd4jArrayStructure<Double>): Nd4jArrayStructure<Double> {
|
||||
check(arg)
|
||||
return arg.ndArray.rdiv(this).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Double.minus(arg: Nd4jArrayStructure<Double>): Nd4jArrayStructure<Double> {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDField] over [Nd4jArrayStructure] of [Float].
|
||||
*/
|
||||
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField> {
|
||||
public override val elementContext: FloatField
|
||||
get() = FloatField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = check(asFloatStructure())
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> {
|
||||
check(this)
|
||||
return ndArray.div(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Float>.plus(arg: Float): Nd4jArrayStructure<Float> {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Float>.minus(arg: Float): Nd4jArrayStructure<Float> {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Float>.times(arg: Float): Nd4jArrayStructure<Float> {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Float.div(arg: Nd4jArrayStructure<Float>): Nd4jArrayStructure<Float> {
|
||||
check(arg)
|
||||
return arg.ndArray.rdiv(this).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Float.minus(arg: Nd4jArrayStructure<Float>): Nd4jArrayStructure<Float> {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDRing] over [Nd4jArrayIntStructure].
|
||||
*/
|
||||
public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing<Int, IntRing> {
|
||||
public override val elementContext: IntRing
|
||||
get() = IntRing
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = check(asIntStructure())
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Int>.minus(arg: Int): Nd4jArrayStructure<Int> {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Int>.times(arg: Int): Nd4jArrayStructure<Int> {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Int.minus(arg: Nd4jArrayStructure<Int>): Nd4jArrayStructure<Int> {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDRing] over [Nd4jArrayStructure] of [Long].
|
||||
*/
|
||||
public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing<Long, LongRing> {
|
||||
public override val elementContext: LongRing
|
||||
get() = LongRing
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = check(asLongStructure())
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Long>.minus(arg: Long): Nd4jArrayStructure<Long> {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<Long>.times(arg: Long): Nd4jArrayStructure<Long> {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Long.minus(arg: Nd4jArrayStructure<Long>): Nd4jArrayStructure<Long> {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.shape.Shape
|
||||
|
||||
private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iterator<IntArray> {
|
||||
private var i: Int = 0
|
||||
|
||||
override fun hasNext(): Boolean = i < iterateOver.length()
|
||||
|
||||
override fun next(): IntArray {
|
||||
val la = if (iterateOver.ordering() == 'c')
|
||||
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||
else
|
||||
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||
|
||||
return la.toIntArray()
|
||||
}
|
||||
}
|
||||
|
||||
internal fun INDArray.indicesIterator(): Iterator<IntArray> = Nd4jArrayIndicesIterator(this)
|
||||
|
||||
private sealed class Nd4jArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
||||
private var i: Int = 0
|
||||
|
||||
final override fun hasNext(): Boolean = i < iterateOver.length()
|
||||
|
||||
abstract fun getSingle(indices: LongArray): T
|
||||
|
||||
final override fun next(): Pair<IntArray, T> {
|
||||
val la = if (iterateOver.ordering() == 'c')
|
||||
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||
else
|
||||
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||
|
||||
return la.toIntArray() to getSingle(la)
|
||||
}
|
||||
}
|
||||
|
||||
private class Nd4jArrayRealIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Double>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.realIterator(): Iterator<Pair<IntArray, Double>> = Nd4jArrayRealIterator(this)
|
||||
|
||||
private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Long>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.longIterator(): Iterator<Pair<IntArray, Long>> = Nd4jArrayLongIterator(this)
|
||||
|
||||
private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Int>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
|
||||
}
|
||||
|
||||
internal fun INDArray.intIterator(): Iterator<Pair<IntArray, Int>> = Nd4jArrayIntIterator(this)
|
||||
|
||||
private class Nd4jArrayFloatIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Float>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.floatIterator(): Iterator<Pair<IntArray, Float>> = Nd4jArrayFloatIterator(this)
|
@ -0,0 +1,68 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
import kscience.kmath.structures.MutableNDStructure
|
||||
import kscience.kmath.structures.NDStructure
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
|
||||
/**
|
||||
* Represents a [NDStructure] wrapping an [INDArray] object.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public sealed class Nd4jArrayStructure<T> : MutableNDStructure<T> {
|
||||
/**
|
||||
* The wrapped [INDArray].
|
||||
*/
|
||||
public abstract val ndArray: INDArray
|
||||
|
||||
public override val shape: IntArray
|
||||
get() = ndArray.shape().toIntArray()
|
||||
|
||||
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
|
||||
public override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||
}
|
||||
|
||||
private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Int>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
|
||||
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||
*/
|
||||
public fun INDArray.asIntStructure(): Nd4jArrayStructure<Int> = Nd4jArrayIntStructure(this)
|
||||
|
||||
private data class Nd4jArrayLongStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Long>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
|
||||
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
||||
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||
*/
|
||||
public fun INDArray.asLongStructure(): Nd4jArrayStructure<Long> = Nd4jArrayLongStructure(this)
|
||||
|
||||
private data class Nd4jArrayRealStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Double>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
|
||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||
*/
|
||||
public fun INDArray.asRealStructure(): Nd4jArrayStructure<Double> = Nd4jArrayRealStructure(this)
|
||||
|
||||
private data class Nd4jArrayFloatStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Float>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
|
||||
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||
*/
|
||||
public fun INDArray.asFloatStructure(): Nd4jArrayStructure<Float> = Nd4jArrayFloatStructure(this)
|
4
kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt
Normal file
4
kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt
Normal file
@ -0,0 +1,4 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
|
||||
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }
|
@ -0,0 +1,42 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.fail
|
||||
|
||||
internal class Nd4jArrayAlgebraTest {
|
||||
@Test
|
||||
fun testProduce() {
|
||||
val res = (RealNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
|
||||
val expected = (Nd4j.create(2, 2) ?: fail()).asRealStructure()
|
||||
expected[intArrayOf(0, 0)] = 0.0
|
||||
expected[intArrayOf(0, 1)] = 1.0
|
||||
expected[intArrayOf(1, 0)] = 1.0
|
||||
expected[intArrayOf(1, 1)] = 2.0
|
||||
assertEquals(expected, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMap() {
|
||||
val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } }
|
||||
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||
expected[intArrayOf(0, 0)] = 3
|
||||
expected[intArrayOf(0, 1)] = 3
|
||||
expected[intArrayOf(1, 0)] = 3
|
||||
expected[intArrayOf(1, 1)] = 3
|
||||
assertEquals(expected, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 }
|
||||
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||
expected[intArrayOf(0, 0)] = 26
|
||||
expected[intArrayOf(0, 1)] = 26
|
||||
expected[intArrayOf(1, 0)] = 26
|
||||
expected[intArrayOf(1, 1)] = 26
|
||||
assertEquals(expected, res)
|
||||
}
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
import kscience.kmath.structures.get
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertNotEquals
|
||||
import kotlin.test.fail
|
||||
|
||||
internal class Nd4jArrayStructureTest {
|
||||
@Test
|
||||
fun testElements() {
|
||||
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct = nd.asRealStructure()
|
||||
val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
|
||||
assertEquals(listOf(1.0, 2.0, 3.0), res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testShape() {
|
||||
val nd = Nd4j.rand(10, 2, 3, 6) ?: fail()
|
||||
val struct = nd.asRealStructure()
|
||||
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testEquals() {
|
||||
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||
val struct1 = nd1.asRealStructure()
|
||||
assertEquals(struct1, struct1)
|
||||
assertNotEquals(struct1 as Any?, null)
|
||||
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||
val struct2 = nd2.asRealStructure()
|
||||
assertEquals(struct1, struct2)
|
||||
assertEquals(struct2, struct1)
|
||||
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||
val struct3 = nd3.asRealStructure()
|
||||
assertEquals(struct2, struct3)
|
||||
assertEquals(struct1, struct3)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testHashCode() {
|
||||
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
|
||||
val struct1 = nd1.asRealStructure()
|
||||
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
|
||||
val struct2 = nd2.asRealStructure()
|
||||
assertEquals(struct1.hashCode(), struct2.hashCode())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDimension() {
|
||||
val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
|
||||
val struct = nd.asFloatStructure()
|
||||
assertEquals(5, struct.dimension)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGet() {
|
||||
val nd = Nd4j.rand(10, 2, 3, 6)?:fail()
|
||||
val struct = nd.asIntStructure()
|
||||
assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSet() {
|
||||
val nd = Nd4j.rand(17, 12, 4, 8)!!
|
||||
val struct = nd.asLongStructure()
|
||||
struct[intArrayOf(1, 2, 3, 4)] = 777
|
||||
assertEquals(777, struct[1, 2, 3, 4])
|
||||
}
|
||||
}
|
@ -34,6 +34,7 @@ include(
|
||||
":kmath-commons",
|
||||
":kmath-viktor",
|
||||
":kmath-stat",
|
||||
":kmath-nd4j",
|
||||
":kmath-dimensions",
|
||||
":kmath-for-real",
|
||||
":kmath-geometry",
|
||||
|
Loading…
Reference in New Issue
Block a user