Merge remote-tracking branch 'origin/dev' into mp-samplers
# Conflicts: # examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt
This commit is contained in:
commit
26d81bddb5
@ -7,6 +7,7 @@
|
|||||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
||||||
- Automatic README generation for features (#139)
|
- Automatic README generation for features (#139)
|
||||||
- Native support for `memory`, `core` and `dimensions`
|
- Native support for `memory`, `core` and `dimensions`
|
||||||
|
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Package changed from `scientifik` to `kscience.kmath`.
|
- Package changed from `scientifik` to `kscience.kmath`.
|
||||||
@ -14,6 +15,7 @@
|
|||||||
- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`)
|
- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`)
|
||||||
- `Polynomial` secondary constructor made function.
|
- `Polynomial` secondary constructor made function.
|
||||||
- Kotlin version: 1.3.72 -> 1.4.20-M1
|
- Kotlin version: 1.3.72 -> 1.4.20-M1
|
||||||
|
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
@ -54,6 +54,8 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||||
to submit a feature request if you want something to be done first.
|
to submit a feature request if you want something to be done first.
|
||||||
|
|
||||||
|
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
||||||
|
@ -2,9 +2,9 @@ plugins {
|
|||||||
id("ru.mipt.npm.project")
|
id("ru.mipt.npm.project")
|
||||||
}
|
}
|
||||||
|
|
||||||
val kmathVersion by extra("0.2.0-dev-2")
|
val kmathVersion: String by extra("0.2.0-dev-2")
|
||||||
val bintrayRepo by extra("kscience")
|
val bintrayRepo: String by extra("kscience")
|
||||||
val githubProject by extra("kmath")
|
val githubProject: String by extra("kmath")
|
||||||
|
|
||||||
allprojects {
|
allprojects {
|
||||||
repositories {
|
repositories {
|
||||||
@ -22,6 +22,6 @@ subprojects {
|
|||||||
if (name.startsWith("kmath")) apply<ru.mipt.npm.gradle.KSciencePublishPlugin>()
|
if (name.startsWith("kmath")) apply<ru.mipt.npm.gradle.KSciencePublishPlugin>()
|
||||||
}
|
}
|
||||||
|
|
||||||
readme{
|
readme {
|
||||||
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||||
}
|
}
|
||||||
|
@ -6,10 +6,10 @@ back-ends. The new operations added as extensions to contexts instead of being m
|
|||||||
|
|
||||||
Two major contexts used for linear algebra and hyper-geometry:
|
Two major contexts used for linear algebra and hyper-geometry:
|
||||||
|
|
||||||
* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its typealias `Point` used for geometry).
|
* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its type alias `Point` used for geometry).
|
||||||
|
|
||||||
* `MatrixContext` forms a space-like context for 2d-structures. It does not store matrix size and therefore does not implement
|
* `MatrixContext` forms a space-like context for 2d-structures. It does not store matrix size and therefore does not implement
|
||||||
`Space` interface (it is not possible to create zero element without knowing the matrix size).
|
`Space` interface (it is impossible to create zero element without knowing the matrix size).
|
||||||
|
|
||||||
## Vector spaces
|
## Vector spaces
|
||||||
|
|
||||||
|
@ -26,9 +26,12 @@ dependencies {
|
|||||||
implementation(project(":kmath-prob"))
|
implementation(project(":kmath-prob"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
|
implementation(project(":kmath-ejml"))
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
|
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
||||||
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
|
implementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
|
"benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8")
|
||||||
|
"benchmarksImplementation"(sourceSets.main.get().output + sourceSets.main.get().runtimeClasspath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure benchmark
|
// Configure benchmark
|
||||||
|
@ -6,34 +6,33 @@ import org.openjdk.jmh.annotations.State
|
|||||||
import java.nio.IntBuffer
|
import java.nio.IntBuffer
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
class ArrayBenchmark {
|
internal class ArrayBenchmark {
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun benchmarkArrayRead() {
|
fun benchmarkArrayRead() {
|
||||||
var res = 0
|
var res = 0
|
||||||
for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.array[_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i]
|
for (i in 1..size) res += array[size - i]
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun benchmarkBufferRead() {
|
fun benchmarkBufferRead() {
|
||||||
var res = 0
|
var res = 0
|
||||||
for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.arrayBuffer.get(
|
for (i in 1..size) res += arrayBuffer.get(
|
||||||
_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i)
|
size - i
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun nativeBufferRead() {
|
fun nativeBufferRead() {
|
||||||
var res = 0
|
var res = 0
|
||||||
for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.nativeBuffer.get(
|
for (i in 1..size) res += nativeBuffer.get(
|
||||||
_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i)
|
size - i
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
const val size: Int = 1000
|
const val size: Int = 1000
|
||||||
val array: IntArray = IntArray(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) { it }
|
val array: IntArray = IntArray(size) { it }
|
||||||
val arrayBuffer: IntBuffer = IntBuffer.wrap(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.array)
|
val arrayBuffer: IntBuffer = IntBuffer.wrap(array)
|
||||||
|
val nativeBuffer: IntBuffer = IntBuffer.allocate(size).also { for (i in 0 until size) it.put(i, i) }
|
||||||
val nativeBuffer: IntBuffer = IntBuffer.allocate(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size).also {
|
|
||||||
for (i in 0 until _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) it.put(i, i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,11 +7,10 @@ import org.openjdk.jmh.annotations.Scope
|
|||||||
import org.openjdk.jmh.annotations.State
|
import org.openjdk.jmh.annotations.State
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
class BufferBenchmark {
|
internal class BufferBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun genericRealBufferReadWrite() {
|
fun genericRealBufferReadWrite() {
|
||||||
val buffer = RealBuffer(size){it.toDouble()}
|
val buffer = RealBuffer(size) { it.toDouble() }
|
||||||
|
|
||||||
(0 until size).forEach {
|
(0 until size).forEach {
|
||||||
buffer[it]
|
buffer[it]
|
||||||
@ -20,7 +19,7 @@ class BufferBenchmark {
|
|||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun complexBufferReadWrite() {
|
fun complexBufferReadWrite() {
|
||||||
val buffer = MutableBuffer.complex(size / 2){Complex(it.toDouble(), -it.toDouble())}
|
val buffer = MutableBuffer.complex(size / 2) { Complex(it.toDouble(), -it.toDouble()) }
|
||||||
|
|
||||||
(0 until size / 2).forEach {
|
(0 until size / 2).forEach {
|
||||||
buffer[it]
|
buffer[it]
|
||||||
@ -28,6 +27,6 @@ class BufferBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
const val size = 100
|
const val size: Int = 100
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -7,7 +7,7 @@ import org.openjdk.jmh.annotations.Scope
|
|||||||
import org.openjdk.jmh.annotations.State
|
import org.openjdk.jmh.annotations.State
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
class NDFieldBenchmark {
|
internal class NDFieldBenchmark {
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun autoFieldAdd() {
|
fun autoFieldAdd() {
|
||||||
bufferedField {
|
bufferedField {
|
||||||
@ -40,11 +40,10 @@ class NDFieldBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
val dim = 1000
|
const val dim: Int = 1000
|
||||||
val n = 100
|
const val n: Int = 100
|
||||||
|
val bufferedField: BufferedNDField<Double, RealField> = NDField.auto(RealField, dim, dim)
|
||||||
val bufferedField = NDField.auto(RealField, dim, dim)
|
val specializedField: RealNDField = NDField.real(dim, dim)
|
||||||
val specializedField = NDField.real(dim, dim)
|
val genericField: BoxingNDField<Double, RealField> = NDField.boxing(RealField, dim, dim)
|
||||||
val genericField = NDField.boxing(RealField, dim, dim)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -9,9 +9,9 @@ import org.openjdk.jmh.annotations.Scope
|
|||||||
import org.openjdk.jmh.annotations.State
|
import org.openjdk.jmh.annotations.State
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
class ViktorBenchmark {
|
internal class ViktorBenchmark {
|
||||||
final val dim = 1000
|
final val dim: Int = 1000
|
||||||
final val n = 100
|
final val n: Int = 100
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
final val autoField: BufferedNDField<Double, RealField> = NDField.auto(RealField, dim, dim)
|
final val autoField: BufferedNDField<Double, RealField> = NDField.auto(RealField, dim, dim)
|
||||||
@ -42,7 +42,7 @@ class ViktorBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun realdFieldLog() {
|
fun realFieldLog() {
|
||||||
realField {
|
realField {
|
||||||
val fortyTwo = produce { 42.0 }
|
val fortyTwo = produce { 42.0 }
|
||||||
var res = one
|
var res = one
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
//package kscience.kmath.ast
|
package kscience.kmath.ast
|
||||||
//
|
//
|
||||||
//import kscience.kmath.asm.compile
|
//import kscience.kmath.asm.compile
|
||||||
//import kscience.kmath.expressions.Expression
|
//import kscience.kmath.expressions.Expression
|
||||||
|
@ -6,9 +6,9 @@ import kscience.kmath.chains.collectWithState
|
|||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler
|
import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler
|
||||||
|
|
||||||
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||||
|
|
||||||
fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState(), { it.copy() }) { chain ->
|
private fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState(), { it.copy() }) { chain ->
|
||||||
val next = chain.next()
|
val next = chain.next()
|
||||||
num++
|
num++
|
||||||
value += next
|
value += next
|
||||||
|
@ -0,0 +1,50 @@
|
|||||||
|
package kscience.kmath.linear
|
||||||
|
|
||||||
|
import kscience.kmath.commons.linear.CMMatrixContext
|
||||||
|
import kscience.kmath.commons.linear.inverse
|
||||||
|
import kscience.kmath.commons.linear.toCM
|
||||||
|
import kscience.kmath.ejml.EjmlMatrixContext
|
||||||
|
import kscience.kmath.ejml.inverse
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.structures.Matrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val random = Random(1224)
|
||||||
|
val dim = 100
|
||||||
|
//creating invertible matrix
|
||||||
|
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
|
||||||
|
val matrix = l dot u
|
||||||
|
|
||||||
|
val n = 5000 // iterations
|
||||||
|
|
||||||
|
MatrixContext.real {
|
||||||
|
repeat(50) { inverse(matrix) }
|
||||||
|
val inverseTime = measureTimeMillis { repeat(n) { inverse(matrix) } }
|
||||||
|
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||||
|
}
|
||||||
|
|
||||||
|
//commons-math
|
||||||
|
|
||||||
|
val commonsTime = measureTimeMillis {
|
||||||
|
CMMatrixContext {
|
||||||
|
val cm = matrix.toCM() //avoid overhead on conversion
|
||||||
|
repeat(n) { inverse(cm) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis")
|
||||||
|
|
||||||
|
val ejmlTime = measureTimeMillis {
|
||||||
|
(EjmlMatrixContext(RealField)) {
|
||||||
|
val km = matrix.toEjml() //avoid overhead on conversion
|
||||||
|
repeat(n) { inverse(km) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println("[ejml] Inversion of $n matrices $dim x $dim finished in $ejmlTime millis")
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
package kscience.kmath.linear
|
||||||
|
|
||||||
|
import kscience.kmath.commons.linear.CMMatrixContext
|
||||||
|
import kscience.kmath.commons.linear.toCM
|
||||||
|
import kscience.kmath.ejml.EjmlMatrixContext
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.structures.Matrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val random = Random(12224)
|
||||||
|
val dim = 1000
|
||||||
|
//creating invertible matrix
|
||||||
|
val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
|
||||||
|
// //warmup
|
||||||
|
// matrix1 dot matrix2
|
||||||
|
|
||||||
|
CMMatrixContext {
|
||||||
|
val cmMatrix1 = matrix1.toCM()
|
||||||
|
val cmMatrix2 = matrix2.toCM()
|
||||||
|
val cmTime = measureTimeMillis { cmMatrix1 dot cmMatrix2 }
|
||||||
|
println("CM implementation time: $cmTime")
|
||||||
|
}
|
||||||
|
|
||||||
|
(EjmlMatrixContext(RealField)) {
|
||||||
|
val ejmlMatrix1 = matrix1.toEjml()
|
||||||
|
val ejmlMatrix2 = matrix2.toEjml()
|
||||||
|
val ejmlTime = measureTimeMillis { ejmlMatrix1 dot ejmlMatrix2 }
|
||||||
|
println("EJML implementation time: $ejmlTime")
|
||||||
|
}
|
||||||
|
|
||||||
|
val genericTime = measureTimeMillis { val res = matrix1 dot matrix2 }
|
||||||
|
println("Generic implementation time: $genericTime")
|
||||||
|
}
|
@ -1,8 +1,6 @@
|
|||||||
package kscience.kmath.operations
|
package kscience.kmath.operations
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val res = BigIntField {
|
val res = BigIntField { number(1) * 2 }
|
||||||
number(1) * 2
|
|
||||||
}
|
|
||||||
println("bigint:$res")
|
println("bigint:$res")
|
||||||
}
|
}
|
@ -5,15 +5,19 @@ import kscience.kmath.structures.NDField
|
|||||||
import kscience.kmath.structures.complex
|
import kscience.kmath.structures.complex
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
|
// 2d element
|
||||||
val element = NDElement.complex(2, 2) { index: IntArray ->
|
val element = NDElement.complex(2, 2) { index: IntArray ->
|
||||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||||
}
|
}
|
||||||
|
println(element)
|
||||||
|
|
||||||
val compute = (NDField.complex(8)) {
|
// 1d element operation
|
||||||
|
val result = with(NDField.complex(8)) {
|
||||||
val a = produce { (it) -> i * it - it.toDouble() }
|
val a = produce { (it) -> i * it - it.toDouble() }
|
||||||
val b = 3
|
val b = 3
|
||||||
val c = Complex(1.0, 1.0)
|
val c = Complex(1.0, 1.0)
|
||||||
|
|
||||||
(a pow b) + c
|
(a pow b) + c
|
||||||
}
|
}
|
||||||
|
println(result)
|
||||||
}
|
}
|
||||||
|
@ -4,32 +4,30 @@ import kotlin.system.measureTimeMillis
|
|||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val n = 6000
|
val n = 6000
|
||||||
|
|
||||||
val array = DoubleArray(n * n) { 1.0 }
|
val array = DoubleArray(n * n) { 1.0 }
|
||||||
val buffer = RealBuffer(array)
|
val buffer = RealBuffer(array)
|
||||||
val strides = DefaultStrides(intArrayOf(n, n))
|
val strides = DefaultStrides(intArrayOf(n, n))
|
||||||
|
|
||||||
val structure = BufferNDStructure(strides, buffer)
|
val structure = BufferNDStructure(strides, buffer)
|
||||||
|
|
||||||
measureTimeMillis {
|
measureTimeMillis {
|
||||||
var res: Double = 0.0
|
var res = 0.0
|
||||||
strides.indices().forEach { res = structure[it] }
|
strides.indices().forEach { res = structure[it] }
|
||||||
} // warmup
|
} // warmup
|
||||||
|
|
||||||
val time1 = measureTimeMillis {
|
val time1 = measureTimeMillis {
|
||||||
var res: Double = 0.0
|
var res = 0.0
|
||||||
strides.indices().forEach { res = structure[it] }
|
strides.indices().forEach { res = structure[it] }
|
||||||
}
|
}
|
||||||
println("Structure reading finished in $time1 millis")
|
println("Structure reading finished in $time1 millis")
|
||||||
|
|
||||||
val time2 = measureTimeMillis {
|
val time2 = measureTimeMillis {
|
||||||
var res: Double = 0.0
|
var res = 0.0
|
||||||
strides.indices().forEach { res = buffer[strides.offset(it)] }
|
strides.indices().forEach { res = buffer[strides.offset(it)] }
|
||||||
}
|
}
|
||||||
println("Buffer reading finished in $time2 millis")
|
println("Buffer reading finished in $time2 millis")
|
||||||
|
|
||||||
val time3 = measureTimeMillis {
|
val time3 = measureTimeMillis {
|
||||||
var res: Double = 0.0
|
var res = 0.0
|
||||||
strides.indices().forEach { res = array[strides.offset(it)] }
|
strides.indices().forEach { res = array[strides.offset(it)] }
|
||||||
}
|
}
|
||||||
println("Array reading finished in $time3 millis")
|
println("Array reading finished in $time3 millis")
|
||||||
|
@ -4,24 +4,17 @@ import kotlin.system.measureTimeMillis
|
|||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val n = 6000
|
val n = 6000
|
||||||
|
|
||||||
val structure = NDStructure.build(intArrayOf(n, n), Buffer.Companion::auto) { 1.0 }
|
val structure = NDStructure.build(intArrayOf(n, n), Buffer.Companion::auto) { 1.0 }
|
||||||
|
|
||||||
structure.mapToBuffer { it + 1 } // warm-up
|
structure.mapToBuffer { it + 1 } // warm-up
|
||||||
|
val time1 = measureTimeMillis { val res = structure.mapToBuffer { it + 1 } }
|
||||||
val time1 = measureTimeMillis {
|
|
||||||
val res = structure.mapToBuffer { it + 1 }
|
|
||||||
}
|
|
||||||
println("Structure mapping finished in $time1 millis")
|
println("Structure mapping finished in $time1 millis")
|
||||||
|
|
||||||
val array = DoubleArray(n * n) { 1.0 }
|
val array = DoubleArray(n * n) { 1.0 }
|
||||||
|
|
||||||
val time2 = measureTimeMillis {
|
val time2 = measureTimeMillis {
|
||||||
val target = DoubleArray(n * n)
|
val target = DoubleArray(n * n)
|
||||||
val res = array.forEachIndexed { index, value ->
|
val res = array.forEachIndexed { index, value -> target[index] = value + 1 }
|
||||||
target[index] = value + 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
println("Array mapping finished in $time2 millis")
|
println("Array mapping finished in $time2 millis")
|
||||||
|
|
||||||
val buffer = RealBuffer(DoubleArray(n * n) { 1.0 })
|
val buffer = RealBuffer(DoubleArray(n * n) { 1.0 })
|
||||||
|
@ -6,7 +6,7 @@ import kscience.kmath.dimensions.DMatrixContext
|
|||||||
import kscience.kmath.dimensions.Dimension
|
import kscience.kmath.dimensions.Dimension
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
fun DMatrixContext<Double, RealField>.simple() {
|
private fun DMatrixContext<Double, RealField>.simple() {
|
||||||
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||||
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||||
|
|
||||||
@ -14,12 +14,11 @@ fun DMatrixContext<Double, RealField>.simple() {
|
|||||||
m1.transpose() + m2
|
m1.transpose() + m2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private object D5 : Dimension {
|
||||||
object D5 : Dimension {
|
|
||||||
override val dim: UInt = 5u
|
override val dim: UInt = 5u
|
||||||
}
|
}
|
||||||
|
|
||||||
fun DMatrixContext<Double, RealField>.custom() {
|
private fun DMatrixContext<Double, RealField>.custom() {
|
||||||
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
||||||
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
||||||
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
||||||
|
@ -14,7 +14,6 @@ kotlin.sourceSets {
|
|||||||
implementation("org.ow2.asm:asm:8.0.1")
|
implementation("org.ow2.asm:asm:8.0.1")
|
||||||
implementation("org.ow2.asm:asm-commons:8.0.1")
|
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
||||||
implementation(kotlin("reflect"))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@ import kscience.kmath.ast.MST
|
|||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles given MST to an Expression using AST compiler.
|
* Compiles given MST to an Expression using AST compiler.
|
||||||
@ -18,7 +17,8 @@ import kotlin.reflect.KClass
|
|||||||
* @return the compiled expression.
|
* @return the compiled expression.
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
@PublishedApi
|
||||||
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||||
is MST.Symbolic -> {
|
is MST.Symbolic -> {
|
||||||
val symbol = try {
|
val symbol = try {
|
||||||
@ -61,11 +61,12 @@ public fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expr
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik.
|
* @author Alexander Nozik.
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
|
||||||
|
mst.compileWith(T::class.java, this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimizes performance of an [MstExpression] using ASM codegen.
|
* Optimizes performance of an [MstExpression] using ASM codegen.
|
||||||
*
|
*
|
||||||
* @author Alexander Nozik.
|
* @author Alexander Nozik.
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra)
|
||||||
|
@ -10,7 +10,6 @@ import org.objectweb.asm.Opcodes.*
|
|||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import java.util.stream.Collectors
|
import java.util.stream.Collectors
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
||||||
@ -23,7 +22,7 @@ import kotlin.reflect.KClass
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal class AsmBuilder<T> internal constructor(
|
internal class AsmBuilder<T> internal constructor(
|
||||||
private val classOfT: KClass<*>,
|
private val classOfT: Class<*>,
|
||||||
private val algebra: Algebra<T>,
|
private val algebra: Algebra<T>,
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||||
@ -32,7 +31,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||||
*/
|
*/
|
||||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -43,7 +42,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* ASM Type for [algebra].
|
* ASM Type for [algebra].
|
||||||
*/
|
*/
|
||||||
private val tAlgebraType: Type = algebra::class.asm
|
private val tAlgebraType: Type = algebra.javaClass.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [T].
|
* ASM type for [T].
|
||||||
@ -55,16 +54,6 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||||
|
|
||||||
/**
|
|
||||||
* Index of `this` variable in invoke method of the built subclass.
|
|
||||||
*/
|
|
||||||
private val invokeThisVar: Int = 0
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Index of `arguments` variable in invoke method of the built subclass.
|
|
||||||
*/
|
|
||||||
private val invokeArgumentsVar: Int = 1
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of constants to provide to the subclass.
|
* List of constants to provide to the subclass.
|
||||||
*/
|
*/
|
||||||
@ -76,22 +65,22 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State if this [AsmBuilder] needs to generate constants field.
|
* States whether this [AsmBuilder] needs to generate constants field.
|
||||||
*/
|
*/
|
||||||
private var hasConstants: Boolean = true
|
private var hasConstants: Boolean = true
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
* States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||||
*/
|
*/
|
||||||
internal var primitiveMode: Boolean = false
|
internal var primitiveMode: Boolean = false
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
* Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
*/
|
*/
|
||||||
internal var primitiveMask: Type = OBJECT_TYPE
|
internal var primitiveMask: Type = OBJECT_TYPE
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
* Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
*/
|
*/
|
||||||
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
||||||
|
|
||||||
@ -103,7 +92,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Stack of useful objects types on stack expected by algebra calls.
|
* Stack of useful objects types on stack expected by algebra calls.
|
||||||
*/
|
*/
|
||||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque(listOf(tType))
|
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>(1).also { it.push(tType) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The cache for instance built by this builder.
|
* The cache for instance built by this builder.
|
||||||
@ -361,7 +350,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* from it).
|
* from it).
|
||||||
*/
|
*/
|
||||||
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
||||||
val boxed = value::class.asm
|
val boxed = value.javaClass.asm
|
||||||
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
||||||
|
|
||||||
if (primitive != null) {
|
if (primitive != null) {
|
||||||
@ -475,17 +464,27 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
||||||
|
|
||||||
internal companion object {
|
internal companion object {
|
||||||
|
/**
|
||||||
|
* Index of `this` variable in invoke method of the built subclass.
|
||||||
|
*/
|
||||||
|
private const val invokeThisVar: Int = 0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Index of `arguments` variable in invoke method of the built subclass.
|
||||||
|
*/
|
||||||
|
private const val invokeArgumentsVar: Int = 1
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
||||||
*/
|
*/
|
||||||
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
|
private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
|
||||||
hashMapOf(
|
hashMapOf(
|
||||||
java.lang.Byte::class to Type.BYTE_TYPE,
|
java.lang.Byte::class.java to Type.BYTE_TYPE,
|
||||||
java.lang.Short::class to Type.SHORT_TYPE,
|
java.lang.Short::class.java to Type.SHORT_TYPE,
|
||||||
java.lang.Integer::class to Type.INT_TYPE,
|
java.lang.Integer::class.java to Type.INT_TYPE,
|
||||||
java.lang.Long::class to Type.LONG_TYPE,
|
java.lang.Long::class.java to Type.LONG_TYPE,
|
||||||
java.lang.Float::class to Type.FLOAT_TYPE,
|
java.lang.Float::class.java to Type.FLOAT_TYPE,
|
||||||
java.lang.Double::class to Type.DOUBLE_TYPE
|
java.lang.Double::class.java to Type.DOUBLE_TYPE
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,43 +522,43 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
||||||
*/
|
*/
|
||||||
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [Expression].
|
* ASM type for [Expression].
|
||||||
*/
|
*/
|
||||||
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
|
internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.Number].
|
* ASM type for [java.lang.Number].
|
||||||
*/
|
*/
|
||||||
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
|
internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.util.Map].
|
* ASM type for [java.util.Map].
|
||||||
*/
|
*/
|
||||||
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
|
internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.Object].
|
* ASM type for [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
|
internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for array of [java.lang.Object].
|
* ASM type for array of [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
||||||
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
|
internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [Algebra].
|
* ASM type for [Algebra].
|
||||||
*/
|
*/
|
||||||
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
|
internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.String].
|
* ASM type for [java.lang.String].
|
||||||
*/
|
*/
|
||||||
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for MapIntrinsics.
|
* ASM type for MapIntrinsics.
|
||||||
|
@ -10,9 +10,9 @@ import org.objectweb.asm.*
|
|||||||
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import java.lang.reflect.Method
|
import java.lang.reflect.Method
|
||||||
|
import java.util.*
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||||
hashMapOf(
|
hashMapOf(
|
||||||
@ -26,12 +26,12 @@ private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns ASM [Type] for given [KClass].
|
* Returns ASM [Type] for given [Class].
|
||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal val KClass<*>.asm: Type
|
internal inline val Class<*>.asm: Type
|
||||||
get() = Type.getType(java)
|
get() = Type.getType(this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
||||||
@ -140,7 +140,7 @@ private fun <T> AsmBuilder<T>.buildExpectationStack(
|
|||||||
if (specific != null)
|
if (specific != null)
|
||||||
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
||||||
else
|
else
|
||||||
repeat(arity) { expectationStack.push(tType) }
|
expectationStack.addAll(Collections.nCopies(arity, tType))
|
||||||
|
|
||||||
return specific != null
|
return specific != null
|
||||||
}
|
}
|
||||||
@ -169,7 +169,7 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
|||||||
val arity = parameterTypes.size
|
val arity = parameterTypes.size
|
||||||
val theName = methodNameAdapters[name to arity] ?: name
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
||||||
val owner = context::class.asm
|
val owner = context.javaClass.asm
|
||||||
|
|
||||||
invokeAlgebraOperation(
|
invokeAlgebraOperation(
|
||||||
owner = owner.internalName,
|
owner = owner.internalName,
|
||||||
|
@ -7,6 +7,7 @@ import com.github.h0tk3y.betterParse.grammar.parser
|
|||||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||||
import com.github.h0tk3y.betterParse.lexer.Token
|
import com.github.h0tk3y.betterParse.lexer.Token
|
||||||
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.literalToken
|
||||||
import com.github.h0tk3y.betterParse.lexer.regexToken
|
import com.github.h0tk3y.betterParse.lexer.regexToken
|
||||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||||
import com.github.h0tk3y.betterParse.parser.Parser
|
import com.github.h0tk3y.betterParse.parser.Parser
|
||||||
@ -23,14 +24,14 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
|||||||
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
||||||
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
|
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
|
||||||
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
|
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
|
||||||
private val lpar: Token by regexToken("\\(")
|
private val lpar: Token by literalToken("(")
|
||||||
private val rpar: Token by regexToken("\\)")
|
private val rpar: Token by literalToken(")")
|
||||||
private val comma: Token by regexToken(",")
|
private val comma: Token by literalToken(",")
|
||||||
private val mul: Token by regexToken("\\*")
|
private val mul: Token by literalToken("*")
|
||||||
private val pow: Token by regexToken("\\^")
|
private val pow: Token by literalToken("^")
|
||||||
private val div: Token by regexToken("/")
|
private val div: Token by literalToken("/")
|
||||||
private val minus: Token by regexToken("-")
|
private val minus: Token by literalToken("-")
|
||||||
private val plus: Token by regexToken("\\+")
|
private val plus: Token by literalToken("+")
|
||||||
private val ws: Token by regexToken("\\s+", ignore = true)
|
private val ws: Token by regexToken("\\s+", ignore = true)
|
||||||
|
|
||||||
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||||
|
@ -9,14 +9,17 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
|||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field wrapping commons-math derivative structures
|
* A field over commons-math [DerivativeStructure].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @property parameters The map of free parameters.
|
||||||
*/
|
*/
|
||||||
public class DerivativeStructureField(
|
public class DerivativeStructureField(
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
public val parameters: Map<String, Double>
|
public val parameters: Map<String, Double>
|
||||||
) : ExtendedField<DerivativeStructure> {
|
) : ExtendedField<DerivativeStructure> {
|
||||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
|
public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
|
||||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
|
public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
|
||||||
|
|
||||||
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
|
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
|
||||||
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
||||||
|
@ -4,7 +4,7 @@ import kscience.kmath.operations.*
|
|||||||
|
|
||||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
||||||
Expression<T> {
|
Expression<T> {
|
||||||
public override operator fun invoke(arguments: Map<String, T>): T =
|
override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
context.unaryOperation(name, expr.invoke(arguments))
|
context.unaryOperation(name, expr.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -14,17 +14,17 @@ internal class FunctionalBinaryOperation<T>(
|
|||||||
val first: Expression<T>,
|
val first: Expression<T>,
|
||||||
val second: Expression<T>
|
val second: Expression<T>
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
public override operator fun invoke(arguments: Map<String, T>): T =
|
override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||||
public override operator fun invoke(arguments: Map<String, T>): T =
|
override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||||
public override operator fun invoke(arguments: Map<String, T>): T = value
|
override operator fun invoke(arguments: Map<String, T>): T = value
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalConstProductExpression<T>(
|
internal class FunctionalConstProductExpression<T>(
|
||||||
@ -32,7 +32,7 @@ internal class FunctionalConstProductExpression<T>(
|
|||||||
private val expr: Expression<T>,
|
private val expr: Expression<T>,
|
||||||
val const: Number
|
val const: Number
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
public override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -139,16 +139,27 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
|
|||||||
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||||
FunctionalExpressionField<T, A>(algebra),
|
FunctionalExpressionField<T, A>(algebra),
|
||||||
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||||
public override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
public override fun sin(arg: Expression<T>): Expression<T> =
|
||||||
public override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||||
public override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
|
||||||
public override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
public override fun cos(arg: Expression<T>): Expression<T> =
|
||||||
public override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||||
|
|
||||||
|
public override fun asin(arg: Expression<T>): Expression<T> =
|
||||||
|
unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||||
|
|
||||||
|
public override fun acos(arg: Expression<T>): Expression<T> =
|
||||||
|
unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||||
|
|
||||||
|
public override fun atan(arg: Expression<T>): Expression<T> =
|
||||||
|
unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||||
|
|
||||||
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
||||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||||
|
|
||||||
public override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
public override fun exp(arg: Expression<T>): Expression<T> =
|
||||||
|
unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||||
|
|
||||||
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||||
|
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
@ -24,7 +24,11 @@ public interface FeaturedMatrix<T : Any> : Matrix<T> {
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
|
public inline fun Structure2D.Companion.real(
|
||||||
|
rows: Int,
|
||||||
|
columns: Int,
|
||||||
|
initializer: (Int, Int) -> Double
|
||||||
|
): Matrix<Double> =
|
||||||
MatrixContext.real.produce(rows, columns, initializer)
|
MatrixContext.real.produce(rows, columns, initializer)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -18,20 +18,52 @@ public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
|||||||
*/
|
*/
|
||||||
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||||
|
|
||||||
|
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): Matrix<T> = when (operation) {
|
||||||
|
"dot" -> left dot right
|
||||||
|
else -> super.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the dot product of this matrix and another one.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param other the multiplier.
|
||||||
|
* @return the dot product.
|
||||||
|
*/
|
||||||
public infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
public infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the dot product of this matrix and a vector.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param vector the multiplier.
|
||||||
|
* @return the dot product.
|
||||||
|
*/
|
||||||
public infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
public infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies a matrix by its element.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param value the multiplier.
|
||||||
|
* @receiver the product.
|
||||||
|
*/
|
||||||
public operator fun Matrix<T>.times(value: T): Matrix<T>
|
public operator fun Matrix<T>.times(value: T): Matrix<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies an element by a matrix of it.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param value the multiplier.
|
||||||
|
* @receiver the product.
|
||||||
|
*/
|
||||||
public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* Non-boxing double matrix
|
* Non-boxing double matrix
|
||||||
*/
|
*/
|
||||||
public val real: RealMatrixContext
|
public val real: RealMatrixContext = RealMatrixContext
|
||||||
get() = RealMatrixContext
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structured matrix with custom buffer
|
* A structured matrix with custom buffer
|
||||||
@ -60,7 +92,7 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
*/
|
*/
|
||||||
public fun point(size: Int, initializer: (Int) -> T): Point<T>
|
public fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||||
|
|
||||||
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
public override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
|
||||||
@ -71,7 +103,7 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
public override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||||
|
|
||||||
@ -81,10 +113,10 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
public override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
|
produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
|
||||||
|
|
||||||
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
public override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
|
require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
|
||||||
"Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
|
"Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
|
||||||
}
|
}
|
||||||
@ -92,7 +124,7 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
|
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
public override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||||
require(rowNum == b.rowNum && colNum == b.colNum) {
|
require(rowNum == b.rowNum && colNum == b.colNum) {
|
||||||
"Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
|
"Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
|
||||||
}
|
}
|
||||||
@ -100,11 +132,11 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
|
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
public override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||||
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
|
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
|
||||||
|
|
||||||
public operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
public operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||||
|
|
||||||
override operator fun Matrix<T>.times(value: T): Matrix<T> =
|
public override operator fun Matrix<T>.times(value: T): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
|
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
package kscience.kmath.misc
|
package kscience.kmath.misc
|
||||||
|
|
||||||
import kscience.kmath.linear.Point
|
import kscience.kmath.linear.Point
|
||||||
import kscience.kmath.operations.ExtendedField
|
import kscience.kmath.operations.*
|
||||||
import kscience.kmath.operations.Field
|
|
||||||
import kscience.kmath.operations.invoke
|
|
||||||
import kscience.kmath.operations.sum
|
|
||||||
import kscience.kmath.structures.asBuffer
|
import kscience.kmath.structures.asBuffer
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
@ -17,23 +14,37 @@ import kotlin.contracts.contract
|
|||||||
/**
|
/**
|
||||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @property value The value of this variable.
|
||||||
*/
|
*/
|
||||||
public open class Variable<T : Any>(public val value: T)
|
public open class Variable<T : Any>(public val value: T)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents result of [deriv] call.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @param value the value of result.
|
||||||
|
* @property deriv The mapping of differentiated variables to their derivatives.
|
||||||
|
* @property context The field over [T].
|
||||||
|
*/
|
||||||
public class DerivationResult<T : Any>(
|
public class DerivationResult<T : Any>(
|
||||||
value: T,
|
value: T,
|
||||||
public val deriv: Map<Variable<T>, T>,
|
public val deriv: Map<Variable<T>, T>,
|
||||||
public val context: Field<T>
|
public val context: Field<T>
|
||||||
) : Variable<T>(value) {
|
) : Variable<T>(value) {
|
||||||
|
/**
|
||||||
|
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
||||||
|
*/
|
||||||
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* compute divergence
|
* Computes the divergence.
|
||||||
*/
|
*/
|
||||||
public fun div(): T = context { sum(deriv.values) }
|
public fun div(): T = context { sum(deriv.values) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute a gradient for variables in given order
|
* Computes the gradient for variables in given order.
|
||||||
*/
|
*/
|
||||||
public fun grad(vararg variables: Variable<T>): Point<T> {
|
public fun grad(vararg variables: Variable<T>): Point<T> {
|
||||||
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||||
@ -53,6 +64,9 @@ public class DerivationResult<T : Any>(
|
|||||||
* assertEquals(17.0, y.x) // the value of result (y)
|
* assertEquals(17.0, y.x) // the value of result (y)
|
||||||
* assertEquals(9.0, x.d) // dy/dx
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
* ```
|
* ```
|
||||||
|
*
|
||||||
|
* @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to.
|
||||||
|
* @return the result of differentiation.
|
||||||
*/
|
*/
|
||||||
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
|
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
@ -65,12 +79,15 @@ public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents field in context of which functions can be derived.
|
||||||
|
*/
|
||||||
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||||
public abstract val context: F
|
public abstract val context: F
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A variable accessing inner state of derivatives.
|
* A variable accessing inner state of derivatives.
|
||||||
* Use this function in inner builders to avoid creating additional derivative bindings
|
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||||
*/
|
*/
|
||||||
public abstract var Variable<T>.d: T
|
public abstract var Variable<T>.d: T
|
||||||
|
|
||||||
@ -87,6 +104,9 @@ public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>>
|
|||||||
*/
|
*/
|
||||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
public abstract fun variable(value: T): Variable<T>
|
public abstract fun variable(value: T): Variable<T>
|
||||||
|
|
||||||
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
||||||
|
@ -299,7 +299,7 @@ public class BigInt internal constructor(
|
|||||||
|
|
||||||
for (i in mag.indices) {
|
for (i in mag.indices) {
|
||||||
val cur: ULong = carry + mag[i].toULong() * x.toULong()
|
val cur: ULong = carry + mag[i].toULong() * x.toULong()
|
||||||
result[i] = (cur and BASE.toULong()).toUInt()
|
result[i] = (cur and BASE).toUInt()
|
||||||
carry = cur shr BASE_SIZE
|
carry = cur shr BASE_SIZE
|
||||||
}
|
}
|
||||||
result[resultLength - 1] = (carry and BASE).toUInt()
|
result[resultLength - 1] = (carry and BASE).toUInt()
|
||||||
@ -316,7 +316,7 @@ public class BigInt internal constructor(
|
|||||||
|
|
||||||
for (j in mag2.indices) {
|
for (j in mag2.indices) {
|
||||||
val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry
|
val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry
|
||||||
result[i + j] = (cur and BASE.toULong()).toUInt()
|
result[i + j] = (cur and BASE).toUInt()
|
||||||
carry = cur shr BASE_SIZE
|
carry = cur shr BASE_SIZE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import kscience.kmath.memory.MemoryWriter
|
|||||||
import kscience.kmath.structures.Buffer
|
import kscience.kmath.structures.Buffer
|
||||||
import kscience.kmath.structures.MemoryBuffer
|
import kscience.kmath.structures.MemoryBuffer
|
||||||
import kscience.kmath.structures.MutableBuffer
|
import kscience.kmath.structures.MutableBuffer
|
||||||
|
import kscience.kmath.structures.MutableMemoryBuffer
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -159,7 +160,7 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents complex number.
|
* Represents `double`-based complex number.
|
||||||
*
|
*
|
||||||
* @property re The real part.
|
* @property re The real part.
|
||||||
* @property im The imaginary part.
|
* @property im The imaginary part.
|
||||||
@ -176,11 +177,16 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
|
|||||||
|
|
||||||
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
|
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
|
||||||
|
|
||||||
public companion object : MemorySpec<Complex> {
|
override fun toString(): String {
|
||||||
override val objectSize: Int = 16
|
return "($re + i*$im)"
|
||||||
|
}
|
||||||
|
|
||||||
override fun MemoryReader.read(offset: Int): Complex =
|
|
||||||
Complex(readDouble(offset), readDouble(offset + 8))
|
public companion object : MemorySpec<Complex> {
|
||||||
|
override val objectSize: Int
|
||||||
|
get() = 16
|
||||||
|
|
||||||
|
override fun MemoryReader.read(offset: Int): Complex = Complex(readDouble(offset), readDouble(offset + 8))
|
||||||
|
|
||||||
override fun MemoryWriter.write(offset: Int, value: Complex) {
|
override fun MemoryWriter.write(offset: Int, value: Complex) {
|
||||||
writeDouble(offset, value.re)
|
writeDouble(offset, value.re)
|
||||||
@ -197,8 +203,16 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
|
|||||||
*/
|
*/
|
||||||
public fun Number.toComplex(): Complex = Complex(this, 0.0)
|
public fun Number.toComplex(): Complex = Complex(this, 0.0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new buffer of complex numbers with the specified [size], where each element is calculated by calling the
|
||||||
|
* specified [init] function.
|
||||||
|
*/
|
||||||
public inline fun Buffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
|
public inline fun Buffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
|
||||||
MemoryBuffer.create(Complex, size, init)
|
MemoryBuffer.create(Complex, size, init)
|
||||||
|
|
||||||
public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
|
/**
|
||||||
MemoryBuffer.create(Complex, size, init)
|
* Creates a new buffer of complex numbers with the specified [size], where each element is calculated by calling the
|
||||||
|
* specified [init] function.
|
||||||
|
*/
|
||||||
|
public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): MutableBuffer<Complex> =
|
||||||
|
MutableMemoryBuffer.create(Complex, size, init)
|
||||||
|
@ -15,8 +15,9 @@ public class BoxingNDField<T, F : Field<T>>(
|
|||||||
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
bufferFactory(size, initializer)
|
bufferFactory(size, initializer)
|
||||||
|
|
||||||
public override fun check(vararg elements: NDBuffer<T>) {
|
public override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
||||||
|
return elements
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
||||||
@ -75,6 +76,6 @@ public inline fun <T : Any, F : Field<T>, R> F.nd(
|
|||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
action: NDField<T, F, *>.() -> R
|
action: NDField<T, F, *>.() -> R
|
||||||
): R {
|
): R {
|
||||||
val ndfield: BoxingNDField<T, F> = NDField.boxing(this, *shape, bufferFactory = bufferFactory)
|
val ndfield = NDField.boxing(this, *shape, bufferFactory = bufferFactory)
|
||||||
return ndfield.action()
|
return ndfield.action()
|
||||||
}
|
}
|
||||||
|
@ -14,8 +14,9 @@ public class BoxingNDRing<T, R : Ring<T>>(
|
|||||||
|
|
||||||
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
|
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||||
|
return elements
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
|
override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement<T, R> =
|
||||||
|
@ -5,8 +5,10 @@ import kscience.kmath.operations.*
|
|||||||
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||||
public val strides: Strides
|
public val strides: Strides
|
||||||
|
|
||||||
public override fun check(vararg elements: NDBuffer<T>): Unit =
|
public override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
require(elements.all { it.strides == strides }) { ("Strides mismatch") }
|
require(elements.all { it.strides == strides }) { "Strides mismatch" }
|
||||||
|
return elements
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert any [NDStructure] to buffered structure using strides from this context.
|
* Convert any [NDStructure] to buffered structure using strides from this context.
|
||||||
|
@ -46,35 +46,48 @@ public interface Buffer<T> {
|
|||||||
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
|
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
|
/**
|
||||||
val array = DoubleArray(size) { initializer(it) }
|
* Creates a [RealBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
return RealBuffer(array)
|
* [initializer] function.
|
||||||
}
|
*/
|
||||||
|
public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer =
|
||||||
|
RealBuffer(size) { initializer(it) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a boxing buffer of given type
|
* Creates a [ListBuffer] of given type [T] with given [size]. Each element is calculated by calling the
|
||||||
|
* specified [initializer] function.
|
||||||
*/
|
*/
|
||||||
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> =
|
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
ListBuffer(List(size, initializer))
|
ListBuffer(List(size, initializer))
|
||||||
|
|
||||||
|
// TODO add resolution based on Annotation or companion resolution
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([IntBuffer],
|
||||||
|
* [RealBuffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
public inline fun <T : Any> auto(type: KClass<T>, size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
//TODO add resolution based on Annotation or companion resolution
|
when (type) {
|
||||||
return when (type) {
|
Double::class -> RealBuffer(size) { initializer(it) as Double } as Buffer<T>
|
||||||
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
Short::class -> ShortBuffer(size) { initializer(it) as Short } as Buffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
Int::class -> IntBuffer(size) { initializer(it) as Int } as Buffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
Long::class -> LongBuffer(size) { initializer(it) as Long } as Buffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
Float::class -> FloatBuffer(size) { initializer(it) as Float } as Buffer<T>
|
||||||
Complex::class -> complex(size) { initializer(it) as Complex } as Buffer<T>
|
Complex::class -> complex(size) { initializer(it) as Complex } as Buffer<T>
|
||||||
else -> boxing(size, initializer)
|
else -> boxing(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
* Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([IntBuffer],
|
||||||
|
* [RealBuffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> =
|
public inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
auto(T::class, size, initializer)
|
auto(T::class, size, initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,25 +130,40 @@ public interface MutableBuffer<T> : Buffer<T> {
|
|||||||
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
MutableListBuffer(MutableList(size, initializer))
|
MutableListBuffer(MutableList(size, initializer))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used
|
||||||
|
* ([IntBuffer], [RealBuffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
public inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
when (type) {
|
when (type) {
|
||||||
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> RealBuffer(size) { initializer(it) as Double } as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(size) { initializer(it) as Short } as MutableBuffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
Int::class -> IntBuffer(size) { initializer(it) as Int } as MutableBuffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
Float::class -> FloatBuffer(size) { initializer(it) as Float } as MutableBuffer<T>
|
||||||
|
Long::class -> LongBuffer(size) { initializer(it) as Long } as MutableBuffer<T>
|
||||||
|
Complex::class -> complex(size) { initializer(it) as Complex } as MutableBuffer<T>
|
||||||
else -> boxing(size, initializer)
|
else -> boxing(size, initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
* Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used
|
||||||
|
* ([IntBuffer], [RealBuffer], etc.), [ListBuffer] is returned otherwise.
|
||||||
|
*
|
||||||
|
* The [size] is specified, and each element is calculated by calling the specified [initializer] function.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
public inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
auto(T::class, size, initializer)
|
auto(T::class, size, initializer)
|
||||||
|
|
||||||
public val real: MutableBufferFactory<Double> =
|
/**
|
||||||
{ size, initializer -> RealBuffer(DoubleArray(size) { initializer(it) }) }
|
* Creates a [RealBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
|
* [initializer] function.
|
||||||
|
*/
|
||||||
|
public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer =
|
||||||
|
RealBuffer(size) { initializer(it) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +48,8 @@ public fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, Valu
|
|||||||
/**
|
/**
|
||||||
* A real buffer which supports flags for each value like NaN or Missing
|
* A real buffer which supports flags for each value like NaN or Missing
|
||||||
*/
|
*/
|
||||||
public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer<Double?>,
|
||||||
|
Buffer<Double?> {
|
||||||
init {
|
init {
|
||||||
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ public class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) :
|
|||||||
public inline fun <T : Any> create(
|
public inline fun <T : Any> create(
|
||||||
spec: MemorySpec<T>,
|
spec: MemorySpec<T>,
|
||||||
size: Int,
|
size: Int,
|
||||||
crossinline initializer: (Int) -> T
|
initializer: (Int) -> T
|
||||||
): MutableMemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
): MutableMemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
||||||
(0 until size).forEach { buffer[it] = initializer(it) }
|
(0 until size).forEach { buffer[it] = initializer(it) }
|
||||||
}
|
}
|
||||||
|
@ -7,49 +7,77 @@ import kscience.kmath.operations.Space
|
|||||||
import kotlin.native.concurrent.ThreadLocal
|
import kotlin.native.concurrent.ThreadLocal
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An exception is thrown when the expected ans actual shape of NDArray differs
|
* An exception is thrown when the expected ans actual shape of NDArray differs.
|
||||||
|
*
|
||||||
|
* @property expected the expected shape.
|
||||||
|
* @property actual the actual shape.
|
||||||
*/
|
*/
|
||||||
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : RuntimeException()
|
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
||||||
|
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The base interface for all nd-algebra implementations
|
* The base interface for all ND-algebra implementations.
|
||||||
* @param T the type of nd-structure element
|
*
|
||||||
* @param C the type of the element context
|
* @param T the type of ND-structure element.
|
||||||
* @param N the type of the structure
|
* @param C the type of the element context.
|
||||||
|
* @param N the type of the structure.
|
||||||
*/
|
*/
|
||||||
public interface NDAlgebra<T, C, N : NDStructure<T>> {
|
public interface NDAlgebra<T, C, N : NDStructure<T>> {
|
||||||
|
/**
|
||||||
|
* The shape of ND-structures this algebra operates on.
|
||||||
|
*/
|
||||||
public val shape: IntArray
|
public val shape: IntArray
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The algebra over elements of ND structure.
|
||||||
|
*/
|
||||||
public val elementContext: C
|
public val elementContext: C
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produce a new [N] structure using given initializer function
|
* Produces a new [N] structure using given initializer function.
|
||||||
*/
|
*/
|
||||||
public fun produce(initializer: C.(IntArray) -> T): N
|
public fun produce(initializer: C.(IntArray) -> T): N
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map elements from one structure to another one
|
* Maps elements from one structure to another one by applying [transform] to them.
|
||||||
*/
|
*/
|
||||||
public fun map(arg: N, transform: C.(T) -> T): N
|
public fun map(arg: N, transform: C.(T) -> T): N
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map indexed elements
|
* Maps elements from one structure to another one by applying [transform] to them alongside with their indices.
|
||||||
*/
|
*/
|
||||||
public fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N
|
public fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine two structures into one
|
* Combines two structures into one.
|
||||||
*/
|
*/
|
||||||
public fun combine(a: N, b: N, transform: C.(T, T) -> T): N
|
public fun combine(a: N, b: N, transform: C.(T, T) -> T): N
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if given elements are consistent with this context
|
* Checks if given element is consistent with this context.
|
||||||
|
*
|
||||||
|
* @param element the structure to check.
|
||||||
|
* @return the valid structure.
|
||||||
*/
|
*/
|
||||||
public fun check(vararg elements: N): Unit = elements.forEach {
|
public fun check(element: N): N {
|
||||||
if (!shape.contentEquals(it.shape)) throw ShapeMismatchException(shape, it.shape)
|
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
|
||||||
|
return element
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* element-by-element invoke a function working on [T] on a [NDStructure]
|
* Checks if given elements are consistent with this context.
|
||||||
|
*
|
||||||
|
* @param elements the structures to check.
|
||||||
|
* @return the array of valid structures.
|
||||||
|
*/
|
||||||
|
public fun check(vararg elements: N): Array<out N> = elements
|
||||||
|
.map(NDStructure<T>::shape)
|
||||||
|
.singleOrNull { !shape.contentEquals(it) }
|
||||||
|
?.let { throw ShapeMismatchException(shape, it) }
|
||||||
|
?: elements
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise invocation of function working on [T] on a [NDStructure].
|
||||||
*/
|
*/
|
||||||
public operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
|
public operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
|
||||||
|
|
||||||
@ -57,42 +85,107 @@ public interface NDAlgebra<T, C, N : NDStructure<T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An nd-space over element space
|
* Space of [NDStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param N the type of ND structure.
|
||||||
|
* @param S the type of space of structure elements.
|
||||||
*/
|
*/
|
||||||
public interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T, S, N> {
|
public interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T, S, N> {
|
||||||
/**
|
/**
|
||||||
* Element-by-element addition
|
* Element-wise addition.
|
||||||
|
*
|
||||||
|
* @param a the addend.
|
||||||
|
* @param b the augend.
|
||||||
|
* @return the sum.
|
||||||
*/
|
*/
|
||||||
override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) }
|
public override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Multiply all elements by constant
|
* Element-wise multiplication by scalar.
|
||||||
|
*
|
||||||
|
* @param a the multiplicand.
|
||||||
|
* @param k the multiplier.
|
||||||
|
* @return the product.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
public override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
// TODO move to extensions after KEEP-176
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds an ND structure to an element of it.
|
||||||
|
*
|
||||||
|
* @receiver the addend.
|
||||||
|
* @param arg the augend.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
public operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
|
public operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtracts an element from ND structure of it.
|
||||||
|
*
|
||||||
|
* @receiver the dividend.
|
||||||
|
* @param arg the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
public operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
|
public operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds an element to ND structure of it.
|
||||||
|
*
|
||||||
|
* @receiver the addend.
|
||||||
|
* @param arg the augend.
|
||||||
|
* @return the sum.
|
||||||
|
*/
|
||||||
public operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
|
public operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtracts an ND structure from an element of it.
|
||||||
|
*
|
||||||
|
* @receiver the dividend.
|
||||||
|
* @param arg the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
public operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
|
public operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An nd-ring over element ring
|
* Ring of [NDStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param N the type of ND structure.
|
||||||
|
* @param R the type of ring of structure elements.
|
||||||
*/
|
*/
|
||||||
public interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N> {
|
public interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N> {
|
||||||
/**
|
/**
|
||||||
* Element-by-element multiplication
|
* Element-wise multiplication.
|
||||||
|
*
|
||||||
|
* @param a the multiplicand.
|
||||||
|
* @param b the multiplier.
|
||||||
|
* @return the product.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
public override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies an ND structure by an element of it.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param arg the multiplier.
|
||||||
|
* @return the product.
|
||||||
|
*/
|
||||||
public operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
|
public operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies an element by a ND structure of it.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param arg the multiplier.
|
||||||
|
* @return the product.
|
||||||
|
*/
|
||||||
public operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
|
public operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
@ -103,17 +196,35 @@ public interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T
|
|||||||
*
|
*
|
||||||
* @param T the type of the element contained in ND structure.
|
* @param T the type of the element contained in ND structure.
|
||||||
* @param N the type of ND structure.
|
* @param N the type of ND structure.
|
||||||
* @param F field of structure elements.
|
* @param F the type field of structure elements.
|
||||||
*/
|
*/
|
||||||
public interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
|
public interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
|
||||||
/**
|
/**
|
||||||
* Element-by-element division
|
* Element-wise division.
|
||||||
|
*
|
||||||
|
* @param a the dividend.
|
||||||
|
* @param b the divisor.
|
||||||
|
* @return the quotient.
|
||||||
*/
|
*/
|
||||||
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
public override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
|
/**
|
||||||
|
* Divides an ND structure by an element of it.
|
||||||
|
*
|
||||||
|
* @receiver the dividend.
|
||||||
|
* @param arg the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
public operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
|
public operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Divides an element by an ND structure of it.
|
||||||
|
*
|
||||||
|
* @receiver the dividend.
|
||||||
|
* @param arg the divisor.
|
||||||
|
* @return the quotient.
|
||||||
|
*/
|
||||||
public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
|
public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
|
||||||
|
|
||||||
@ThreadLocal
|
@ThreadLocal
|
||||||
@ -121,12 +232,12 @@ public interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing
|
|||||||
private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
|
private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
* Create a nd-field for [Double] values or pull it from cache if it was created previously.
|
||||||
*/
|
*/
|
||||||
public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a nd-field with boxing generic buffer
|
* Create an ND field with boxing generic buffer.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> boxing(
|
public fun <T : Any, F : Field<T>> boxing(
|
||||||
field: F,
|
field: F,
|
||||||
|
@ -38,9 +38,8 @@ public interface NDStructure<T> {
|
|||||||
*/
|
*/
|
||||||
public fun elements(): Sequence<Pair<IntArray, T>>
|
public fun elements(): Sequence<Pair<IntArray, T>>
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean
|
public override fun equals(other: Any?): Boolean
|
||||||
|
public override fun hashCode(): Int
|
||||||
override fun hashCode(): Int
|
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
@ -50,13 +49,8 @@ public interface NDStructure<T> {
|
|||||||
if (st1 === st2) return true
|
if (st1 === st2) return true
|
||||||
|
|
||||||
// fast comparison of buffers if possible
|
// fast comparison of buffers if possible
|
||||||
if (
|
if (st1 is NDBuffer && st2 is NDBuffer && st1.strides == st2.strides)
|
||||||
st1 is NDBuffer &&
|
|
||||||
st2 is NDBuffer &&
|
|
||||||
st1.strides == st2.strides
|
|
||||||
) {
|
|
||||||
return st1.buffer.contentEquals(st2.buffer)
|
return st1.buffer.contentEquals(st2.buffer)
|
||||||
}
|
|
||||||
|
|
||||||
//element by element comparison if it could not be avoided
|
//element by element comparison if it could not be avoided
|
||||||
return st1.elements().all { (index, value) -> value == st2[index] }
|
return st1.elements().all { (index, value) -> value == st2[index] }
|
||||||
@ -70,7 +64,7 @@ public interface NDStructure<T> {
|
|||||||
public fun <T> build(
|
public fun <T> build(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
@ -79,40 +73,40 @@ public interface NDStructure<T> {
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
public inline fun <T : Any> auto(
|
public inline fun <T : Any> auto(
|
||||||
type: KClass<T>,
|
type: KClass<T>,
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
public fun <T> build(
|
public fun <T> build(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||||
|
|
||||||
public inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
@JvmName("autoVarArg")
|
@JvmName("autoVarArg")
|
||||||
public inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
public inline fun <T : Any> auto(
|
public inline fun <T : Any> auto(
|
||||||
type: KClass<T>,
|
type: KClass<T>,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(type, DefaultStrides(shape), initializer)
|
auto(type, DefaultStrides(shape), initializer)
|
||||||
}
|
}
|
||||||
@ -274,6 +268,22 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
result = 31 * result + buffer.hashCode()
|
result = 31 * result + buffer.hashCode()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
val bufferRepr: String = when (shape.size) {
|
||||||
|
1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
|
||||||
|
2 -> (0 until shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i ->
|
||||||
|
(0 until shape[1]).joinToString(prefix = "[", postfix = "]", separator = ", ") { j ->
|
||||||
|
val offset = strides.offset(intArrayOf(i, j))
|
||||||
|
buffer[offset].toString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else -> "..."
|
||||||
|
}
|
||||||
|
return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -281,7 +291,7 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
*/
|
*/
|
||||||
public class BufferNDStructure<T>(
|
public class BufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>,
|
||||||
) : NDBuffer<T>() {
|
) : NDBuffer<T>() {
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
if (strides.linearSize != buffer.size) {
|
||||||
@ -295,7 +305,7 @@ public class BufferNDStructure<T>(
|
|||||||
*/
|
*/
|
||||||
public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
||||||
factory: BufferFactory<R> = Buffer.Companion::auto,
|
factory: BufferFactory<R> = Buffer.Companion::auto,
|
||||||
crossinline transform: (T) -> R
|
crossinline transform: (T) -> R,
|
||||||
): BufferNDStructure<R> {
|
): BufferNDStructure<R> {
|
||||||
return if (this is BufferNDStructure<T>)
|
return if (this is BufferNDStructure<T>)
|
||||||
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
||||||
@ -310,7 +320,7 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
*/
|
*/
|
||||||
public class MutableBufferNDStructure<T>(
|
public class MutableBufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: MutableBuffer<T>
|
override val buffer: MutableBuffer<T>,
|
||||||
) : NDBuffer<T>(), MutableNDStructure<T> {
|
) : NDBuffer<T>(), MutableNDStructure<T> {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
@ -324,7 +334,7 @@ public class MutableBufferNDStructure<T>(
|
|||||||
|
|
||||||
public inline fun <reified T : Any> NDStructure<T>.combine(
|
public inline fun <reified T : Any> NDStructure<T>.combine(
|
||||||
struct: NDStructure<T>,
|
struct: NDStructure<T>,
|
||||||
crossinline block: (T, T) -> T
|
crossinline block: (T, T) -> T,
|
||||||
): NDStructure<T> {
|
): NDStructure<T> {
|
||||||
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
|
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
|
||||||
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
||||||
|
@ -17,7 +17,7 @@ public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
|||||||
/**
|
/**
|
||||||
* A 1D wrapper for nd-structure
|
* A 1D wrapper for nd-structure
|
||||||
*/
|
*/
|
||||||
private inline class Structure1DWrapper<T>(public val structure: NDStructure<T>) : Structure1D<T> {
|
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
|
||||||
override val shape: IntArray get() = structure.shape
|
override val shape: IntArray get() = structure.shape
|
||||||
override val size: Int get() = structure.shape[0]
|
override val size: Int get() = structure.shape[0]
|
||||||
|
|
||||||
|
@ -21,7 +21,8 @@ internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: s
|
|||||||
}
|
}
|
||||||
|
|
||||||
public class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
|
public class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
|
||||||
override suspend fun collect(collector: FlowCollector<T>): Unit = deferredFlow.collect { collector.emit((it.await())) }
|
override suspend fun collect(collector: FlowCollector<T>): Unit =
|
||||||
|
deferredFlow.collect { collector.emit((it.await())) }
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T, R> Flow<T>.async(
|
public fun <T, R> Flow<T>.async(
|
||||||
|
@ -3,8 +3,9 @@ package kscience.kmath.dimensions
|
|||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An abstract class which is not used in runtime. Designates a size of some structure.
|
* Represents a quantity of dimensions in certain structure.
|
||||||
* Could be replaced later by fully inline constructs
|
*
|
||||||
|
* @property dim The number of dimensions.
|
||||||
*/
|
*/
|
||||||
public interface Dimension {
|
public interface Dimension {
|
||||||
public val dim: UInt
|
public val dim: UInt
|
||||||
@ -16,18 +17,33 @@ public fun <D : Dimension> KClass<D>.dim(): UInt = Dimension.resolve(this).dim
|
|||||||
|
|
||||||
public expect fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D
|
public expect fun <D : Dimension> Dimension.Companion.resolve(type: KClass<D>): D
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finds or creates [Dimension] with [Dimension.dim] equal to [dim].
|
||||||
|
*/
|
||||||
public expect fun Dimension.Companion.of(dim: UInt): Dimension
|
public expect fun Dimension.Companion.of(dim: UInt): Dimension
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finds [Dimension.dim] of given type [D].
|
||||||
|
*/
|
||||||
public inline fun <reified D : Dimension> Dimension.Companion.dim(): UInt = D::class.dim()
|
public inline fun <reified D : Dimension> Dimension.Companion.dim(): UInt = D::class.dim()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type representing 1 dimension.
|
||||||
|
*/
|
||||||
public object D1 : Dimension {
|
public object D1 : Dimension {
|
||||||
override val dim: UInt get() = 1U
|
override val dim: UInt get() = 1U
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type representing 2 dimensions.
|
||||||
|
*/
|
||||||
public object D2 : Dimension {
|
public object D2 : Dimension {
|
||||||
override val dim: UInt get() = 2U
|
override val dim: UInt get() = 2U
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type representing 3 dimensions.
|
||||||
|
*/
|
||||||
public object D3 : Dimension {
|
public object D3 : Dimension {
|
||||||
override val dim: UInt get() = 3U
|
override val dim: UInt get() = 3U
|
||||||
}
|
}
|
||||||
|
8
kmath-ejml/build.gradle.kts
Normal file
8
kmath-ejml/build.gradle.kts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("org.ejml:ejml-simple:0.39")
|
||||||
|
implementation(project(":kmath-core"))
|
||||||
|
}
|
71
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt
Normal file
71
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kscience.kmath.linear.DeterminantFeature
|
||||||
|
import kscience.kmath.linear.FeaturedMatrix
|
||||||
|
import kscience.kmath.linear.LUPDecompositionFeature
|
||||||
|
import kscience.kmath.linear.MatrixFeature
|
||||||
|
import kscience.kmath.structures.NDStructure
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents featured matrix over EJML [SimpleMatrix].
|
||||||
|
*
|
||||||
|
* @property origin the underlying [SimpleMatrix].
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public class EjmlMatrix(public val origin: SimpleMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||||
|
public override val rowNum: Int
|
||||||
|
get() = origin.numRows()
|
||||||
|
|
||||||
|
public override val colNum: Int
|
||||||
|
get() = origin.numCols()
|
||||||
|
|
||||||
|
public override val shape: IntArray
|
||||||
|
get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||||
|
|
||||||
|
public override val features: Set<MatrixFeature> = setOf(
|
||||||
|
object : LUPDecompositionFeature<Double>, DeterminantFeature<Double> {
|
||||||
|
override val determinant: Double
|
||||||
|
get() = origin.determinant()
|
||||||
|
|
||||||
|
private val lup by lazy {
|
||||||
|
val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
|
||||||
|
.also { it.decompose(origin.ddrm.copy()) }
|
||||||
|
|
||||||
|
Triple(
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))),
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))),
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val l: FeaturedMatrix<Double>
|
||||||
|
get() = lup.second
|
||||||
|
|
||||||
|
override val u: FeaturedMatrix<Double>
|
||||||
|
get() = lup.third
|
||||||
|
|
||||||
|
override val p: FeaturedMatrix<Double>
|
||||||
|
get() = lup.first
|
||||||
|
}
|
||||||
|
) union features.orEmpty()
|
||||||
|
|
||||||
|
public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix =
|
||||||
|
EjmlMatrix(origin, this.features + features)
|
||||||
|
|
||||||
|
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||||
|
|
||||||
|
public override fun equals(other: Any?): Boolean {
|
||||||
|
if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0)
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun hashCode(): Int {
|
||||||
|
var result = origin.hashCode()
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)"
|
||||||
|
}
|
@ -0,0 +1,86 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kscience.kmath.linear.MatrixContext
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.operations.Space
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.structures.Matrix
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents context of basic operations operating with [EjmlMatrix].
|
||||||
|
*
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double> {
|
||||||
|
/**
|
||||||
|
* Converts this matrix to EJML one.
|
||||||
|
*/
|
||||||
|
public fun Matrix<Double>.toEjml(): EjmlMatrix =
|
||||||
|
if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts this vector to EJML one.
|
||||||
|
*/
|
||||||
|
public fun Point<Double>.toEjml(): EjmlVector =
|
||||||
|
if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also {
|
||||||
|
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) }
|
||||||
|
})
|
||||||
|
|
||||||
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix =
|
||||||
|
EjmlMatrix(SimpleMatrix(rows, columns).also {
|
||||||
|
(0 until it.numRows()).forEach { row ->
|
||||||
|
(0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
|
||||||
|
|
||||||
|
public override fun Matrix<Double>.dot(vector: Point<Double>): EjmlVector =
|
||||||
|
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
|
||||||
|
|
||||||
|
public override fun add(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(a.toEjml().origin + b.toEjml().origin)
|
||||||
|
|
||||||
|
public override operator fun Matrix<Double>.minus(b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(toEjml().origin - b.toEjml().origin)
|
||||||
|
|
||||||
|
public override fun multiply(a: Matrix<Double>, k: Number): EjmlMatrix =
|
||||||
|
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
|
||||||
|
|
||||||
|
public override operator fun Matrix<Double>.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value))
|
||||||
|
|
||||||
|
public companion object
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix.
|
||||||
|
*
|
||||||
|
* @param a the base matrix.
|
||||||
|
* @param b n by p matrix.
|
||||||
|
* @return the solution for 'x' that is n by p.
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix.
|
||||||
|
*
|
||||||
|
* @param a the base matrix.
|
||||||
|
* @param b n by p vector.
|
||||||
|
* @return the solution for 'x' that is n by p.
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
|
||||||
|
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the inverse of given matrix: b = a^(-1).
|
||||||
|
*
|
||||||
|
* @param a the matrix.
|
||||||
|
* @return the inverse of this matrix.
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())
|
40
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt
Normal file
40
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.structures.Buffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents point over EJML [SimpleMatrix].
|
||||||
|
*
|
||||||
|
* @property origin the underlying [SimpleMatrix].
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public class EjmlVector internal constructor(public val origin: SimpleMatrix) : Point<Double> {
|
||||||
|
public override val size: Int
|
||||||
|
get() = origin.numRows()
|
||||||
|
|
||||||
|
init {
|
||||||
|
require(origin.numCols() == 1) { "Only single column matrices are allowed" }
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun get(index: Int): Double = origin[index]
|
||||||
|
|
||||||
|
public override operator fun iterator(): Iterator<Double> = object : Iterator<Double> {
|
||||||
|
private var cursor: Int = 0
|
||||||
|
|
||||||
|
override fun next(): Double {
|
||||||
|
cursor += 1
|
||||||
|
return origin[cursor - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun contentEquals(other: Buffer<*>): Boolean {
|
||||||
|
if (other is EjmlVector) return origin.isIdentical(other.origin, 0.0)
|
||||||
|
return super.contentEquals(other)
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun toString(): String = "EjmlVector(origin=$origin)"
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import kscience.kmath.linear.DeterminantFeature
|
||||||
|
import kscience.kmath.linear.LUPDecompositionFeature
|
||||||
|
import kscience.kmath.linear.MatrixFeature
|
||||||
|
import kscience.kmath.linear.getFeature
|
||||||
|
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.random.asJavaRandom
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
internal class EjmlMatrixTest {
|
||||||
|
private val random = Random(0)
|
||||||
|
|
||||||
|
private val randomMatrix: SimpleMatrix
|
||||||
|
get() {
|
||||||
|
val s = random.nextInt(2, 100)
|
||||||
|
return SimpleMatrix.random_DDRM(s, s, 0.0, 10.0, random.asJavaRandom())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun rowNum() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertEquals(m.numRows(), EjmlMatrix(m).rowNum)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun colNum() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertEquals(m.numCols(), EjmlMatrix(m).rowNum)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun shape() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlMatrix(m)
|
||||||
|
assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun features() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlMatrix(m)
|
||||||
|
val det = w.getFeature<DeterminantFeature<Double>>() ?: fail()
|
||||||
|
assertEquals(m.determinant(), det.determinant)
|
||||||
|
val lup = w.getFeature<LUPDecompositionFeature<Double>>() ?: fail()
|
||||||
|
|
||||||
|
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols())
|
||||||
|
.also { it.decompose(m.ddrm.copy()) }
|
||||||
|
|
||||||
|
assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), lup.l)
|
||||||
|
assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), lup.u)
|
||||||
|
assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), lup.p)
|
||||||
|
}
|
||||||
|
|
||||||
|
private object SomeFeature : MatrixFeature {}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun suggestFeature() {
|
||||||
|
assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature<SomeFeature>())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun get() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertEquals(m[0, 0], EjmlMatrix(m)[0, 0])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun origin() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertSame(m, EjmlMatrix(m).origin)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.random.asJavaRandom
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertSame
|
||||||
|
|
||||||
|
internal class EjmlVectorTest {
|
||||||
|
private val random = Random(0)
|
||||||
|
|
||||||
|
private val randomMatrix: SimpleMatrix
|
||||||
|
get() = SimpleMatrix.random_DDRM(random.nextInt(2, 100), 1, 0.0, 10.0, random.asJavaRandom())
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun size() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
assertEquals(m.numRows(), w.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun get() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
assertEquals(m[0, 0], w[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun iterator() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
|
||||||
|
assertEquals(
|
||||||
|
m.iterator(true, 0, 0, m.numRows() - 1, 0).asSequence().toList(),
|
||||||
|
w.iterator().asSequence().toList()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun origin() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
assertSame(m, w.origin)
|
||||||
|
}
|
||||||
|
}
|
@ -4,7 +4,7 @@ import kscience.kmath.operations.Space
|
|||||||
|
|
||||||
public interface Vector
|
public interface Vector
|
||||||
|
|
||||||
public interface GeometrySpace<V: Vector>: Space<V> {
|
public interface GeometrySpace<V : Vector> : Space<V> {
|
||||||
/**
|
/**
|
||||||
* L2 distance
|
* L2 distance
|
||||||
*/
|
*/
|
||||||
|
@ -10,9 +10,10 @@ import kscience.kmath.structures.RealBuffer
|
|||||||
*/
|
*/
|
||||||
public interface Bin<T : Any> : Domain<T> {
|
public interface Bin<T : Any> : Domain<T> {
|
||||||
/**
|
/**
|
||||||
* The value of this bin
|
* The value of this bin.
|
||||||
*/
|
*/
|
||||||
public val value: Number
|
public val value: Number
|
||||||
|
|
||||||
public val center: Point<T>
|
public val center: Point<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,10 +5,7 @@ import kscience.kmath.histogram.fill
|
|||||||
import kscience.kmath.histogram.put
|
import kscience.kmath.histogram.put
|
||||||
import kscience.kmath.real.RealVector
|
import kscience.kmath.real.RealVector
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.test.Test
|
import kotlin.test.*
|
||||||
import kotlin.test.assertEquals
|
|
||||||
import kotlin.test.assertFalse
|
|
||||||
import kotlin.test.assertTrue
|
|
||||||
|
|
||||||
internal class MultivariateHistogramTest {
|
internal class MultivariateHistogramTest {
|
||||||
@Test
|
@Test
|
||||||
@ -18,7 +15,7 @@ internal class MultivariateHistogramTest {
|
|||||||
(-1.0..1.0)
|
(-1.0..1.0)
|
||||||
)
|
)
|
||||||
histogram.put(0.55, 0.55)
|
histogram.put(0.55, 0.55)
|
||||||
val bin = histogram.find { it.value.toInt() > 0 }!!
|
val bin = histogram.find { it.value.toInt() > 0 } ?: fail()
|
||||||
assertTrue { bin.contains(RealVector(0.55, 0.55)) }
|
assertTrue { bin.contains(RealVector(0.55, 0.55)) }
|
||||||
assertTrue { bin.contains(RealVector(0.6, 0.5)) }
|
assertTrue { bin.contains(RealVector(0.6, 0.5)) }
|
||||||
assertFalse { bin.contains(RealVector(-0.55, 0.55)) }
|
assertFalse { bin.contains(RealVector(-0.55, 0.55)) }
|
||||||
|
@ -63,7 +63,7 @@ public fun <T : Any> Sampler<T>.sampleBuffer(
|
|||||||
//clear list from previous run
|
//clear list from previous run
|
||||||
tmp.clear()
|
tmp.clear()
|
||||||
//Fill list
|
//Fill list
|
||||||
repeat(size){
|
repeat(size) {
|
||||||
tmp.add(chain.next())
|
tmp.add(chain.next())
|
||||||
}
|
}
|
||||||
//return new buffer with elements from tmp
|
//return new buffer with elements from tmp
|
||||||
|
@ -3,16 +3,59 @@ package kscience.kmath.prob
|
|||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A basic generator
|
* An interface that is implemented by random number generator algorithms.
|
||||||
*/
|
*/
|
||||||
public interface RandomGenerator {
|
public interface RandomGenerator {
|
||||||
|
/**
|
||||||
|
* Gets the next random [Boolean] value.
|
||||||
|
*/
|
||||||
public fun nextBoolean(): Boolean
|
public fun nextBoolean(): Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the next random [Double] value uniformly distributed between 0 (inclusive) and 1 (exclusive).
|
||||||
|
*/
|
||||||
public fun nextDouble(): Double
|
public fun nextDouble(): Double
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the next random `Int` from the random number generator.
|
||||||
|
*
|
||||||
|
* Generates an `Int` random value uniformly distributed between [Int.MIN_VALUE] and [Int.MAX_VALUE] (inclusive).
|
||||||
|
*/
|
||||||
public fun nextInt(): Int
|
public fun nextInt(): Int
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the next random non-negative `Int` from the random number generator less than the specified [until] bound.
|
||||||
|
*
|
||||||
|
* Generates an `Int` random value uniformly distributed between `0` (inclusive) and the specified [until] bound
|
||||||
|
* (exclusive).
|
||||||
|
*/
|
||||||
public fun nextInt(until: Int): Int
|
public fun nextInt(until: Int): Int
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the next random `Long` from the random number generator.
|
||||||
|
*
|
||||||
|
* Generates a `Long` random value uniformly distributed between [Long.MIN_VALUE] and [Long.MAX_VALUE] (inclusive).
|
||||||
|
*/
|
||||||
public fun nextLong(): Long
|
public fun nextLong(): Long
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the next random non-negative `Long` from the random number generator less than the specified [until] bound.
|
||||||
|
*
|
||||||
|
* Generates a `Long` random value uniformly distributed between `0` (inclusive) and the specified [until] bound (exclusive).
|
||||||
|
*/
|
||||||
public fun nextLong(until: Long): Long
|
public fun nextLong(until: Long): Long
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fills a subrange of the specified byte [array] starting from [fromIndex] inclusive and ending [toIndex] exclusive
|
||||||
|
* with random bytes.
|
||||||
|
*
|
||||||
|
* @return [array] with the subrange filled with random bytes.
|
||||||
|
*/
|
||||||
public fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
|
public fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a byte array of the specified [size], filled with random bytes.
|
||||||
|
*/
|
||||||
public fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
|
public fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -25,12 +68,21 @@ public interface RandomGenerator {
|
|||||||
public fun fork(): RandomGenerator
|
public fun fork(): RandomGenerator
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val default: DefaultGenerator by lazy { DefaultGenerator() }
|
/**
|
||||||
|
* The [DefaultGenerator] instance.
|
||||||
|
*/
|
||||||
|
public val default: DefaultGenerator by lazy(::DefaultGenerator)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns [DefaultGenerator] of given [seed].
|
||||||
|
*/
|
||||||
public fun default(seed: Long): DefaultGenerator = DefaultGenerator(Random(seed))
|
public fun default(seed: Long): DefaultGenerator = DefaultGenerator(Random(seed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implements [RandomGenerator] by delegating all operations to [Random].
|
||||||
|
*/
|
||||||
public inline class DefaultGenerator(public val random: Random = Random) : RandomGenerator {
|
public inline class DefaultGenerator(public val random: Random = Random) : RandomGenerator {
|
||||||
public override fun nextBoolean(): Boolean = random.nextBoolean()
|
public override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||||
public override fun nextDouble(): Double = random.nextDouble()
|
public override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
@ -6,7 +6,7 @@ import kotlin.test.Test
|
|||||||
class SamplerTest {
|
class SamplerTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun bufferSamplerTest(){
|
fun bufferSamplerTest() {
|
||||||
val sampler: Sampler<Double> =
|
val sampler: Sampler<Double> =
|
||||||
BasicSampler { it.chain { nextDouble() } }
|
BasicSampler { it.chain { nextDouble() } }
|
||||||
val data = sampler.sampleBuffer(RandomGenerator.default, 100)
|
val data = sampler.sampleBuffer(RandomGenerator.default, 100)
|
||||||
|
@ -40,5 +40,6 @@ include(
|
|||||||
":kmath-for-real",
|
":kmath-for-real",
|
||||||
":kmath-geometry",
|
":kmath-geometry",
|
||||||
":kmath-ast",
|
":kmath-ast",
|
||||||
":examples"
|
":examples",
|
||||||
|
":kmath-ejml"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user