Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116

Merged
CommanderTvis merged 50 commits from nd4j into dev 2020-10-29 19:58:53 +03:00
247 changed files with 4504 additions and 4858 deletions
Showing only changes of commit 202bc2e904 - Show all commits

1
.space.kts Normal file
View File

@ -0,0 +1 @@
job("Build") { gradlew("openjdk:11", "build") }

View File

@ -2,17 +2,27 @@
## [Unreleased] ## [Unreleased]
### Added ### Added
- `fun` annotation for SAM interfaces in library
- Explicit `public` visibility for all public APIs
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`. - ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`.
### Changed ### Changed
- Package changed from `scientifik` to `kscience.kmath`.
- Gradle version: 6.6 -> 6.6.1
- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`)
- `Polynomial` secondary constructor made function.
### Deprecated ### Deprecated
### Removed ### Removed
- `kmath-koma` module because it doesn't support Kotlin 1.4.
### Fixed ### Fixed
- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140)
### Security ### Security
## [0.1.4] ## [0.1.4]
### Added ### Added

View File

@ -3,7 +3,7 @@
![Gradle build](https://github.com/mipt-npm/kmath/workflows/Gradle%20build/badge.svg) ![Gradle build](https://github.com/mipt-npm/kmath/workflows/Gradle%20build/badge.svg)
Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion) Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion) Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
@ -54,9 +54,6 @@ can be used for a wide variety of purposes from high performance calculations to
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
to submit a feature request if you want something to be done first. to submit a feature request if you want something to be done first.
* **Koma wrapper** [Koma](https://github.com/kyonifer/koma) is a well established numerics library in Kotlin, specifically linear algebra.
The plan is to have wrappers for koma implementations for compatibility with kmath API.
## Planned features ## Planned features
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks. * **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
@ -83,12 +80,12 @@ Release artifacts are accessible from bintray with following configuration (see
```kotlin ```kotlin
repositories{ repositories{
maven("https://dl.bintray.com/mipt-npm/scientifik") maven("https://dl.bintray.com/mipt-npm/kscience")
} }
dependencies{ dependencies{
api("scientifik:kmath-core:${kmathVersion}") api("kscience.kmath:kmath-core:${kmathVersion}")
//api("scientifik:kmath-core-jvm:${kmathVersion}") for jvm-specific version //api("kscience.kmath:kmath-core-jvm:${kmathVersion}") for jvm-specific version
} }
``` ```

View File

@ -1,13 +1,10 @@
import org.jetbrains.kotlin.gradle.dsl.KotlinProjectExtension
import scientifik.ScientifikPublishPlugin
plugins { plugins {
id("scientifik.publish") apply false id("ru.mipt.npm.base")
id("org.jetbrains.changelog") version "0.4.0" id("org.jetbrains.changelog") version "0.4.0"
} }
val kmathVersion by extra("0.1.4") val kmathVersion by extra("0.2.0-dev-1")
val bintrayRepo by extra("scientifik") val bintrayRepo by extra("kscience")
val githubProject by extra("kmath") val githubProject by extra("kmath")
allprojects { allprojects {
@ -20,17 +17,6 @@ allprojects {
group = "kscience.kmath" group = "kscience.kmath"
version = kmathVersion version = kmathVersion
afterEvaluate {
extensions.findByType<KotlinProjectExtension>()?.run {
sourceSets.all {
languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts")
}
}
}
} }
subprojects { subprojects { if (name.startsWith("kmath")) apply(plugin = "ru.mipt.npm.publish") }
if (name.startsWith("kmath"))
apply<ScientifikPublishPlugin>()
}

View File

@ -5,7 +5,7 @@ operation, say `+`, one needs two objects of a type `T` and an algebra context,
say `Space<T>`. Next one needs to run the actual operation in the context: say `Space<T>`. Next one needs to run the actual operation in the context:
```kotlin ```kotlin
import scientifik.kmath.operations.* import kscience.kmath.operations.*
val a: T = ... val a: T = ...
val b: T = ... val b: T = ...
@ -47,7 +47,7 @@ but it also holds reference to the `ComplexField` singleton, which allows perfor
numbers without explicit involving the context like: numbers without explicit involving the context like:
```kotlin ```kotlin
import scientifik.kmath.operations.* import kscience.kmath.operations.*
// Using elements // Using elements
val c1 = Complex(1.0, 1.0) val c1 = Complex(1.0, 1.0)
@ -82,7 +82,7 @@ operations in all performance-critical places. The performance of element operat
KMath submits both contexts and elements for builtin algebraic structures: KMath submits both contexts and elements for builtin algebraic structures:
```kotlin ```kotlin
import scientifik.kmath.operations.* import kscience.kmath.operations.*
val c1 = Complex(1.0, 2.0) val c1 = Complex(1.0, 2.0)
val c2 = ComplexField.i val c2 = ComplexField.i
@ -95,7 +95,7 @@ val c3 = ComplexField { c1 + c2 }
Also, `ComplexField` features special operations to mix complex and real numbers, for example: Also, `ComplexField` features special operations to mix complex and real numbers, for example:
```kotlin ```kotlin
import scientifik.kmath.operations.* import kscience.kmath.operations.*
val c1 = Complex(1.0, 2.0) val c1 = Complex(1.0, 2.0)
val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0) val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0)

View File

@ -12,6 +12,3 @@ api and multiple library back-ends.
* [Expressions](./expressions.md) * [Expressions](./expressions.md)
* Commons math integration * Commons math integration
* Koma integration

View File

@ -1,59 +1,49 @@
import org.jetbrains.kotlin.allopen.gradle.AllOpenExtension
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
plugins { plugins {
java java
kotlin("jvm") kotlin("jvm")
kotlin("plugin.allopen") version "1.3.72" kotlin("plugin.allopen") version "1.4.20-dev-3898-14"
id("kotlinx.benchmark") version "0.2.0-dev-8" id("kotlinx.benchmark") version "0.2.0-dev-20"
} }
configure<AllOpenExtension> { allOpen.annotation("org.openjdk.jmh.annotations.State")
annotation("org.openjdk.jmh.annotations.State")
}
repositories { repositories {
maven("http://dl.bintray.com/kyonifer/maven") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://dl.bintray.com/mipt-npm/scientifik")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
mavenCentral() mavenCentral()
} }
sourceSets { sourceSets.register("benchmarks")
register("benchmarks")
}
dependencies { dependencies {
implementation(project(":kmath-ast")) // implementation(project(":kmath-ast"))
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))
implementation(project(":kmath-prob")) implementation(project(":kmath-prob"))
implementation(project(":kmath-koma"))
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions")) implementation(project(":kmath-dimensions"))
implementation("com.kyonifer:koma-core-ejml:0.12")
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
} }
// Configure benchmark // Configure benchmark
benchmark { benchmark {
// Setup configurations // Setup configurations
targets { targets
// This one matches sourceSet name above // This one matches sourceSet name above
register("benchmarks") .register("benchmarks")
}
configurations { configurations.register("fast") {
register("fast") {
warmups = 5 // number of warmup iterations warmups = 5 // number of warmup iterations
iterations = 3 // number of iterations iterations = 3 // number of iterations
iterationTime = 500 // time in seconds per iteration iterationTime = 500 // time in seconds per iteration
iterationTimeUnit = "ms" // time unity for iterationTime, default is seconds iterationTimeUnit = "ms" // time unity for iterationTime, default is seconds
} }
}
} }
kotlin.sourceSets.all { kotlin.sourceSets.all {
@ -63,9 +53,4 @@ kotlin.sourceSets.all {
} }
} }
tasks.withType<KotlinCompile> { tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" }
kotlinOptions {
jvmTarget = Scientifik.JVM_TARGET.toString()
freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn"
}
}

View File

@ -0,0 +1,39 @@
package kscience.kmath.structures
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import java.nio.IntBuffer
@State(Scope.Benchmark)
class ArrayBenchmark {
@Benchmark
fun benchmarkArrayRead() {
var res = 0
for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.array[_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i]
}
@Benchmark
fun benchmarkBufferRead() {
var res = 0
for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.arrayBuffer.get(
_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i)
}
@Benchmark
fun nativeBufferRead() {
var res = 0
for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.nativeBuffer.get(
_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i)
}
companion object {
const val size: Int = 1000
val array: IntArray = IntArray(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) { it }
val arrayBuffer: IntBuffer = IntBuffer.wrap(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.array)
val nativeBuffer: IntBuffer = IntBuffer.allocate(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size).also {
for (i in 0 until _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) it.put(i, i)
}
}
}

View File

@ -1,10 +1,10 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kscience.kmath.operations.Complex
import kscience.kmath.operations.complex
import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.complex
@State(Scope.Benchmark) @State(Scope.Benchmark)
class BufferBenchmark { class BufferBenchmark {

View File

@ -1,10 +1,10 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke
import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
@State(Scope.Benchmark) @State(Scope.Benchmark)
class NDFieldBenchmark { class NDFieldBenchmark {

View File

@ -1,12 +1,12 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke
import kscience.kmath.viktor.ViktorNDField
import org.jetbrains.bio.viktor.F64Array import org.jetbrains.bio.viktor.F64Array
import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.viktor.ViktorNDField
@State(Scope.Benchmark) @State(Scope.Benchmark)
class ViktorBenchmark { class ViktorBenchmark {

View File

@ -1,4 +1,4 @@
package scientifik.kmath.utils package kscience.kmath.utils
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract

View File

@ -1,48 +0,0 @@
package scientifik.kmath.structures
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import java.nio.IntBuffer
@State(Scope.Benchmark)
class ArrayBenchmark {
@Benchmark
fun benchmarkArrayRead() {
var res = 0
for (i in 1..size) {
res += array[size - i]
}
}
@Benchmark
fun benchmarkBufferRead() {
var res = 0
for (i in 1..size) {
res += arrayBuffer.get(size - i)
}
}
@Benchmark
fun nativeBufferRead() {
var res = 0
for (i in 1..size) {
res += nativeBuffer.get(size - i)
}
}
companion object {
val size = 1000
val array = IntArray(size) { it }
val arrayBuffer = IntBuffer.wrap(array)
val nativeBuffer = IntBuffer.allocate(size).also {
for (i in 0 until size) {
it.put(i, i)
}
}
}
}

View File

@ -0,0 +1,70 @@
//package kscience.kmath.ast
//
//import kscience.kmath.asm.compile
//import kscience.kmath.expressions.Expression
//import kscience.kmath.expressions.expressionInField
//import kscience.kmath.expressions.invoke
//import kscience.kmath.operations.Field
//import kscience.kmath.operations.RealField
//import kotlin.random.Random
//import kotlin.system.measureTimeMillis
//
//class ExpressionsInterpretersBenchmark {
// private val algebra: Field<Double> = RealField
// fun functionalExpression() {
// val expr = algebra.expressionInField {
// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
// }
//
// invokeAndSum(expr)
// }
//
// fun mstExpression() {
// val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// }
//
// invokeAndSum(expr)
// }
//
// fun asmExpression() {
// val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// }.compile()
//
// invokeAndSum(expr)
// }
//
// private fun invokeAndSum(expr: Expression<Double>) {
// val random = Random(0)
// var sum = 0.0
//
// repeat(1000000) {
// sum += expr("x" to random.nextDouble())
// }
//
// println(sum)
// }
//}
//
//fun main() {
// val benchmark = ExpressionsInterpretersBenchmark()
//
// val fe = measureTimeMillis {
// benchmark.functionalExpression()
// }
//
// println("fe=$fe")
//
// val mst = measureTimeMillis {
// benchmark.mstExpression()
// }
//
// println("mst=$mst")
//
// val asm = measureTimeMillis {
// benchmark.asmExpression()
// }
//
// println("asm=$asm")
//}

View File

@ -1,12 +1,12 @@
package scientifik.kmath.commons.prob package kscience.kmath.commons.prob
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kscience.kmath.chains.BlockingRealChain
import kscience.kmath.prob.*
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource import org.apache.commons.rng.simple.RandomSource
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.prob.*
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant

View File

@ -1,11 +1,11 @@
package scientifik.kmath.commons.prob package kscience.kmath.commons.prob
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.chains.collectWithState import kscience.kmath.chains.collectWithState
import scientifik.kmath.prob.Distribution import kscience.kmath.prob.Distribution
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.normal import kscience.kmath.prob.normal
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)

View File

@ -1,4 +1,4 @@
package scientifik.kmath.operations package kscience.kmath.operations
fun main() { fun main() {
val res = BigIntField { val res = BigIntField {

View File

@ -1,8 +1,8 @@
package scientifik.kmath.operations package kscience.kmath.operations
import scientifik.kmath.structures.NDElement import kscience.kmath.structures.NDElement
import scientifik.kmath.structures.NDField import kscience.kmath.structures.NDField
import scientifik.kmath.structures.complex import kscience.kmath.structures.complex
fun main() { fun main() {
val element = NDElement.complex(2, 2) { index: IntArray -> val element = NDElement.complex(2, 2) { index: IntArray ->

View File

@ -1,9 +1,9 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.linear.transpose import kscience.kmath.linear.transpose
import scientifik.kmath.operations.Complex import kscience.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import kscience.kmath.operations.ComplexField
import scientifik.kmath.operations.invoke import kscience.kmath.operations.invoke
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
fun main() { fun main() {

View File

@ -1,8 +1,8 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import scientifik.kmath.operations.invoke import kscience.kmath.operations.invoke
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis

View File

@ -1,8 +1,8 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
fun main(args: Array<String>) { fun main() {
val n = 6000 val n = 6000
val array = DoubleArray(n * n) { 1.0 } val array = DoubleArray(n * n) { 1.0 }

View File

@ -1,10 +1,8 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
fun main() {
fun main(args: Array<String>) {
val n = 6000 val n = 6000
val structure = NDStructure.build(intArrayOf(n, n), Buffer.Companion::auto) { 1.0 } val structure = NDStructure.build(intArrayOf(n, n), Buffer.Companion::auto) { 1.0 }

View File

@ -1,10 +1,10 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.dimensions.D2 import kscience.kmath.dimensions.D2
import scientifik.kmath.dimensions.D3 import kscience.kmath.dimensions.D3
import scientifik.kmath.dimensions.DMatrixContext import kscience.kmath.dimensions.DMatrixContext
import scientifik.kmath.dimensions.Dimension import kscience.kmath.dimensions.Dimension
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
fun DMatrixContext<Double, RealField>.simple() { fun DMatrixContext<Double, RealField>.simple() {
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() } val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }

View File

@ -1,70 +0,0 @@
package scientifik.kmath.ast
import scientifik.kmath.asm.compile
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.expressionInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import kotlin.random.Random
import kotlin.system.measureTimeMillis
class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField
fun functionalExpression() {
val expr = algebra.expressionInField {
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
}
invokeAndSum(expr)
}
fun mstExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}
invokeAndSum(expr)
}
fun asmExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}.compile()
invokeAndSum(expr)
}
private fun invokeAndSum(expr: Expression<Double>) {
val random = Random(0)
var sum = 0.0
repeat(1000000) {
sum += expr("x" to random.nextDouble())
}
println(sum)
}
}
fun main() {
val benchmark = ExpressionsInterpretersBenchmark()
val fe = measureTimeMillis {
benchmark.functionalExpression()
}
println("fe=$fe")
val mst = measureTimeMillis {
benchmark.mstExpression()
}
println("mst=$mst")
val asm = measureTimeMillis {
benchmark.asmExpression()
}
println("asm=$asm")
}

View File

@ -1,53 +0,0 @@
package scientifik.kmath.linear
import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.inverse
import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix
import kotlin.random.Random
import kotlin.system.measureTimeMillis
fun main() {
val random = Random(1224)
val dim = 100
//creating invertible matrix
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
val matrix = l dot u
val n = 5000 // iterations
MatrixContext.real {
repeat(50) { val res = inverse(matrix) }
val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } }
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
}
//commons-math
val commonsTime = measureTimeMillis {
CMMatrixContext {
val cm = matrix.toCM() //avoid overhead on conversion
repeat(n) { val res = inverse(cm) }
}
}
println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis")
//koma-ejml
val komaTime = measureTimeMillis {
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
val km = matrix.toKoma() //avoid overhead on conversion
repeat(n) {
val res = inverse(km)
}
}
}
println("[koma-ejml] Inversion of $n matrices $dim x $dim finished in $komaTime millis")
}

View File

@ -1,49 +0,0 @@
package scientifik.kmath.linear
import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix
import kotlin.random.Random
import kotlin.system.measureTimeMillis
fun main() {
val random = Random(12224)
val dim = 1000
//creating invertible matrix
val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
// //warmup
// matrix1 dot matrix2
CMMatrixContext {
val cmMatrix1 = matrix1.toCM()
val cmMatrix2 = matrix2.toCM()
val cmTime = measureTimeMillis {
cmMatrix1 dot cmMatrix2
}
println("CM implementation time: $cmTime")
}
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
val komaMatrix1 = matrix1.toKoma()
val komaMatrix2 = matrix2.toKoma()
val komaTime = measureTimeMillis {
komaMatrix1 dot komaMatrix2
}
println("Koma-ejml implementation time: $komaTime")
}
val genericTime = measureTimeMillis {
val res = matrix1 dot matrix2
}
println("Generic implementation time: $genericTime")
}

Binary file not shown.

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-6.6.1-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

21
gradlew.bat vendored
View File

@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1 %JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto init if "%ERRORLEVEL%" == "0" goto execute
echo. echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
@ -54,7 +54,7 @@ goto fail
set JAVA_HOME=%JAVA_HOME:"=% set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto init if exist "%JAVA_EXE%" goto execute
echo. echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
@ -64,21 +64,6 @@ echo location of your Java installation.
goto fail goto fail
:init
@rem Get command-line arguments, handling Windows variants
if not "%OS%" == "Windows_NT" goto win9xME_args
:win9xME_args
@rem Slurp the command line arguments.
set CMD_LINE_ARGS=
set _SKIP=2
:win9xME_args_slurp
if "x%~1" == "x" goto execute
set CMD_LINE_ARGS=%*
:execute :execute
@rem Setup the command line @rem Setup the command line
@ -86,7 +71,7 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle @rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end :end
@rem End local scope for the variables with windows NT shell @rem End local scope for the variables with windows NT shell

View File

@ -8,32 +8,32 @@ This subproject implements the following features:
- Evaluating expressions by traversing MST. - Evaluating expressions by traversing MST.
> #### Artifact: > #### Artifact:
> This module is distributed in the artifact `scientifik:kmath-ast:0.1.4-dev-8`. > This module is distributed in the artifact `kscience.kmath:kmath-ast:0.1.4-dev-8`.
> >
> **Gradle:** > **Gradle:**
> >
> ```gradle > ```gradle
> repositories { > repositories {
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } > maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url https://dl.bintray.com/hotkeytlt/maven' } > maven { url https://dl.bintray.com/hotkeytlt/maven' }
> } > }
> >
> dependencies { > dependencies {
> implementation 'scientifik:kmath-ast:0.1.4-dev-8' > implementation 'kscience.kmath:kmath-ast:0.1.4-dev-8'
> } > }
> ``` > ```
> **Gradle Kotlin DSL:** > **Gradle Kotlin DSL:**
> >
> ```kotlin > ```kotlin
> repositories { > repositories {
> maven("https://dl.bintray.com/mipt-npm/scientifik") > maven("https://dl.bintray.com/mipt-npm/kscience")
> maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven") > maven("https://dl.bintray.com/hotkeytlt/maven")
> } > }
> >
> dependencies { > dependencies {
> implementation("scientifik:kmath-ast:0.1.4-dev-8") > implementation("kscience.kmath:kmath-ast:0.1.4-dev-8")
> } > }
> ``` > ```
> >
@ -52,12 +52,12 @@ RealField.mstInField { symbol("x") + 2 }.compile()
… leads to generation of bytecode, which can be decompiled to the following Java class: … leads to generation of bytecode, which can be decompiled to the following Java class:
```java ```java
package scientifik.kmath.asm.generated; package kscience.kmath.asm.generated;
import java.util.Map; import java.util.Map;
import scientifik.kmath.asm.internal.MapIntrinsics; import kscience.kmath.asm.internal.MapIntrinsics;
import scientifik.kmath.expressions.Expression; import kscience.kmath.expressions.Expression;
import scientifik.kmath.operations.RealField; import kscience.kmath.operations.RealField;
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> { public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
private final RealField algebra; private final RealField algebra;

View File

@ -1,12 +1,11 @@
plugins { plugins {
id("scientifik.mpp") id("ru.mipt.npm.mpp")
} }
kotlin.sourceSets { kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
} }
} }
@ -14,6 +13,7 @@ kotlin.sourceSets {
dependencies { dependencies {
implementation("org.ow2.asm:asm:8.0.1") implementation("org.ow2.asm:asm:8.0.1")
implementation("org.ow2.asm:asm-commons:8.0.1") implementation("org.ow2.asm:asm-commons:8.0.1")
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
implementation(kotlin("reflect")) implementation(kotlin("reflect"))
} }
} }

View File

@ -1,26 +1,28 @@
package scientifik.kmath.ast package kscience.kmath.ast
import scientifik.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra import kscience.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
/** /**
* A Mathematical Syntax Tree node for mathematical expressions. * A Mathematical Syntax Tree node for mathematical expressions.
*
* @author Alexander Nozik
*/ */
sealed class MST { public sealed class MST {
/** /**
* A node containing raw string. * A node containing raw string.
* *
* @property value the value of this node. * @property value the value of this node.
*/ */
data class Symbolic(val value: String) : MST() public data class Symbolic(val value: String) : MST()
/** /**
* A node containing a numeric value or scalar. * A node containing a numeric value or scalar.
* *
* @property value the value of this number. * @property value the value of this number.
*/ */
data class Numeric(val value: Number) : MST() public data class Numeric(val value: Number) : MST()
/** /**
* A node containing an unary operation. * A node containing an unary operation.
@ -28,9 +30,7 @@ sealed class MST {
* @property operation the identifier of operation. * @property operation the identifier of operation.
* @property value the argument of this operation. * @property value the argument of this operation.
*/ */
data class Unary(val operation: String, val value: MST) : MST() { public data class Unary(val operation: String, val value: MST) : MST()
companion object
}
/** /**
* A node containing binary operation. * A node containing binary operation.
@ -39,9 +39,7 @@ sealed class MST {
* @property left the left operand. * @property left the left operand.
* @property right the right operand. * @property right the right operand.
*/ */
data class Binary(val operation: String, val left: MST, val right: MST) : MST() { public data class Binary(val operation: String, val left: MST, val right: MST) : MST()
companion object
}
} }
// TODO add a function with named arguments // TODO add a function with named arguments
@ -52,8 +50,9 @@ sealed class MST {
* @receiver the algebra that provides operations. * @receiver the algebra that provides operations.
* @param node the node to evaluate. * @param node the node to evaluate.
* @return the value of expression. * @return the value of expression.
* @author Alexander Nozik
*/ */
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) { public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
@ -84,4 +83,4 @@ fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
* @param algebra the algebra that provides operations. * @param algebra the algebra that provides operations.
* @return the value of expression. * @return the value of expression.
*/ */
fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this) public fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -1,11 +1,11 @@
package scientifik.kmath.ast package kscience.kmath.ast
import scientifik.kmath.operations.* import kscience.kmath.operations.*
/** /**
* [Algebra] over [MST] nodes. * [Algebra] over [MST] nodes.
*/ */
object MstAlgebra : NumericAlgebra<MST> { public object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value) override fun number(value: Number): MST = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value) override fun symbol(value: String): MST = MST.Symbolic(value)
@ -20,7 +20,7 @@ object MstAlgebra : NumericAlgebra<MST> {
/** /**
* [Space] over [MST] nodes. * [Space] over [MST] nodes.
*/ */
object MstSpace : Space<MST>, NumericAlgebra<MST> { public object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0) override val zero: MST = number(0.0)
override fun number(value: Number): MST = MstAlgebra.number(value) override fun number(value: Number): MST = MstAlgebra.number(value)
@ -37,8 +37,9 @@ object MstSpace : Space<MST>, NumericAlgebra<MST> {
/** /**
* [Ring] over [MST] nodes. * [Ring] over [MST] nodes.
*/ */
object MstRing : Ring<MST>, NumericAlgebra<MST> { public object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0) override val zero: MST
get() = MstSpace.zero
override val one: MST = number(1.0) override val one: MST = number(1.0)
override fun number(value: Number): MST = MstSpace.number(value) override fun number(value: Number): MST = MstSpace.number(value)
@ -58,18 +59,21 @@ object MstRing : Ring<MST>, NumericAlgebra<MST> {
/** /**
* [Field] over [MST] nodes. * [Field] over [MST] nodes.
*/ */
object MstField : Field<MST> { public object MstField : Field<MST> {
override val zero: MST = number(0.0) public override val zero: MST
override val one: MST = number(1.0) get() = MstRing.zero
override fun symbol(value: String): MST = MstRing.symbol(value) public override val one: MST
override fun number(value: Number): MST = MstRing.number(value) get() = MstRing.one
override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = public override fun symbol(value: String): MST = MstRing.symbol(value)
public override fun number(value: Number): MST = MstRing.number(value)
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
public override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstRing.binaryOperation(operation, left, right) MstRing.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
@ -78,15 +82,26 @@ object MstField : Field<MST> {
/** /**
* [ExtendedField] over [MST] nodes. * [ExtendedField] over [MST] nodes.
*/ */
object MstExtendedField : ExtendedField<MST> { public object MstExtendedField : ExtendedField<MST> {
override val zero: MST = number(0.0) override val zero: MST
override val one: MST = number(1.0) get() = MstField.zero
override val one: MST
get() = MstField.one
override fun symbol(value: String): MST = MstField.symbol(value)
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
override fun add(a: MST, b: MST): MST = MstField.add(a, b) override fun add(a: MST, b: MST): MST = MstField.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)

View File

@ -1,7 +1,7 @@
package scientifik.kmath.ast package kscience.kmath.ast
import scientifik.kmath.expressions.* import kscience.kmath.expressions.*
import scientifik.kmath.operations.* import kscience.kmath.operations.*
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -11,8 +11,9 @@ import kotlin.contracts.contract
* *
* @property algebra the algebra that provides operations. * @property algebra the algebra that provides operations.
* @property mst the [MST] node. * @property mst the [MST] node.
* @author Alexander Nozik
*/ */
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> { public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
@ -31,72 +32,92 @@ class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
/** /**
* Builds [MstExpression] over [Algebra]. * Builds [MstExpression] over [Algebra].
*
* @author Alexander Nozik
*/ */
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst( public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E, mstAlgebra: E,
block: E.() -> MST block: E.() -> MST
): MstExpression<T> = MstExpression(this, mstAlgebra.block()) ): MstExpression<T> = MstExpression(this, mstAlgebra.block())
/** /**
* Builds [MstExpression] over [Space]. * Builds [MstExpression] over [Space].
*
* @author Alexander Nozik
*/ */
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstSpace.block()) return MstExpression(this, MstSpace.block())
} }
/** /**
* Builds [MstExpression] over [Ring]. * Builds [MstExpression] over [Ring].
*
* @author Alexander Nozik
*/ */
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block()) return MstExpression(this, MstRing.block())
} }
/** /**
* Builds [MstExpression] over [Field]. * Builds [MstExpression] over [Field].
*
* @author Alexander Nozik
*/ */
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block()) return MstExpression(this, MstField.block())
} }
/** /**
* Builds [MstExpression] over [ExtendedField]. * Builds [MstExpression] over [ExtendedField].
*
* @author Iaroslav Postovalov
*/ */
inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> { public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block()) return MstExpression(this, MstExtendedField.block())
} }
/** /**
* Builds [MstExpression] over [FunctionalExpressionSpace]. * Builds [MstExpression] over [FunctionalExpressionSpace].
*
* @author Alexander Nozik
*/ */
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInSpace(block) return algebra.mstInSpace(block)
} }
/** /**
* Builds [MstExpression] over [FunctionalExpressionRing]. * Builds [MstExpression] over [FunctionalExpressionRing].
*
* @author Alexander Nozik
*/ */
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block) return algebra.mstInRing(block)
} }
/** /**
* Builds [MstExpression] over [FunctionalExpressionField]. * Builds [MstExpression] over [FunctionalExpressionField].
*
* @author Alexander Nozik
*/ */
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block) return algebra.mstInField(block)
} }
/** /**
* Builds [MstExpression] over [FunctionalExpressionExtendedField]. * Builds [MstExpression] over [FunctionalExpressionExtendedField].
*
* @author Iaroslav Postovalov
*/ */
inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST
): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block) return algebra.mstInExtendedField(block)
} }

View File

@ -1,19 +1,24 @@
package scientifik.kmath.asm package kscience.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder import kscience.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.MstType import kscience.kmath.asm.internal.MstType
import scientifik.kmath.asm.internal.buildAlgebraOperationCall import kscience.kmath.asm.internal.buildAlgebraOperationCall
import scientifik.kmath.asm.internal.buildName import kscience.kmath.asm.internal.buildName
import scientifik.kmath.ast.MST import kscience.kmath.ast.MST
import scientifik.kmath.ast.MstExpression import kscience.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
* Compile given MST to an Expression using AST compiler * Compiles given MST to an Expression using AST compiler.
*
* @param type the target type.
* @param algebra the target algebra.
* @return the compiled expression.
* @author Alexander Nozik
*/ */
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> { public fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> { is MST.Symbolic -> {
@ -54,11 +59,15 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
} }
/** /**
* Compile an [MST] to ASM using given algebra * Compiles an [MST] to ASM using given algebra.
*
* @author Alexander Nozik.
*/ */
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this) public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
/** /**
* Optimize performance of an [MstExpression] using ASM codegen * Optimizes performance of an [MstExpression] using ASM codegen.
*
* @author Alexander Nozik.
*/ */
inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra) public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)

View File

@ -1,13 +1,13 @@
package scientifik.kmath.asm.internal package kscience.kmath.asm.internal
import kscience.kmath.asm.internal.AsmBuilder.ClassLoader
import kscience.kmath.ast.MST
import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import org.objectweb.asm.* import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import java.util.* import java.util.*
import java.util.stream.Collectors import java.util.stream.Collectors
import kotlin.reflect.KClass import kotlin.reflect.KClass
@ -20,6 +20,7 @@ import kotlin.reflect.KClass
* @property algebra the algebra the applied AsmExpressions use. * @property algebra the algebra the applied AsmExpressions use.
* @property className the unique class name of new loaded class. * @property className the unique class name of new loaded class.
* @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
* @author Iaroslav Postovalov
*/ */
internal class AsmBuilder<T> internal constructor( internal class AsmBuilder<T> internal constructor(
private val classOfT: KClass<*>, private val classOfT: KClass<*>,
@ -563,6 +564,6 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* ASM type for MapIntrinsics. * ASM type for MapIntrinsics.
*/ */
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/asm/internal/MapIntrinsics") }
} }
} }

View File

@ -1,7 +1,10 @@
package scientifik.kmath.asm.internal package kscience.kmath.asm.internal
import scientifik.kmath.ast.MST import kscience.kmath.ast.MST
/**
* Represents types known in [MST], numbers and general values.
*/
internal enum class MstType { internal enum class MstType {
GENERAL, GENERAL,
NUMBER; NUMBER;

View File

@ -1,11 +1,14 @@
package scientifik.kmath.asm.internal package kscience.kmath.asm.internal
import kscience.kmath.ast.MST
import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.FieldOperations
import kscience.kmath.operations.RingOperations
import kscience.kmath.operations.SpaceOperations
import org.objectweb.asm.* import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import java.lang.reflect.Method import java.lang.reflect.Method
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -13,20 +16,27 @@ import kotlin.reflect.KClass
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy { private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
hashMapOf( hashMapOf(
"+" to 2 to "add", SpaceOperations.PLUS_OPERATION to 2 to "add",
"*" to 2 to "multiply", RingOperations.TIMES_OPERATION to 2 to "multiply",
"/" to 2 to "divide", FieldOperations.DIV_OPERATION to 2 to "divide",
"+" to 1 to "unaryPlus", SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus",
"-" to 1 to "unaryMinus", SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus",
"-" to 2 to "minus" SpaceOperations.MINUS_OPERATION to 2 to "minus"
) )
} }
/**
* Returns ASM [Type] for given [KClass].
*
* @author Iaroslav Postovalov
*/
internal val KClass<*>.asm: Type internal val KClass<*>.asm: Type
get() = Type.getType(java) get() = Type.getType(java)
/** /**
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
*
* @author Iaroslav Postovalov
*/ */
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> { internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> {
contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) }
@ -35,11 +45,15 @@ internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Arra
/** /**
* Creates an [InstructionAdapter] from this [MethodVisitor]. * Creates an [InstructionAdapter] from this [MethodVisitor].
*
* @author Iaroslav Postovalov
*/ */
private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
/** /**
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
*
* @author Iaroslav Postovalov
*/ */
internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter { internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
@ -48,6 +62,8 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
/** /**
* Constructs a [Label], then applies it to this visitor. * Constructs a [Label], then applies it to this visitor.
*
* @author Iaroslav Postovalov
*/ */
internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) } internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
@ -56,9 +72,11 @@ internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
* *
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
*
* @author Iaroslav Postovalov
*/ */
internal tailrec fun buildName(mst: MST, collision: Int = 0): String { internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" val name = "kscience.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try { try {
Class.forName(name) Class.forName(name)
@ -75,6 +93,11 @@ internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): Clas
return ClassWriter(flags).apply(block) return ClassWriter(flags).apply(block)
} }
/**
* Invokes [visitField] and applies [block] to the [FieldVisitor].
*
* @author Iaroslav Postovalov
*/
internal inline fun ClassWriter.visitField( internal inline fun ClassWriter.visitField(
access: Int, access: Int,
name: String, name: String,
@ -104,7 +127,7 @@ private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, pa
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
* type expectation stack for needed arity. * type expectation stack for needed arity.
* *
* @return `true` if contains, else `false`. * @author Iaroslav Postovalov
*/ */
private fun <T> AsmBuilder<T>.buildExpectationStack( private fun <T> AsmBuilder<T>.buildExpectationStack(
context: Algebra<T>, context: Algebra<T>,
@ -136,7 +159,7 @@ private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<Mst
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method. * [AsmBuilder.invokeAlgebraOperation] of this method.
* *
* @return `true` if contains, else `false`. * @author Iaroslav Postovalov
*/ */
private fun <T> AsmBuilder<T>.tryInvokeSpecific( private fun <T> AsmBuilder<T>.tryInvokeSpecific(
context: Algebra<T>, context: Algebra<T>,
@ -160,7 +183,9 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(
} }
/** /**
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String. * Builds specialized [context] call with option to fallback to generic algebra operation accepting [String].
*
* @author Iaroslav Postovalov
*/ */
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall( internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
context: Algebra<T>, context: Algebra<T>,

View File

@ -1,7 +1,12 @@
@file:JvmName("MapIntrinsics") @file:JvmName("MapIntrinsics")
package scientifik.kmath.asm.internal package kscience.kmath.asm.internal
/**
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
*
* @author Iaroslav Postovalov
*/
@JvmOverloads @JvmOverloads
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V = internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
this[key] ?: default ?: error("Parameter not found: $key") this[key] ?: default ?: error("Parameter not found: $key")

View File

@ -1,4 +1,4 @@
package scientifik.kmath.ast package kscience.kmath.ast
import com.github.h0tk3y.betterParse.combinators.* import com.github.h0tk3y.betterParse.combinators.*
import com.github.h0tk3y.betterParse.grammar.Grammar import com.github.h0tk3y.betterParse.grammar.Grammar
@ -10,15 +10,16 @@ import com.github.h0tk3y.betterParse.lexer.TokenMatch
import com.github.h0tk3y.betterParse.lexer.regexToken import com.github.h0tk3y.betterParse.lexer.regexToken
import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations import kscience.kmath.operations.FieldOperations
import scientifik.kmath.operations.PowerOperations import kscience.kmath.operations.PowerOperations
import scientifik.kmath.operations.RingOperations import kscience.kmath.operations.RingOperations
import scientifik.kmath.operations.SpaceOperations import kscience.kmath.operations.SpaceOperations
/** /**
* TODO move to core * TODO move to common after IR version is released
* @author Alexander Nozik and Iaroslav Postovalov
*/ */
object ArithmeticsEvaluator : Grammar<MST>() { public object ArithmeticsEvaluator : Grammar<MST>() {
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released // TODO replace with "...".toRegex() when better-parse 0.4.1 is released
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?") private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*") private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
@ -35,23 +36,23 @@ object ArithmeticsEvaluator : Grammar<MST>() {
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) } private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val singular: Parser<MST> by id use { MST.Symbolic(text) } private val singular: Parser<MST> by id use { MST.Symbolic(text) }
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar)) private val unaryFunction: Parser<MST> by (id and -lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
.map { (id, term) -> MST.Unary(id.text, term) } .map { (id, term) -> MST.Unary(id.text, term) }
private val binaryFunction: Parser<MST> by id private val binaryFunction: Parser<MST> by id
.and(skip(lpar)) .and(-lpar)
.and(parser(::subSumChain)) .and(parser(ArithmeticsEvaluator::subSumChain))
.and(skip(comma)) .and(-comma)
.and(parser(::subSumChain)) .and(parser(ArithmeticsEvaluator::subSumChain))
.and(skip(rpar)) .and(-rpar)
.map { (id, left, right) -> MST.Binary(id.text, left, right) } .map { (id, left, right) -> MST.Binary(id.text, left, right) }
private val term: Parser<MST> by number private val term: Parser<MST> by number
.or(binaryFunction) .or(binaryFunction)
.or(unaryFunction) .or(unaryFunction)
.or(singular) .or(singular)
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
.or(skip(lpar) and parser(::subSumChain) and skip(rpar)) .or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b -> private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b) MST.Binary(PowerOperations.POW_OPERATION, a, b)
@ -85,13 +86,15 @@ object ArithmeticsEvaluator : Grammar<MST>() {
* *
* @receiver the string to parse. * @receiver the string to parse.
* @return the [MST] node. * @return the [MST] node.
* @author Alexander Nozik
*/ */
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this) public fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
/** /**
* Parses the string into [MST]. * Parses the string into [MST].
* *
* @receiver the string to parse. * @receiver the string to parse.
* @return the [MST] node. * @return the [MST] node.
* @author Alexander Nozik
*/ */
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) public fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)

View File

@ -1,12 +1,12 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.asm.compile import kscience.kmath.asm.compile
import scientifik.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import scientifik.kmath.ast.mstInRing import kscience.kmath.ast.mstInRing
import scientifik.kmath.ast.mstInSpace import kscience.kmath.ast.mstInSpace
import scientifik.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing import kscience.kmath.operations.ByteRing
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,10 +1,10 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.asm.compile import kscience.kmath.asm.compile
import scientifik.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import scientifik.kmath.ast.mstInSpace import kscience.kmath.ast.mstInSpace
import scientifik.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,9 +1,9 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.asm.compile import kscience.kmath.asm.compile
import scientifik.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import scientifik.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,8 +1,8 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.ast.mstInRing import kscience.kmath.ast.mstInRing
import scientifik.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing import kscience.kmath.operations.ByteRing
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith

View File

@ -1,12 +1,12 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import scientifik.kmath.asm.compile import kscience.kmath.asm.compile
import scientifik.kmath.asm.expression import kscience.kmath.asm.expression
import scientifik.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import kscience.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import scientifik.kmath.operations.Complex import kscience.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import kscience.kmath.operations.ComplexField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,9 +1,9 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate import kscience.kmath.ast.evaluate
import scientifik.kmath.ast.parseMath import kscience.kmath.ast.parseMath
import scientifik.kmath.operations.Field import kscience.kmath.operations.Field
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,13 +1,13 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate import kscience.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import kscience.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import scientifik.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import scientifik.kmath.operations.Complex import kscience.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import kscience.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals

View File

@ -1,13 +1,12 @@
plugins { plugins {
id("scientifik.jvm") id("ru.mipt.npm.jvm")
} }
description = "Commons math binding for kmath" description = "Commons math binding for kmath"
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
api(project(":kmath-coroutines")) api(project(":kmath-coroutines"))
api(project(":kmath-prob")) api(project(":kmath-prob"))
api(project(":kmath-functions")) // api(project(":kmath-functions"))
api("org.apache.commons:commons-math3:3.6.1") api("org.apache.commons:commons-math3:3.6.1")
} }

View File

@ -0,0 +1,128 @@
package kscience.kmath.commons.expressions
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.ExpressionAlgebra
import kscience.kmath.operations.ExtendedField
import kscience.kmath.operations.Field
import kscience.kmath.operations.invoke
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import kotlin.properties.ReadOnlyProperty
/**
* A field wrapping commons-math derivative structures
*/
public class DerivativeStructureField(
public val order: Int,
public val parameters: Map<String, Double>
) : ExtendedField<DerivativeStructure> {
public override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
}
public val variable: ReadOnlyProperty<Any?, DerivativeStructure> = ReadOnlyProperty { _, property ->
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
}
public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
variables[name] ?: default ?: error("A variable with name $name does not exist")
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
return deriv(mapOf(parName to order))
}
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
}
public fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
is Double -> a.multiply(k)
is Int -> a.multiply(k)
else -> a.multiply(k.toDouble())
}
public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow)
is Int -> arg.pow(pow)
else -> arg.pow(pow.toDouble())
}
public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
public class DiffExpression(public val function: DerivativeStructureField.() -> DerivativeStructure) :
Expression<Double> {
public override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
0,
arguments
).function().value
/**
* Get the derivative expression with given orders
* TODO make result [DiffExpression]
*/
public fun derivative(orders: Map<String, Int>): Expression<Double> = Expression { arguments ->
(DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) }
}
//TODO add gradient and maybe other vector operators
}
public fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
public fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
/**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/
public object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
public override val zero: DiffExpression = DiffExpression { 0.0.const() }
public override val one: DiffExpression = DiffExpression { 1.0.const() }
public override fun variable(name: String, default: Double?): DiffExpression =
DiffExpression { variable(name, default?.const()) }
public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) + b.function(this) }
public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) * b.function(this) }
public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) / b.function(this) }
}

View File

@ -0,0 +1,93 @@
package kscience.kmath.commons.linear
import kscience.kmath.linear.*
import kscience.kmath.structures.Matrix
import kscience.kmath.structures.NDStructure
import org.apache.commons.math3.linear.*
public class CMMatrix(public val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
FeaturedMatrix<Double> {
public override val rowNum: Int get() = origin.rowDimension
public override val colNum: Int get() = origin.columnDimension
public override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
if (origin is DiagonalMatrix) yield(DiagonalFeature)
}.toHashSet()
public override fun suggestFeature(vararg features: MatrixFeature): CMMatrix =
CMMatrix(origin, this.features + features)
public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
public override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
public override fun hashCode(): Int {
var result = origin.hashCode()
result = 31 * result + features.hashCode()
return result
}
}
public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
this
} else {
//TODO add feature analysis
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
CMMatrix(Array2DRowRealMatrix(array))
}
public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
public class CMVector(public val origin: RealVector) : Point<Double> {
public override val size: Int get() = origin.dimension
public override operator fun get(index: Int): Double = origin.getEntry(index)
public override operator fun iterator(): Iterator<Double> = origin.toArray().iterator()
}
public fun Point<Double>.toCM(): CMVector = if (this is CMVector) this else {
val array = DoubleArray(size) { this[it] }
CMVector(ArrayRealVector(array))
}
public fun RealVector.toPoint(): CMVector = CMVector(this)
public object CMMatrixContext : MatrixContext<Double> {
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
return CMMatrix(Array2DRowRealMatrix(array))
}
public override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
CMMatrix(toCM().origin.multiply(other.toCM().origin))
public override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
CMVector(toCM().origin.preMultiply(vector.toCM().origin))
public override operator fun Matrix<Double>.unaryMinus(): CMMatrix =
produce(rowNum, colNum) { i, j -> -get(i, j) }
public override fun add(a: Matrix<Double>, b: Matrix<Double>): CMMatrix =
CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
public override operator fun Matrix<Double>.minus(b: Matrix<Double>): CMMatrix =
CMMatrix(toCM().origin.subtract(b.toCM().origin))
public override fun multiply(a: Matrix<Double>, k: Number): CMMatrix =
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
public override operator fun Matrix<Double>.times(value: Double): Matrix<Double> =
produce(rowNum, colNum) { i, j -> get(i, j) * value }
}
public operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =
CMMatrix(origin.add(other.origin))
public operator fun CMMatrix.minus(other: CMMatrix): CMMatrix =
CMMatrix(origin.subtract(other.origin))
public infix fun CMMatrix.dot(other: CMMatrix): CMMatrix =
CMMatrix(origin.multiply(other.origin))

View File

@ -0,0 +1,41 @@
package kscience.kmath.commons.linear
import kscience.kmath.linear.Point
import kscience.kmath.structures.Matrix
import org.apache.commons.math3.linear.*
public enum class CMDecomposition {
LUP,
QR,
RRQR,
EIGEN,
CHOLESKY
}
public fun CMMatrixContext.solver(
a: Matrix<Double>,
decomposition: CMDecomposition = CMDecomposition.LUP
): DecompositionSolver = when (decomposition) {
CMDecomposition.LUP -> LUDecomposition(a.toCM().origin).solver
CMDecomposition.RRQR -> RRQRDecomposition(a.toCM().origin).solver
CMDecomposition.QR -> QRDecomposition(a.toCM().origin).solver
CMDecomposition.EIGEN -> EigenDecomposition(a.toCM().origin).solver
CMDecomposition.CHOLESKY -> CholeskyDecomposition(a.toCM().origin).solver
}
public fun CMMatrixContext.solve(
a: Matrix<Double>,
b: Matrix<Double>,
decomposition: CMDecomposition = CMDecomposition.LUP
): CMMatrix = solver(a, decomposition).solve(b.toCM().origin).asMatrix()
public fun CMMatrixContext.solve(
a: Matrix<Double>,
b: Point<Double>,
decomposition: CMDecomposition = CMDecomposition.LUP
): CMVector = solver(a, decomposition).solve(b.toCM().origin).toPoint()
public fun CMMatrixContext.inverse(
a: Matrix<Double>,
decomposition: CMDecomposition = CMDecomposition.LUP
): CMMatrix = solver(a, decomposition).inverse.asMatrix()

View File

@ -0,0 +1,33 @@
package kscience.kmath.commons.random
import kscience.kmath.prob.RandomGenerator
public class CMRandomGeneratorWrapper(public val factory: (IntArray) -> RandomGenerator) :
org.apache.commons.math3.random.RandomGenerator {
private var generator: RandomGenerator = factory(intArrayOf())
public override fun nextBoolean(): Boolean = generator.nextBoolean()
public override fun nextFloat(): Float = generator.nextDouble().toFloat()
public override fun setSeed(seed: Int) {
generator = factory(intArrayOf(seed))
}
public override fun setSeed(seed: IntArray) {
generator = factory(seed)
}
public override fun setSeed(seed: Long) {
setSeed(seed.toInt())
}
public override fun nextBytes(bytes: ByteArray) {
generator.fillBytes(bytes)
}
public override fun nextInt(): Int = generator.nextInt()
public override fun nextInt(n: Int): Int = generator.nextInt(n)
public override fun nextGaussian(): Double = TODO()
public override fun nextDouble(): Double = generator.nextDouble()
public override fun nextLong(): Long = generator.nextLong()
}

View File

@ -1,20 +1,19 @@
package scientifik.kmath.commons.transform package kscience.kmath.commons.transform
import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import kscience.kmath.operations.Complex
import kscience.kmath.streaming.chunked
import kscience.kmath.streaming.spread
import kscience.kmath.structures.*
import org.apache.commons.math3.transform.* import org.apache.commons.math3.transform.*
import scientifik.kmath.operations.Complex
import scientifik.kmath.streaming.chunked
import scientifik.kmath.streaming.spread
import scientifik.kmath.structures.*
/** /**
* Streaming and buffer transformations * Streaming and buffer transformations
*/ */
object Transformations { public object Transformations {
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> = private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
@ -32,35 +31,35 @@ object Transformations {
Complex(value.real, value.imaginary) Complex(value.real, value.imaginary)
} }
fun fourier( public fun fourier(
normalization: DftNormalization = DftNormalization.STANDARD, normalization: DftNormalization = DftNormalization.STANDARD,
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): SuspendBufferTransform<Complex, Complex> = { ): SuspendBufferTransform<Complex, Complex> = {
FastFourierTransformer(normalization).transform(it.toArray(), direction).asBuffer() FastFourierTransformer(normalization).transform(it.toArray(), direction).asBuffer()
} }
fun realFourier( public fun realFourier(
normalization: DftNormalization = DftNormalization.STANDARD, normalization: DftNormalization = DftNormalization.STANDARD,
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): SuspendBufferTransform<Double, Complex> = { ): SuspendBufferTransform<Double, Complex> = {
FastFourierTransformer(normalization).transform(it.asArray(), direction).asBuffer() FastFourierTransformer(normalization).transform(it.asArray(), direction).asBuffer()
} }
fun sine( public fun sine(
normalization: DstNormalization = DstNormalization.STANDARD_DST_I, normalization: DstNormalization = DstNormalization.STANDARD_DST_I,
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): SuspendBufferTransform<Double, Double> = { ): SuspendBufferTransform<Double, Double> = {
FastSineTransformer(normalization).transform(it.asArray(), direction).asBuffer() FastSineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
} }
fun cosine( public fun cosine(
normalization: DctNormalization = DctNormalization.STANDARD_DCT_I, normalization: DctNormalization = DctNormalization.STANDARD_DCT_I,
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): SuspendBufferTransform<Double, Double> = { ): SuspendBufferTransform<Double, Double> = {
FastCosineTransformer(normalization).transform(it.asArray(), direction).asBuffer() FastCosineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
} }
fun hadamard( public fun hadamard(
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): SuspendBufferTransform<Double, Double> = { ): SuspendBufferTransform<Double, Double> = {
FastHadamardTransformer().transform(it.asArray(), direction).asBuffer() FastHadamardTransformer().transform(it.asArray(), direction).asBuffer()
@ -71,7 +70,7 @@ object Transformations {
* Process given [Flow] with commons-math fft transformation * Process given [Flow] with commons-math fft transformation
*/ */
@FlowPreview @FlowPreview
fun Flow<Buffer<Complex>>.FFT( public fun Flow<Buffer<Complex>>.FFT(
normalization: DftNormalization = DftNormalization.STANDARD, normalization: DftNormalization = DftNormalization.STANDARD,
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): Flow<Buffer<Complex>> { ): Flow<Buffer<Complex>> {
@ -81,7 +80,7 @@ fun Flow<Buffer<Complex>>.FFT(
@FlowPreview @FlowPreview
@JvmName("realFFT") @JvmName("realFFT")
fun Flow<Buffer<Double>>.FFT( public fun Flow<Buffer<Double>>.FFT(
normalization: DftNormalization = DftNormalization.STANDARD, normalization: DftNormalization = DftNormalization.STANDARD,
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): Flow<Buffer<Complex>> { ): Flow<Buffer<Complex>> {
@ -90,20 +89,18 @@ fun Flow<Buffer<Double>>.FFT(
} }
/** /**
* Process a continous flow of real numbers in FFT splitting it in chunks of [bufferSize]. * Process a continuous flow of real numbers in FFT splitting it in chunks of [bufferSize].
*/ */
@FlowPreview @FlowPreview
@JvmName("realFFT") @JvmName("realFFT")
fun Flow<Double>.FFT( public fun Flow<Double>.FFT(
bufferSize: Int = Int.MAX_VALUE, bufferSize: Int = Int.MAX_VALUE,
normalization: DftNormalization = DftNormalization.STANDARD, normalization: DftNormalization = DftNormalization.STANDARD,
direction: TransformType = TransformType.FORWARD direction: TransformType = TransformType.FORWARD
): Flow<Complex> { ): Flow<Complex> = chunked(bufferSize).FFT(normalization, direction).spread()
return chunked(bufferSize).FFT(normalization,direction).spread()
}
/** /**
* Map a complex flow into real flow by taking real part of each number * Map a complex flow into real flow by taking real part of each number
*/ */
@FlowPreview @FlowPreview
fun Flow<Complex>.real(): Flow<Double> = map{it.re} public fun Flow<Complex>.real(): Flow<Double> = map { it.re }

View File

@ -1,137 +0,0 @@
package scientifik.kmath.commons.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.invoke
import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KProperty
/**
* A field wrapping commons-math derivative structures
*/
class DerivativeStructureField(
val order: Int,
val parameters: Map<String, Double>
) : ExtendedField<DerivativeStructure> {
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
}
val variable: ReadOnlyProperty<Any?, DerivativeStructure> = object : ReadOnlyProperty<Any?, DerivativeStructure> {
override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure =
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
}
fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
variables[name] ?: default ?: error("A variable with name $name does not exist")
fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
return deriv(mapOf(parName to order))
}
fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
}
fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
is Double -> a.multiply(k)
is Int -> a.multiply(k)
else -> a.multiply(k.toDouble())
}
override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow)
is Int -> arg.pow(pow)
else -> arg.pow(pow.toDouble())
}
fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
0,
arguments
).run(function).value
/**
* Get the derivative expression with given orders
* TODO make result [DiffExpression]
*/
fun derivative(orders: Map<String, Int>): Expression<Double> = object : Expression<Double> {
override operator fun invoke(arguments: Map<String, Double>): Double =
(DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) }
}
//TODO add gradient and maybe other vector operators
}
fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
/**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
override fun variable(name: String, default: Double?): DiffExpression =
DiffExpression { variable(name, default?.const()) }
override fun const(value: Double): DiffExpression =
DiffExpression { value.const() }
override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) + b.function(this) }
override val zero: DiffExpression = DiffExpression { 0.0.const() }
override fun multiply(a: DiffExpression, k: Number): DiffExpression =
DiffExpression { a.function(this) * k }
override val one: DiffExpression = DiffExpression { 1.0.const() }
override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) * b.function(this) }
override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) / b.function(this) }
}

View File

@ -1,93 +0,0 @@
package scientifik.kmath.commons.linear
import org.apache.commons.math3.linear.*
import scientifik.kmath.linear.*
import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.NDStructure
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
FeaturedMatrix<Double> {
override val rowNum: Int get() = origin.rowDimension
override val colNum: Int get() = origin.columnDimension
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
if (origin is DiagonalMatrix) yield(DiagonalFeature)
}.toHashSet()
override fun suggestFeature(vararg features: MatrixFeature): CMMatrix =
CMMatrix(origin, this.features + features)
override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun hashCode(): Int {
var result = origin.hashCode()
result = 31 * result + features.hashCode()
return result
}
}
fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
this
} else {
//TODO add feature analysis
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
CMMatrix(Array2DRowRealMatrix(array))
}
fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
class CMVector(val origin: RealVector) : Point<Double> {
override val size: Int get() = origin.dimension
override operator fun get(index: Int): Double = origin.getEntry(index)
override operator fun iterator(): Iterator<Double> = origin.toArray().iterator()
}
fun Point<Double>.toCM(): CMVector = if (this is CMVector) this else {
val array = DoubleArray(size) { this[it] }
CMVector(ArrayRealVector(array))
}
fun RealVector.toPoint(): CMVector = CMVector(this)
object CMMatrixContext : MatrixContext<Double> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
return CMMatrix(Array2DRowRealMatrix(array))
}
override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
override operator fun Matrix<Double>.unaryMinus(): CMMatrix =
produce(rowNum, colNum) { i, j -> -get(i, j) }
override fun add(a: Matrix<Double>, b: Matrix<Double>): CMMatrix =
CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
override operator fun Matrix<Double>.minus(b: Matrix<Double>): CMMatrix =
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
override fun multiply(a: Matrix<Double>, k: Number): CMMatrix =
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
override operator fun Matrix<Double>.times(value: Double): Matrix<Double> =
produce(rowNum, colNum) { i, j -> get(i, j) * value }
}
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =
CMMatrix(this.origin.add(other.origin))
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix =
CMMatrix(this.origin.subtract(other.origin))
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix =
CMMatrix(this.origin.multiply(other.origin))

View File

@ -1,40 +0,0 @@
package scientifik.kmath.commons.linear
import org.apache.commons.math3.linear.*
import scientifik.kmath.linear.Point
import scientifik.kmath.structures.Matrix
enum class CMDecomposition {
LUP,
QR,
RRQR,
EIGEN,
CHOLESKY
}
fun CMMatrixContext.solver(a: Matrix<Double>, decomposition: CMDecomposition = CMDecomposition.LUP) =
when (decomposition) {
CMDecomposition.LUP -> LUDecomposition(a.toCM().origin).solver
CMDecomposition.RRQR -> RRQRDecomposition(a.toCM().origin).solver
CMDecomposition.QR -> QRDecomposition(a.toCM().origin).solver
CMDecomposition.EIGEN -> EigenDecomposition(a.toCM().origin).solver
CMDecomposition.CHOLESKY -> CholeskyDecomposition(a.toCM().origin).solver
}
fun CMMatrixContext.solve(
a: Matrix<Double>,
b: Matrix<Double>,
decomposition: CMDecomposition = CMDecomposition.LUP
) = solver(a, decomposition).solve(b.toCM().origin).asMatrix()
fun CMMatrixContext.solve(
a: Matrix<Double>,
b: Point<Double>,
decomposition: CMDecomposition = CMDecomposition.LUP
) = solver(a, decomposition).solve(b.toCM().origin).toPoint()
fun CMMatrixContext.inverse(
a: Matrix<Double>,
decomposition: CMDecomposition = CMDecomposition.LUP
) = solver(a, decomposition).inverse.asMatrix()

View File

@ -1,33 +0,0 @@
package scientifik.kmath.commons.random
import scientifik.kmath.prob.RandomGenerator
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
org.apache.commons.math3.random.RandomGenerator {
private var generator: RandomGenerator = factory(intArrayOf())
override fun nextBoolean(): Boolean = generator.nextBoolean()
override fun nextFloat(): Float = generator.nextDouble().toFloat()
override fun setSeed(seed: Int) {
generator = factory(intArrayOf(seed))
}
override fun setSeed(seed: IntArray) {
generator = factory(seed)
}
override fun setSeed(seed: Long) {
setSeed(seed.toInt())
}
override fun nextBytes(bytes: ByteArray) {
generator.fillBytes(bytes)
}
override fun nextInt(): Int = generator.nextInt()
override fun nextInt(n: Int): Int = generator.nextInt(n)
override fun nextGaussian(): Double = TODO()
override fun nextDouble(): Double = generator.nextDouble()
override fun nextLong(): Long = generator.nextLong()
}

View File

@ -1,17 +1,21 @@
package scientifik.kmath.commons.expressions package kscience.kmath.commons.expressions
import scientifik.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R): R { internal inline fun <R> diff(
order: Int,
vararg parameters: Pair<String, Double>,
block: DerivativeStructureField.() -> R
): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block) return DerivativeStructureField(order, mapOf(*parameters)).run(block)
} }
class AutoDiffTest { internal class AutoDiffTest {
@Test @Test
fun derivativeStructureFieldTest() { fun derivativeStructureFieldTest() {
val res = diff(3, "x" to 1.0, "y" to 1.0) { val res = diff(3, "x" to 1.0, "y" to 1.0) {

View File

@ -10,31 +10,31 @@ The core features of KMath:
- Automatic differentiation. - Automatic differentiation.
> #### Artifact: > #### Artifact:
> This module is distributed in the artifact `scientifik:kmath-core:0.1.4-dev-8`. > This module is distributed in the artifact `kscience.kmath:kmath-core:0.1.4-dev-8`.
> >
> **Gradle:** > **Gradle:**
> >
> ```gradle > ```gradle
> repositories { > repositories {
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } > maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url https://dl.bintray.com/hotkeytlt/maven' } > maven { url https://dl.bintray.com/hotkeytlt/maven' }
> } > }
> >
> dependencies { > dependencies {
> implementation 'scientifik:kmath-core:0.1.4-dev-8' > implementation 'kscience.kmath:kmath-core:0.1.4-dev-8'
> } > }
> ``` > ```
> **Gradle Kotlin DSL:** > **Gradle Kotlin DSL:**
> >
> ```kotlin > ```kotlin
> repositories { > repositories {
> maven("https://dl.bintray.com/mipt-npm/scientifik") > maven("https://dl.bintray.com/mipt-npm/kscience")
> maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven") > maven("https://dl.bintray.com/hotkeytlt/maven")
> } > }
> >
> dependencies {`` > dependencies {
> implementation("scientifik:kmath-core:0.1.4-dev-8") > implementation("kscience.kmath:kmath-core:0.1.4-dev-8")
> } > }
> ``` > ```

View File

@ -1,6 +1,4 @@
plugins { plugins { id("ru.mipt.npm.mpp") }
id("scientifik.mpp")
}
kotlin.sourceSets.commonMain { kotlin.sourceSets.commonMain {
dependencies { dependencies {

View File

@ -1,20 +1,20 @@
package scientifik.kmath.domains package kscience.kmath.domains
import scientifik.kmath.linear.Point import kscience.kmath.linear.Point
/** /**
* A simple geometric domain. * A simple geometric domain.
* *
* @param T the type of element of this domain. * @param T the type of element of this domain.
*/ */
interface Domain<T : Any> { public interface Domain<T : Any> {
/** /**
* Checks if the specified point is contained in this domain. * Checks if the specified point is contained in this domain.
*/ */
operator fun contains(point: Point<T>): Boolean public operator fun contains(point: Point<T>): Boolean
/** /**
* Number of hyperspace dimensions. * Number of hyperspace dimensions.
*/ */
val dimension: Int public val dimension: Int
} }

View File

@ -13,11 +13,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package scientifik.kmath.domains package kscience.kmath.domains
import scientifik.kmath.linear.Point import kscience.kmath.linear.Point
import scientifik.kmath.structures.RealBuffer import kscience.kmath.structures.RealBuffer
import scientifik.kmath.structures.indices import kscience.kmath.structures.indices
/** /**
* *
@ -25,23 +25,22 @@ import scientifik.kmath.structures.indices
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { public class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
public override val dimension: Int get() = lower.size
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i -> public override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
point[i] in lower[i]..upper[i] point[i] in lower[i]..upper[i]
} }
override val dimension: Int get() = lower.size public override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num] public override fun getLowerBound(num: Int): Double? = lower[num]
override fun getLowerBound(num: Int): Double? = lower[num] public override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num] public override fun getUpperBound(num: Int): Double? = upper[num]
override fun getUpperBound(num: Int): Double? = upper[num] public override fun nearestInDomain(point: Point<Double>): Point<Double> {
override fun nearestInDomain(point: Point<Double>): Point<Double> {
val res = DoubleArray(point.size) { i -> val res = DoubleArray(point.size) { i ->
when { when {
point[i] < lower[i] -> lower[i] point[i] < lower[i] -> lower[i]
@ -53,16 +52,14 @@ class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBu
return RealBuffer(*res) return RealBuffer(*res)
} }
override fun volume(): Double { public override fun volume(): Double {
var res = 1.0 var res = 1.0
for (i in 0 until dimension) { for (i in 0 until dimension) {
if (lower[i].isInfinite() || upper[i].isInfinite()) { if (lower[i].isInfinite() || upper[i].isInfinite()) return Double.POSITIVE_INFINITY
return Double.POSITIVE_INFINITY if (upper[i] > lower[i]) res *= upper[i] - lower[i]
}
if (upper[i] > lower[i]) {
res *= upper[i] - lower[i]
}
} }
return res return res
} }
} }

View File

@ -13,17 +13,17 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package scientifik.kmath.domains package kscience.kmath.domains
import scientifik.kmath.linear.Point import kscience.kmath.linear.Point
/** /**
* n-dimensional volume * n-dimensional volume
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
interface RealDomain : Domain<Double> { public interface RealDomain : Domain<Double> {
fun nearestInDomain(point: Point<Double>): Point<Double> public fun nearestInDomain(point: Point<Double>): Point<Double>
/** /**
* The lower edge for the domain going down from point * The lower edge for the domain going down from point
@ -31,7 +31,7 @@ interface RealDomain : Domain<Double> {
* @param point * @param point
* @return * @return
*/ */
fun getLowerBound(num: Int, point: Point<Double>): Double? public fun getLowerBound(num: Int, point: Point<Double>): Double?
/** /**
* The upper edge of the domain going up from point * The upper edge of the domain going up from point
@ -39,25 +39,25 @@ interface RealDomain : Domain<Double> {
* @param point * @param point
* @return * @return
*/ */
fun getUpperBound(num: Int, point: Point<Double>): Double? public fun getUpperBound(num: Int, point: Point<Double>): Double?
/** /**
* Global lower edge * Global lower edge
* @param num * @param num
* @return * @return
*/ */
fun getLowerBound(num: Int): Double? public fun getLowerBound(num: Int): Double?
/** /**
* Global upper edge * Global upper edge
* @param num * @param num
* @return * @return
*/ */
fun getUpperBound(num: Int): Double? public fun getUpperBound(num: Int): Double?
/** /**
* Hyper volume * Hyper volume
* @return * @return
*/ */
fun volume(): Double public fun volume(): Double
} }

View File

@ -0,0 +1,34 @@
/*
* Copyright 2015 Alexander Nozik.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kscience.kmath.domains
import kscience.kmath.linear.Point
public class UnconstrainedDomain(public override val dimension: Int) : RealDomain {
public override operator fun contains(point: Point<Double>): Boolean = true
public override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
public override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
public override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
public override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
public override fun nearestInDomain(point: Point<Double>): Point<Double> = point
public override fun volume(): Double = Double.POSITIVE_INFINITY
}

View File

@ -0,0 +1,49 @@
package kscience.kmath.domains
import kscience.kmath.linear.Point
import kscience.kmath.structures.asBuffer
public inline class UnivariateDomain(public val range: ClosedFloatingPointRange<Double>) : RealDomain {
public override val dimension: Int
get() = 1
public operator fun contains(d: Double): Boolean = range.contains(d)
public override operator fun contains(point: Point<Double>): Boolean {
require(point.size == 0)
return contains(point[0])
}
public override fun nearestInDomain(point: Point<Double>): Point<Double> {
require(point.size == 1)
val value = point[0]
return when {
value in range -> point
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
else -> doubleArrayOf(range.start).asBuffer()
}
}
public override fun getLowerBound(num: Int, point: Point<Double>): Double? {
require(num == 0)
return range.start
}
public override fun getUpperBound(num: Int, point: Point<Double>): Double? {
require(num == 0)
return range.endInclusive
}
public override fun getLowerBound(num: Int): Double? {
require(num == 0)
return range.start
}
public override fun getUpperBound(num: Int): Double? {
require(num == 0)
return range.endInclusive
}
public override fun volume(): Double = range.endInclusive - range.start
}

View File

@ -0,0 +1,41 @@
package kscience.kmath.expressions
import kscience.kmath.operations.Algebra
/**
* An elementary function that could be invoked on a map of arguments
*/
public fun interface Expression<T> {
/**
* Calls this expression from arguments.
*
* @param arguments the map of arguments.
* @return the value.
*/
public operator fun invoke(arguments: Map<String, T>): T
public companion object
}
/**
* Calls this expression from arguments.
*
* @param pairs the pair of arguments' names to values.
* @return the value.
*/
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
/**
* A context for expression construction
*/
public interface ExpressionAlgebra<T, E> : Algebra<E> {
/**
* Introduce a variable into expression context
*/
public fun variable(name: String, default: T? = null): E
/**
* A constant expression which does not depend on arguments
*/
public fun const(value: T): E
}

View File

@ -0,0 +1,171 @@
package kscience.kmath.expressions
import kscience.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> {
public override operator fun invoke(arguments: Map<String, T>): T =
context.unaryOperation(name, expr.invoke(arguments))
}
internal class FunctionalBinaryOperation<T>(
val context: Algebra<T>,
val name: String,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
public override operator fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
}
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
public override operator fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
public override operator fun invoke(arguments: Map<String, T>): T = value
}
internal class FunctionalConstProductExpression<T>(
val context: Space<T>,
private val expr: Expression<T>,
val const: Number
) : Expression<T> {
public override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
/**
* A context class for [Expression] construction.
*
* @param algebra The algebra to provide for Expressions built.
*/
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
ExpressionAlgebra<T, Expression<T>> {
/**
* Builds an Expression of constant expression which does not depend on arguments.
*/
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
/**
* Builds an Expression to access a variable.
*/
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, operation, left, right)
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
FunctionalUnaryOperation(algebra, operation, arg)
}
/**
* A context class for [Expression] construction for [Space] algebras.
*/
public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
public override val zero: Expression<T> get() = const(algebra.zero)
/**
* Builds an Expression of addition of two another expressions.
*/
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an Expression of multiplication of expression by number.
*/
public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
FunctionalConstProductExpression(algebra, a, k)
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
}
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
public override val one: Expression<T>
get() = const(algebra.one)
/**
* Builds an Expression of multiplication of two expressions.
*/
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
}
public open class FunctionalExpressionField<T, A>(algebra: A) :
FunctionalExpressionRing<T, A>(algebra),
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an Expression of division an expression by another one.
*/
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
}
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
FunctionalExpressionField<T, A>(algebra),
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
public override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
public override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
public override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
public override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
public override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
public override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.unaryOperation(operation, arg)
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
}
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionSpace(this).block()
public inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionRing(this).block()
public inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).block()
public inline fun <T, A : ExtendedField<T>> A.expressionInExtendedField(block: FunctionalExpressionExtendedField<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionExtendedField(this).block()

View File

@ -1,16 +1,16 @@
package scientifik.kmath.expressions package kscience.kmath.expressions
import scientifik.kmath.operations.ExtendedField import kscience.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import kscience.kmath.operations.Field
import scientifik.kmath.operations.Ring import kscience.kmath.operations.Ring
import scientifik.kmath.operations.Space import kscience.kmath.operations.Space
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
/** /**
* Creates a functional expression with this [Space]. * Creates a functional expression with this [Space].
*/ */
inline fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> { public inline fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionSpace(this).block() return FunctionalExpressionSpace(this).block()
} }
@ -18,7 +18,7 @@ inline fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Spac
/** /**
* Creates a functional expression with this [Ring]. * Creates a functional expression with this [Ring].
*/ */
inline fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> { public inline fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionRing(this).block() return FunctionalExpressionRing(this).block()
} }
@ -26,7 +26,7 @@ inline fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>
/** /**
* Creates a functional expression with this [Field]. * Creates a functional expression with this [Field].
*/ */
inline fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> { public inline fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionField(this).block() return FunctionalExpressionField(this).block()
} }
@ -34,7 +34,7 @@ inline fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Fiel
/** /**
* Creates a functional expression with this [ExtendedField]. * Creates a functional expression with this [ExtendedField].
*/ */
inline fun <T> ExtendedField<T>.extendedFieldExpression(block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>): Expression<T> { public inline fun <T> ExtendedField<T>.extendedFieldExpression(block: FunctionalExpressionExtendedField<T, ExtendedField<T>>.() -> Expression<T>): Expression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return FunctionalExpressionExtendedField(this).block() return FunctionalExpressionExtendedField(this).block()
} }

View File

@ -0,0 +1,113 @@
package kscience.kmath.linear
import kscience.kmath.operations.RealField
import kscience.kmath.operations.Ring
import kscience.kmath.structures.*
/**
* Basic implementation of Matrix space based on [NDStructure]
*/
public class BufferMatrixContext<T : Any, R : Ring<T>>(
public override val elementContext: R,
private val bufferFactory: BufferFactory<T>
) : GenericMatrixContext<T, R> {
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer)
}
public override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
public companion object
}
@Suppress("OVERRIDE_BY_INLINE")
public object RealMatrixContext : GenericMatrixContext<Double, RealField> {
public override val elementContext: RealField
get() = RealField
public override inline fun produce(
rows: Int,
columns: Int,
initializer: (i: Int, j: Int) -> Double
): Matrix<Double> {
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer)
}
public override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
RealBuffer(size, initializer)
}
public class BufferMatrix<T : Any>(
public override val rowNum: Int,
public override val colNum: Int,
public val buffer: Buffer<out T>,
public override val features: Set<MatrixFeature> = emptySet()
) : FeaturedMatrix<T> {
override val shape: IntArray
get() = intArrayOf(rowNum, colNum)
init {
require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" }
}
public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
BufferMatrix(rowNum, colNum, buffer, this.features + features)
public override operator fun get(index: IntArray): T = get(index[0], index[1])
public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
public override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
}
public override fun equals(other: Any?): Boolean {
if (this === other) return true
return when (other) {
is NDStructure<*> -> return NDStructure.equals(this, other)
else -> false
}
}
public override fun hashCode(): Int {
var result = buffer.hashCode()
result = 31 * result + features.hashCode()
return result
}
public override fun toString(): String {
return if (rowNum <= 5 && colNum <= 5)
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
buffer.asSequence().joinToString(separator = "\t") { it.toString() }
}
else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
}
}
/**
* Optimized dot product for real matrices
*/
public infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
val array = DoubleArray(this.rowNum * other.colNum)
//convert to array to insure there is not memory indirection
fun Buffer<out Double>.unsafeArray() = if (this is RealBuffer)
array
else
DoubleArray(size) { get(it) }
val a = this.buffer.unsafeArray()
val b = other.buffer.unsafeArray()
for (i in (0 until rowNum))
for (j in (0 until other.colNum))
for (k in (0 until colNum))
array[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j]
val buffer = RealBuffer(array)
return BufferMatrix(rowNum, other.colNum, buffer)
}

View File

@ -1,20 +1,17 @@
package scientifik.kmath.linear package kscience.kmath.linear
import scientifik.kmath.operations.Ring import kscience.kmath.operations.Ring
import scientifik.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import kscience.kmath.structures.Structure2D
import scientifik.kmath.structures.asBuffer import kscience.kmath.structures.asBuffer
import kotlin.contracts.contract
import kotlin.math.sqrt import kotlin.math.sqrt
/** /**
* A 2d structure plus optional matrix-specific features * A 2d structure plus optional matrix-specific features
*/ */
interface FeaturedMatrix<T : Any> : Matrix<T> { public interface FeaturedMatrix<T : Any> : Matrix<T> {
override val shape: IntArray get() = intArrayOf(rowNum, colNum) override val shape: IntArray get() = intArrayOf(rowNum, colNum)
public val features: Set<MatrixFeature>
val features: Set<MatrixFeature>
/** /**
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure. * Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure.
@ -22,44 +19,42 @@ interface FeaturedMatrix<T : Any> : Matrix<T> {
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to * The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
* add only those features that are valid. * add only those features that are valid.
*/ */
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> public fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
companion object public companion object
} }
inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> { public inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
contract { callsInPlace(initializer) } MatrixContext.real.produce(rows, columns, initializer)
return MatrixContext.real.produce(rows, columns, initializer)
}
/** /**
* Build a square matrix from given elements. * Build a square matrix from given elements.
*/ */
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> { public fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
val size: Int = sqrt(elements.size.toDouble()).toInt() val size: Int = sqrt(elements.size.toDouble()).toInt()
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer) return BufferMatrix(size, size, buffer)
} }
val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet() public val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
/** /**
* Check if matrix has the given feature class * Check if matrix has the given feature class
*/ */
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = public inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
features.find { it is T } != null features.find { it is T } != null
/** /**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/ */
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = public inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
features.filterIsInstance<T>().firstOrNull() features.filterIsInstance<T>().firstOrNull()
/** /**
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created * Diagonal matrix of ones. The matrix is virtual no actual matrix is created
*/ */
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> = public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
if (i == j) elementContext.one else elementContext.zero if (i == j) elementContext.one else elementContext.zero
} }
@ -68,20 +63,20 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
/** /**
* A virtual matrix of zeroes * A virtual matrix of zeroes
*/ */
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> = public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } VirtualMatrix(rows, columns) { _, _ -> elementContext.zero }
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature public class TransposedFeature<T : Any>(public val original: Matrix<T>) : MatrixFeature
/** /**
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
*/ */
fun <T : Any> Matrix<T>.transpose(): Matrix<T> { public fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix( return getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
this.colNum, colNum,
this.rowNum, rowNum,
setOf(TransposedFeature(this)) setOf(TransposedFeature(this))
) { i, j -> get(j, i) } ) { i, j -> get(j, i) }
} }
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) } public infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }

View File

@ -1,25 +1,25 @@
package scientifik.kmath.linear package kscience.kmath.linear
import scientifik.kmath.operations.Field import kscience.kmath.operations.Field
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import scientifik.kmath.operations.Ring import kscience.kmath.operations.Ring
import scientifik.kmath.operations.invoke import kscience.kmath.operations.invoke
import scientifik.kmath.structures.BufferAccessor2D import kscience.kmath.structures.BufferAccessor2D
import scientifik.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import scientifik.kmath.structures.Structure2D import kscience.kmath.structures.Structure2D
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
* Common implementation of [LUPDecompositionFeature] * Common implementation of [LUPDecompositionFeature]
*/ */
class LUPDecomposition<T : Any>( public class LUPDecomposition<T : Any>(
val context: GenericMatrixContext<T, out Field<T>>, public val context: GenericMatrixContext<T, out Field<T>>,
val lu: Structure2D<T>, public val lu: Structure2D<T>,
val pivot: IntArray, public val pivot: IntArray,
private val even: Boolean private val even: Boolean
) : LUPDecompositionFeature<T>, DeterminantFeature<T> { ) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
public val elementContext: Field<T>
val elementContext: Field<T> get() = context.elementContext get() = context.elementContext
/** /**
* Returns the matrix L of the decomposition. * Returns the matrix L of the decomposition.
@ -44,7 +44,6 @@ class LUPDecomposition<T : Any>(
if (j >= i) lu[i, j] else elementContext.zero if (j >= i) lu[i, j] else elementContext.zero
} }
/** /**
* Returns the P rows permutation matrix. * Returns the P rows permutation matrix.
* *
@ -55,7 +54,6 @@ class LUPDecomposition<T : Any>(
if (j == pivot[i]) elementContext.one else elementContext.zero if (j == pivot[i]) elementContext.one else elementContext.zero
} }
/** /**
* Return the determinant of the matrix * Return the determinant of the matrix
* @return determinant of the matrix * @return determinant of the matrix
@ -66,22 +64,18 @@ class LUPDecomposition<T : Any>(
} }
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T = public fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T): T =
if (value > elementContext.zero) value else elementContext { -value } if (value > elementContext.zero) value else elementContext { -value }
/** /**
* Create a lup decomposition of generic matrix * Create a lup decomposition of generic matrix
*/ */
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup( public inline fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
type: KClass<T>, type: KClass<T>,
matrix: Matrix<T>, matrix: Matrix<T>,
checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): LUPDecomposition<T> { ): LUPDecomposition<T> {
if (matrix.rowNum != matrix.colNum) { require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
error("LU decomposition supports only square matrices")
}
val m = matrix.colNum val m = matrix.colNum
val pivot = IntArray(matrix.rowNum) val pivot = IntArray(matrix.rowNum)
@ -154,15 +148,15 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
} }
} }
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): LUPDecomposition<T> = lup(T::class, matrix, checkSingular) ): LUPDecomposition<T> = lup(T::class, matrix, checkSingular)
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> = public fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
lup(Double::class, matrix) { it < 1e-11 } lup(Double::class, matrix) { it < 1e-11 }
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> { public fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
@ -207,27 +201,27 @@ fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Mat
} }
} }
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> = solve(T::class, matrix) public inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>): Matrix<T> = solve(T::class, matrix)
/** /**
* Solve a linear equation **a*x = b** * Solve a linear equation **a*x = b**
*/ */
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
a: Matrix<T>, a: Matrix<T>,
b: Matrix<T>, b: Matrix<T>,
noinline checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): Matrix<T> { ): Matrix<T> {
// Use existing decomposition if it is provided by matrix // Use existing decomposition if it is provided by matrix
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular) val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
return decomposition.solve(T::class, b) return decomposition.solve(T::class, b)
} }
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> = solve(a, b) { it < 1e-11 } public fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> = solve(a, b) { it < 1e-11 }
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular) ): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> = public fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 } solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }

View File

@ -0,0 +1,27 @@
package kscience.kmath.linear
import kscience.kmath.structures.Buffer
import kscience.kmath.structures.Matrix
import kscience.kmath.structures.VirtualBuffer
public typealias Point<T> = Buffer<T>
/**
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
*/
public interface LinearSolver<T : Any> {
public fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
public fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.asMatrix()).asPoint()
public fun inverse(a: Matrix<T>): Matrix<T>
}
/**
* Convert matrix to vector if it is possible
*/
public fun <T : Any> Matrix<T>.asPoint(): Point<T> =
if (this.colNum == 1)
VirtualBuffer(rowNum) { get(it, 0) }
else
error("Can't convert matrix with more than one column to vector")
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }

View File

@ -1,12 +1,12 @@
package scientifik.kmath.linear package kscience.kmath.linear
import scientifik.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import kscience.kmath.structures.BufferFactory
import scientifik.kmath.structures.Structure2D import kscience.kmath.structures.Structure2D
import scientifik.kmath.structures.asBuffer import kscience.kmath.structures.asBuffer
class MatrixBuilder(val rows: Int, val columns: Int) { public class MatrixBuilder(public val rows: Int, public val columns: Int) {
operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> { public operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
@ -15,14 +15,14 @@ class MatrixBuilder(val rows: Int, val columns: Int) {
//TODO add specific matrix builder functions like diagonal, etc //TODO add specific matrix builder functions like diagonal, etc
} }
fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns)
fun <T : Any> Structure2D.Companion.row(vararg values: T): FeaturedMatrix<T> { public fun <T : Any> Structure2D.Companion.row(vararg values: T): FeaturedMatrix<T> {
val buffer = values.asBuffer() val buffer = values.asBuffer()
return BufferMatrix(1, values.size, buffer) return BufferMatrix(1, values.size, buffer)
} }
inline fun <reified T : Any> Structure2D.Companion.row( public inline fun <reified T : Any> Structure2D.Companion.row(
size: Int, size: Int,
factory: BufferFactory<T> = Buffer.Companion::auto, factory: BufferFactory<T> = Buffer.Companion::auto,
noinline builder: (Int) -> T noinline builder: (Int) -> T
@ -31,12 +31,12 @@ inline fun <reified T : Any> Structure2D.Companion.row(
return BufferMatrix(1, size, buffer) return BufferMatrix(1, size, buffer)
} }
fun <T : Any> Structure2D.Companion.column(vararg values: T): FeaturedMatrix<T> { public fun <T : Any> Structure2D.Companion.column(vararg values: T): FeaturedMatrix<T> {
val buffer = values.asBuffer() val buffer = values.asBuffer()
return BufferMatrix(values.size, 1, buffer) return BufferMatrix(values.size, 1, buffer)
} }
inline fun <reified T : Any> Structure2D.Companion.column( public inline fun <reified T : Any> Structure2D.Companion.column(
size: Int, size: Int,
factory: BufferFactory<T> = Buffer.Companion::auto, factory: BufferFactory<T> = Buffer.Companion::auto,
noinline builder: (Int) -> T noinline builder: (Int) -> T

View File

@ -1,41 +1,42 @@
package scientifik.kmath.linear package kscience.kmath.linear
import scientifik.kmath.operations.Ring import kscience.kmath.operations.Ring
import scientifik.kmath.operations.SpaceOperations import kscience.kmath.operations.SpaceOperations
import scientifik.kmath.operations.invoke import kscience.kmath.operations.invoke
import scientifik.kmath.operations.sum import kscience.kmath.operations.sum
import scientifik.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import kscience.kmath.structures.BufferFactory
import scientifik.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import scientifik.kmath.structures.asSequence import kscience.kmath.structures.asSequence
/** /**
* Basic operations on matrices. Operates on [Matrix] * Basic operations on matrices. Operates on [Matrix]
*/ */
interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> { public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
/** /**
* Produce a matrix with this context and given dimensions * Produce a matrix with this context and given dimensions
*/ */
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T> public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> public infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
infix fun Matrix<T>.dot(vector: Point<T>): Point<T> public infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
operator fun Matrix<T>.times(value: T): Matrix<T> public operator fun Matrix<T>.times(value: T): Matrix<T>
operator fun T.times(m: Matrix<T>): Matrix<T> = m * this public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
companion object { public companion object {
/** /**
* Non-boxing double matrix * Non-boxing double matrix
*/ */
val real: RealMatrixContext = RealMatrixContext public val real: RealMatrixContext
get() = RealMatrixContext
/** /**
* A structured matrix with custom buffer * A structured matrix with custom buffer
*/ */
fun <T : Any, R : Ring<T>> buffered( public fun <T : Any, R : Ring<T>> buffered(
ring: R, ring: R,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
): GenericMatrixContext<T, R> = BufferMatrixContext(ring, bufferFactory) ): GenericMatrixContext<T, R> = BufferMatrixContext(ring, bufferFactory)
@ -43,21 +44,21 @@ interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
/** /**
* Automatic buffered matrix, unboxed if it is possible * Automatic buffered matrix, unboxed if it is possible
*/ */
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> = public inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> =
buffered(ring, Buffer.Companion::auto) buffered(ring, Buffer.Companion::auto)
} }
} }
interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> { public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
/** /**
* The ring context for matrix elements * The ring context for matrix elements
*/ */
val elementContext: R public val elementContext: R
/** /**
* Produce a point compatible with matrix space * Produce a point compatible with matrix space
*/ */
fun point(size: Int, initializer: (Int) -> T): Point<T> public fun point(size: Int, initializer: (Int) -> T): Point<T>
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> { override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
//TODO add typed error //TODO add typed error
@ -102,7 +103,7 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> = override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this public operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
override operator fun Matrix<T>.times(value: T): Matrix<T> = override operator fun Matrix<T>.times(value: T): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }

View File

@ -0,0 +1,62 @@
package kscience.kmath.linear
/**
* A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix
* operations performance in some cases.
*/
public interface MatrixFeature
/**
* The matrix with this feature is considered to have only diagonal non-null elements
*/
public object DiagonalFeature : MatrixFeature
/**
* Matrix with this feature has all zero elements
*/
public object ZeroFeature : MatrixFeature
/**
* Matrix with this feature have unit elements on diagonal and zero elements in all other places
*/
public object UnitFeature : MatrixFeature
/**
* Inverted matrix feature
*/
public interface InverseMatrixFeature<T : Any> : MatrixFeature {
public val inverse: FeaturedMatrix<T>
}
/**
* A determinant container
*/
public interface DeterminantFeature<T : Any> : MatrixFeature {
public val determinant: T
}
@Suppress("FunctionName")
public fun <T : Any> DeterminantFeature(determinant: T): DeterminantFeature<T> = object : DeterminantFeature<T> {
override val determinant: T = determinant
}
/**
* Lower triangular matrix
*/
public object LFeature : MatrixFeature
/**
* Upper triangular feature
*/
public object UFeature : MatrixFeature
/**
* TODO add documentation
*/
public interface LUPDecompositionFeature<T : Any> : MatrixFeature {
public val l: FeaturedMatrix<T>
public val u: FeaturedMatrix<T>
public val p: FeaturedMatrix<T>
}
//TODO add sparse matrix feature

View File

@ -1,21 +1,21 @@
package scientifik.kmath.linear package kscience.kmath.linear
import scientifik.kmath.operations.RealField import kscience.kmath.operations.RealField
import scientifik.kmath.operations.Space import kscience.kmath.operations.Space
import scientifik.kmath.operations.invoke import kscience.kmath.operations.invoke
import scientifik.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import kscience.kmath.structures.BufferFactory
/** /**
* A linear space for vectors. * A linear space for vectors.
* Could be used on any point-like structure * Could be used on any point-like structure
*/ */
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> { public interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
val size: Int public val size: Int
val space: S public val space: S
override val zero: Point<T> get() = produce { space.zero } override val zero: Point<T> get() = produce { space.zero }
fun produce(initializer: (Int) -> T): Point<T> public fun produce(initializer: (Int) -> T): Point<T>
/** /**
* Produce a space-element of this vector space for expressions * Produce a space-element of this vector space for expressions
@ -28,13 +28,13 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
//TODO add basis //TODO add basis
companion object { public companion object {
private val realSpaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf() private val realSpaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
/** /**
* Non-boxing double vector space * Non-boxing double vector space
*/ */
fun real(size: Int): BufferVectorSpace<Double, RealField> = realSpaceCache.getOrPut(size) { public fun real(size: Int): BufferVectorSpace<Double, RealField> = realSpaceCache.getOrPut(size) {
BufferVectorSpace( BufferVectorSpace(
size, size,
RealField, RealField,
@ -45,7 +45,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
/** /**
* A structured vector space with custom buffer * A structured vector space with custom buffer
*/ */
fun <T : Any, S : Space<T>> buffered( public fun <T : Any, S : Space<T>> buffered(
size: Int, size: Int,
space: S, space: S,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
@ -54,16 +54,16 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
/** /**
* Automatic buffered vector, unboxed if it is possible * Automatic buffered vector, unboxed if it is possible
*/ */
inline fun <reified T : Any, S : Space<T>> auto(size: Int, space: S): VectorSpace<T, S> = public inline fun <reified T : Any, S : Space<T>> auto(size: Int, space: S): VectorSpace<T, S> =
buffered(size, space, Buffer.Companion::auto) buffered(size, space, Buffer.Companion::auto)
} }
} }
class BufferVectorSpace<T : Any, S : Space<T>>( public class BufferVectorSpace<T : Any, S : Space<T>>(
override val size: Int, override val size: Int,
override val space: S, override val space: S,
val bufferFactory: BufferFactory<T> public val bufferFactory: BufferFactory<T>
) : VectorSpace<T, S> { ) : VectorSpace<T, S> {
override fun produce(initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer) override fun produce(initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer)) //override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))

View File

@ -1,15 +1,19 @@
package scientifik.kmath.linear package kscience.kmath.linear
import scientifik.kmath.structures.Matrix import kscience.kmath.structures.Matrix
class VirtualMatrix<T : Any>( public class VirtualMatrix<T : Any>(
override val rowNum: Int, override val rowNum: Int,
override val colNum: Int, override val colNum: Int,
override val features: Set<MatrixFeature> = emptySet(), override val features: Set<MatrixFeature> = emptySet(),
val generator: (i: Int, j: Int) -> T public val generator: (i: Int, j: Int) -> T
) : FeaturedMatrix<T> { ) : FeaturedMatrix<T> {
public constructor(
constructor(rowNum: Int, colNum: Int, vararg features: MatrixFeature, generator: (i: Int, j: Int) -> T) : this( rowNum: Int,
colNum: Int,
vararg features: MatrixFeature,
generator: (i: Int, j: Int) -> T
) : this(
rowNum, rowNum,
colNum, colNum,
setOf(*features), setOf(*features),
@ -42,18 +46,15 @@ class VirtualMatrix<T : Any>(
} }
companion object { public companion object {
/** /**
* Wrap a matrix adding additional features to it * Wrap a matrix adding additional features to it
*/ */
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> { public fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
return if (matrix is VirtualMatrix) { return if (matrix is VirtualMatrix)
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator) VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
} else { else
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> matrix[i, j] }
matrix[i, j]
}
}
} }
} }
} }

View File

@ -1,11 +1,11 @@
package scientifik.kmath.misc package kscience.kmath.misc
import scientifik.kmath.linear.Point import kscience.kmath.linear.Point
import scientifik.kmath.operations.ExtendedField import kscience.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import kscience.kmath.operations.Field
import scientifik.kmath.operations.invoke import kscience.kmath.operations.invoke
import scientifik.kmath.operations.sum import kscience.kmath.operations.sum
import scientifik.kmath.structures.asBuffer import kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -18,24 +18,24 @@ import kotlin.contracts.contract
* Differentiable variable with value and derivative of differentiation ([deriv]) result * Differentiable variable with value and derivative of differentiation ([deriv]) result
* with respect to this variable. * with respect to this variable.
*/ */
open class Variable<T : Any>(val value: T) public open class Variable<T : Any>(public val value: T)
class DerivationResult<T : Any>( public class DerivationResult<T : Any>(
value: T, value: T,
val deriv: Map<Variable<T>, T>, public val deriv: Map<Variable<T>, T>,
val context: Field<T> public val context: Field<T>
) : Variable<T>(value) { ) : Variable<T>(value) {
fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
/** /**
* compute divergence * compute divergence
*/ */
fun div(): T = context { sum(deriv.values) } public fun div(): T = context { sum(deriv.values) }
/** /**
* Compute a gradient for variables in given order * Compute a gradient for variables in given order
*/ */
fun grad(vararg variables: Variable<T>): Point<T> { public fun grad(vararg variables: Variable<T>): Point<T> {
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
return variables.map(::deriv).asBuffer() return variables.map(::deriv).asBuffer()
} }
@ -54,7 +54,7 @@ class DerivationResult<T : Any>(
* assertEquals(9.0, x.d) // dy/dx * assertEquals(9.0, x.d) // dy/dx
* ``` * ```
*/ */
inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> { public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return (AutoDiffContext(this)) { return (AutoDiffContext(this)) {
@ -65,15 +65,14 @@ inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Varia
} }
} }
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> { public abstract val context: F
abstract val context: F
/** /**
* A variable accessing inner state of derivatives. * A variable accessing inner state of derivatives.
* Use this function in inner builders to avoid creating additional derivative bindings * Use this function in inner builders to avoid creating additional derivative bindings
*/ */
abstract var Variable<T>.d: T public abstract var Variable<T>.d: T
/** /**
* Performs update of derivative after the rest of the formula in the back-pass. * Performs update of derivative after the rest of the formula in the back-pass.
@ -86,11 +85,11 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
* } * }
* ``` * ```
*/ */
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
abstract fun variable(value: T): Variable<T> public abstract fun variable(value: T): Variable<T>
inline fun variable(block: F.() -> T): Variable<T> = variable(context.block()) public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
// Overloads for Double constants // Overloads for Double constants
@ -152,7 +151,6 @@ internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) :
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z -> override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
a.d += z.d a.d += z.d
b.d += z.d b.d += z.d
@ -176,35 +174,73 @@ internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) :
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
// x ^ 2 // x ^ 2
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> = public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2 // x ^ 1/2
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const) // x ^ y (const)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble()) public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
pow(x, y.toDouble())
// exp(x) // exp(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value } derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x) // ln(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value } derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any) // x ^ y (any)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
exp(y * ln(x)) exp(y * ln(x))
// sin(x) // sin(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x) // cos(x)
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
derive(variable { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
derive(variable { tan(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }

View File

@ -1,4 +1,4 @@
package scientifik.kmath.misc package kscience.kmath.misc
import kotlin.math.abs import kotlin.math.abs
@ -10,17 +10,21 @@ import kotlin.math.abs
* *
* If step is negative, the same goes from upper boundary downwards * If step is negative, the same goes from upper boundary downwards
*/ */
fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Sequence<Double> = when { public fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Sequence<Double> = when {
step == 0.0 -> error("Zero step in double progression") step == 0.0 -> error("Zero step in double progression")
step > 0 -> sequence { step > 0 -> sequence {
var current = start var current = start
while (current <= endInclusive) { while (current <= endInclusive) {
yield(current) yield(current)
current += step current += step
} }
} }
else -> sequence { else -> sequence {
var current = endInclusive var current = endInclusive
while (current >= start) { while (current >= start) {
yield(current) yield(current)
current += step current += step
@ -31,7 +35,7 @@ fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Sequence<
/** /**
* Convert double range to sequence with the fixed number of points * Convert double range to sequence with the fixed number of points
*/ */
fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Sequence<Double> { public fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Sequence<Double> {
require(numPoints > 1) { "The number of points should be more than 2" } require(numPoints > 1) { "The number of points should be more than 2" }
return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1)) return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1))
} }
@ -40,7 +44,7 @@ fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Seque
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
*/ */
@Deprecated("Replace by 'toSequenceWithPoints'") @Deprecated("Replace by 'toSequenceWithPoints'")
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray { public fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
require(numPoints >= 2) { "Can't create generic grid with less than two points" } require(numPoints >= 2) { "Can't create generic grid with less than two points" }
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
} }

View File

@ -0,0 +1,74 @@
package kscience.kmath.misc
import kscience.kmath.operations.Space
import kscience.kmath.operations.invoke
import kotlin.jvm.JvmName
/**
* Generic cumulative operation on iterator.
*
* @param T the type of initial iterable.
* @param R the type of resulting iterable.
* @param initial lazy evaluated.
*/
public inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> =
object : Iterator<R> {
var state: R = initial
override fun hasNext(): Boolean = this@cumulative.hasNext()
override fun next(): R {
state = operation(state, this@cumulative.next())
return state
}
}
public inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> =
Iterable { this@cumulative.iterator().cumulative(initial, operation) }
public inline fun <T, R> Sequence<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence<R> =
Sequence { this@cumulative.iterator().cumulative(initial, operation) }
public fun <T, R> List<T>.cumulative(initial: R, operation: (R, T) -> R): List<R> =
iterator().cumulative(initial, operation).asSequence().toList()
//Cumulative sum
/**
* Cumulative sum with custom space
*/
public fun <T> Iterable<T>.cumulativeSum(space: Space<T>): Iterable<T> =
space { cumulative(zero) { element: T, sum: T -> sum + element } }
@JvmName("cumulativeSumOfDouble")
public fun Iterable<Double>.cumulativeSum(): Iterable<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt")
public fun Iterable<Int>.cumulativeSum(): Iterable<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong")
public fun Iterable<Long>.cumulativeSum(): Iterable<Long> = cumulative(0L) { element, sum -> sum + element }
public fun <T> Sequence<T>.cumulativeSum(space: Space<T>): Sequence<T> =
space { cumulative(zero) { element: T, sum: T -> sum + element } }
@JvmName("cumulativeSumOfDouble")
public fun Sequence<Double>.cumulativeSum(): Sequence<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt")
public fun Sequence<Int>.cumulativeSum(): Sequence<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong")
public fun Sequence<Long>.cumulativeSum(): Sequence<Long> = cumulative(0L) { element, sum -> sum + element }
public fun <T> List<T>.cumulativeSum(space: Space<T>): List<T> =
space { cumulative(zero) { element: T, sum: T -> sum + element } }
@JvmName("cumulativeSumOfDouble")
public fun List<Double>.cumulativeSum(): List<Double> = cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt")
public fun List<Int>.cumulativeSum(): List<Int> = cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong")
public fun List<Long>.cumulativeSum(): List<Long> = cumulative(0L) { element, sum -> sum + element }

View File

@ -1,31 +1,31 @@
package scientifik.kmath.operations package kscience.kmath.operations
/** /**
* Stub for DSL the [Algebra] is. * Stub for DSL the [Algebra] is.
*/ */
@DslMarker @DslMarker
annotation class KMathContext public annotation class KMathContext
/** /**
* Represents an algebraic structure. * Represents an algebraic structure.
* *
* @param T the type of element of this structure. * @param T the type of element of this structure.
*/ */
interface Algebra<T> { public interface Algebra<T> {
/** /**
* Wrap raw string or variable * Wrap raw string or variable
*/ */
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
/** /**
* Dynamic call of unary operation with name [operation] on [arg] * Dynamic call of unary operation with name [operation] on [arg]
*/ */
fun unaryOperation(operation: String, arg: T): T public fun unaryOperation(operation: String, arg: T): T
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] * Dynamic call of binary operation [operation] on [left] and [right]
*/ */
fun binaryOperation(operation: String, left: T, right: T): T public fun binaryOperation(operation: String, left: T, right: T): T
} }
/** /**
@ -33,29 +33,30 @@ interface Algebra<T> {
* *
* @param T the type of element of this structure. * @param T the type of element of this structure.
*/ */
interface NumericAlgebra<T> : Algebra<T> { public interface NumericAlgebra<T> : Algebra<T> {
/** /**
* Wraps a number. * Wraps a number.
*/ */
fun number(value: Number): T public fun number(value: Number): T
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. * Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
*/ */
fun leftSideNumberOperation(operation: String, left: Number, right: T): T = public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
binaryOperation(operation, number(left), right) binaryOperation(operation, number(left), right)
/** /**
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. * Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
*/ */
fun rightSideNumberOperation(operation: String, left: T, right: Number): T = public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left) leftSideNumberOperation(operation, right, left)
} }
/** /**
* Call a block with an [Algebra] as receiver. * Call a block with an [Algebra] as receiver.
*/ */
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block) // TODO add contract when KT-32313 is fixed
public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = block()
/** /**
* Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as
@ -63,7 +64,7 @@ inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(bloc
* *
* @param T the type of element of this semispace. * @param T the type of element of this semispace.
*/ */
interface SpaceOperations<T> : Algebra<T> { public interface SpaceOperations<T> : Algebra<T> {
/** /**
* Addition of two elements. * Addition of two elements.
* *
@ -71,7 +72,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @param b the augend. * @param b the augend.
* @return the sum. * @return the sum.
*/ */
fun add(a: T, b: T): T public fun add(a: T, b: T): T
/** /**
* Multiplication of element by scalar. * Multiplication of element by scalar.
@ -80,7 +81,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @param k the multiplicand. * @param k the multiplicand.
* @return the produce. * @return the produce.
*/ */
fun multiply(a: T, k: Number): T public fun multiply(a: T, k: Number): T
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176 // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176
@ -90,7 +91,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @receiver this value. * @receiver this value.
* @return the additive inverse of this value. * @return the additive inverse of this value.
*/ */
operator fun T.unaryMinus(): T = multiply(this, -1.0) public operator fun T.unaryMinus(): T = multiply(this, -1.0)
/** /**
* Returns this value. * Returns this value.
@ -98,7 +99,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @receiver this value. * @receiver this value.
* @return this value. * @return this value.
*/ */
operator fun T.unaryPlus(): T = this public operator fun T.unaryPlus(): T = this
/** /**
* Addition of two elements. * Addition of two elements.
@ -107,7 +108,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @param b the augend. * @param b the augend.
* @return the sum. * @return the sum.
*/ */
operator fun T.plus(b: T): T = add(this, b) public operator fun T.plus(b: T): T = add(this, b)
/** /**
* Subtraction of two elements. * Subtraction of two elements.
@ -116,7 +117,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @param b the subtrahend. * @param b the subtrahend.
* @return the difference. * @return the difference.
*/ */
operator fun T.minus(b: T): T = add(this, -b) public operator fun T.minus(b: T): T = add(this, -b)
/** /**
* Multiplication of this element by a scalar. * Multiplication of this element by a scalar.
@ -125,7 +126,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @param k the multiplicand. * @param k the multiplicand.
* @return the product. * @return the product.
*/ */
operator fun T.times(k: Number): T = multiply(this, k.toDouble()) public operator fun T.times(k: Number): T = multiply(this, k.toDouble())
/** /**
* Division of this element by scalar. * Division of this element by scalar.
@ -134,7 +135,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @param k the divisor. * @param k the divisor.
* @return the quotient. * @return the quotient.
*/ */
operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble()) public operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble())
/** /**
* Multiplication of this number by element. * Multiplication of this number by element.
@ -143,7 +144,7 @@ interface SpaceOperations<T> : Algebra<T> {
* @param b the multiplicand. * @param b the multiplicand.
* @return the product. * @return the product.
*/ */
operator fun Number.times(b: T): T = b * this public operator fun Number.times(b: T): T = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String, arg: T): T = when (operation) {
PLUS_OPERATION -> arg PLUS_OPERATION -> arg
@ -157,18 +158,16 @@ interface SpaceOperations<T> : Algebra<T> {
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }
companion object { public companion object {
/** /**
* The identifier of addition. * The identifier of addition.
*/ */
const val PLUS_OPERATION: String = "+" public const val PLUS_OPERATION: String = "+"
/** /**
* The identifier of subtraction (and negation). * The identifier of subtraction (and negation).
*/ */
const val MINUS_OPERATION: String = "-" public const val MINUS_OPERATION: String = "-"
const val NOT_OPERATION: String = "!"
} }
} }
@ -178,11 +177,11 @@ interface SpaceOperations<T> : Algebra<T> {
* *
* @param T the type of element of this group. * @param T the type of element of this group.
*/ */
interface Space<T> : SpaceOperations<T> { public interface Space<T> : SpaceOperations<T> {
/** /**
* The neutral element of addition. * The neutral element of addition.
*/ */
val zero: T public val zero: T
} }
/** /**
@ -191,14 +190,14 @@ interface Space<T> : SpaceOperations<T> {
* *
* @param T the type of element of this semiring. * @param T the type of element of this semiring.
*/ */
interface RingOperations<T> : SpaceOperations<T> { public interface RingOperations<T> : SpaceOperations<T> {
/** /**
* Multiplies two elements. * Multiplies two elements.
* *
* @param a the multiplier. * @param a the multiplier.
* @param b the multiplicand. * @param b the multiplicand.
*/ */
fun multiply(a: T, b: T): T public fun multiply(a: T, b: T): T
/** /**
* Multiplies this element by scalar. * Multiplies this element by scalar.
@ -206,18 +205,18 @@ interface RingOperations<T> : SpaceOperations<T> {
* @receiver the multiplier. * @receiver the multiplier.
* @param b the multiplicand. * @param b the multiplicand.
*/ */
operator fun T.times(b: T): T = multiply(this, b) public operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
TIMES_OPERATION -> multiply(left, right) TIMES_OPERATION -> multiply(left, right)
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation, left, right)
} }
companion object { public companion object {
/** /**
* The identifier of multiplication. * The identifier of multiplication.
*/ */
const val TIMES_OPERATION: String = "*" public const val TIMES_OPERATION: String = "*"
} }
} }
@ -227,11 +226,11 @@ interface RingOperations<T> : SpaceOperations<T> {
* *
* @param T the type of element of this ring. * @param T the type of element of this ring.
*/ */
interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> { public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
/** /**
* neutral operation for multiplication * neutral operation for multiplication
*/ */
val one: T public val one: T
override fun number(value: Number): T = one * value.toDouble() override fun number(value: Number): T = one * value.toDouble()
@ -255,7 +254,7 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
* @receiver the addend. * @receiver the addend.
* @param b the augend. * @param b the augend.
*/ */
operator fun T.plus(b: Number): T = this + number(b) public operator fun T.plus(b: Number): T = this + number(b)
/** /**
* Addition of scalar and element. * Addition of scalar and element.
@ -263,7 +262,7 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
* @receiver the addend. * @receiver the addend.
* @param b the augend. * @param b the augend.
*/ */
operator fun Number.plus(b: T): T = b + this public operator fun Number.plus(b: T): T = b + this
/** /**
* Subtraction of element from number. * Subtraction of element from number.
@ -272,7 +271,7 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
* @param b the subtrahend. * @param b the subtrahend.
* @receiver the difference. * @receiver the difference.
*/ */
operator fun T.minus(b: Number): T = this - number(b) public operator fun T.minus(b: Number): T = this - number(b)
/** /**
* Subtraction of number from element. * Subtraction of number from element.
@ -281,7 +280,7 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
* @param b the subtrahend. * @param b the subtrahend.
* @receiver the difference. * @receiver the difference.
*/ */
operator fun Number.minus(b: T): T = -b + this public operator fun Number.minus(b: T): T = -b + this
} }
/** /**
@ -290,7 +289,7 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
* *
* @param T the type of element of this semifield. * @param T the type of element of this semifield.
*/ */
interface FieldOperations<T> : RingOperations<T> { public interface FieldOperations<T> : RingOperations<T> {
/** /**
* Division of two elements. * Division of two elements.
* *
@ -298,7 +297,7 @@ interface FieldOperations<T> : RingOperations<T> {
* @param b the divisor. * @param b the divisor.
* @return the quotient. * @return the quotient.
*/ */
fun divide(a: T, b: T): T public fun divide(a: T, b: T): T
/** /**
* Division of two elements. * Division of two elements.
@ -307,18 +306,18 @@ interface FieldOperations<T> : RingOperations<T> {
* @param b the divisor. * @param b the divisor.
* @return the quotient. * @return the quotient.
*/ */
operator fun T.div(b: T): T = divide(this, b) public operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
DIV_OPERATION -> divide(left, right) DIV_OPERATION -> divide(left, right)
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation, left, right)
} }
companion object { public companion object {
/** /**
* The identifier of division. * The identifier of division.
*/ */
const val DIV_OPERATION: String = "/" public const val DIV_OPERATION: String = "/"
} }
} }
@ -328,7 +327,7 @@ interface FieldOperations<T> : RingOperations<T> {
* *
* @param T the type of element of this semifield. * @param T the type of element of this semifield.
*/ */
interface Field<T> : Ring<T>, FieldOperations<T> { public interface Field<T> : Ring<T>, FieldOperations<T> {
/** /**
* Division of element by scalar. * Division of element by scalar.
* *
@ -336,5 +335,5 @@ interface Field<T> : Ring<T>, FieldOperations<T> {
* @param b the divisor. * @param b the divisor.
* @return the quotient. * @return the quotient.
*/ */
operator fun Number.div(b: T): T = this * divide(one, b) public operator fun Number.div(b: T): T = this * divide(one, b)
} }

View File

@ -1,15 +1,15 @@
package scientifik.kmath.operations package kscience.kmath.operations
/** /**
* The generic mathematics elements which is able to store its context * The generic mathematics elements which is able to store its context
* *
* @param C the type of mathematical context for this element. * @param C the type of mathematical context for this element.
*/ */
interface MathElement<C> { public interface MathElement<C> {
/** /**
* The context this element belongs to. * The context this element belongs to.
*/ */
val context: C public val context: C
} }
/** /**
@ -18,16 +18,16 @@ interface MathElement<C> {
* @param T the type wrapped by this wrapper. * @param T the type wrapped by this wrapper.
* @param I the type of this wrapper. * @param I the type of this wrapper.
*/ */
interface MathWrapper<T, I> { public interface MathWrapper<T, I> {
/** /**
* Unwraps [I] to [T]. * Unwraps [I] to [T].
*/ */
fun unwrap(): T public fun unwrap(): T
/** /**
* Wraps [T] to [I]. * Wraps [T] to [I].
*/ */
fun T.wrap(): I public fun T.wrap(): I
} }
/** /**
@ -37,14 +37,14 @@ interface MathWrapper<T, I> {
* @param I self type of the element. Needed for static type checking. * @param I self type of the element. Needed for static type checking.
* @param S the type of space. * @param S the type of space.
*/ */
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> { public interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S>, MathWrapper<T, I> {
/** /**
* Adds element to this one. * Adds element to this one.
* *
* @param b the augend. * @param b the augend.
* @return the sum. * @return the sum.
*/ */
operator fun plus(b: T): I = context.add(unwrap(), b).wrap() public operator fun plus(b: T): I = context.add(unwrap(), b).wrap()
/** /**
* Subtracts element from this one. * Subtracts element from this one.
@ -52,7 +52,7 @@ interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement
* @param b the subtrahend. * @param b the subtrahend.
* @return the difference. * @return the difference.
*/ */
operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap() public operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
/** /**
* Multiplies this element by number. * Multiplies this element by number.
@ -60,7 +60,7 @@ interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement
* @param k the multiplicand. * @param k the multiplicand.
* @return the product. * @return the product.
*/ */
operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap() public operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap()
/** /**
* Divides this element by number. * Divides this element by number.
@ -68,7 +68,7 @@ interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement
* @param k the divisor. * @param k the divisor.
* @return the quotient. * @return the quotient.
*/ */
operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap() public operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
} }
/** /**
@ -78,14 +78,14 @@ interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement
* @param I self type of the element. Needed for static type checking. * @param I self type of the element. Needed for static type checking.
* @param R the type of ring. * @param R the type of ring.
*/ */
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> { public interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
/** /**
* Multiplies this element by another one. * Multiplies this element by another one.
* *
* @param b the multiplicand. * @param b the multiplicand.
* @return the product. * @return the product.
*/ */
operator fun times(b: T): I = context.multiply(unwrap(), b).wrap() public operator fun times(b: T): I = context.multiply(unwrap(), b).wrap()
} }
/** /**
@ -95,7 +95,7 @@ interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T
* @param I self type of the element. Needed for static type checking. * @param I self type of the element. Needed for static type checking.
* @param F the type of field. * @param F the type of field.
*/ */
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> { public interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
override val context: F override val context: F
/** /**
@ -104,5 +104,5 @@ interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement
* @param b the divisor. * @param b the divisor.
* @return the quotient. * @return the quotient.
*/ */
operator fun div(b: T): I = context.divide(unwrap(), b).wrap() public operator fun div(b: T): I = context.divide(unwrap(), b).wrap()
} }

View File

@ -1,4 +1,4 @@
package scientifik.kmath.operations package kscience.kmath.operations
/** /**
* Returns the sum of all elements in the iterable in this [Space]. * Returns the sum of all elements in the iterable in this [Space].
@ -7,7 +7,7 @@ package scientifik.kmath.operations
* @param data the iterable to sum up. * @param data the iterable to sum up.
* @return the sum. * @return the sum.
*/ */
fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> add(left, right) } public fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> add(left, right) }
/** /**
* Returns the sum of all elements in the sequence in this [Space]. * Returns the sum of all elements in the sequence in this [Space].
@ -16,7 +16,7 @@ fun <T> Space<T>.sum(data: Iterable<T>): T = data.fold(zero) { left, right -> ad
* @param data the sequence to sum up. * @param data the sequence to sum up.
* @return the sum. * @return the sum.
*/ */
fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> add(left, right) } public fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> add(left, right) }
/** /**
* Returns an average value of elements in the iterable in this [Space]. * Returns an average value of elements in the iterable in this [Space].
@ -24,8 +24,9 @@ fun <T> Space<T>.sum(data: Sequence<T>): T = data.fold(zero) { left, right -> ad
* @receiver the algebra that provides addition and division. * @receiver the algebra that provides addition and division.
* @param data the iterable to find average. * @param data the iterable to find average.
* @return the average value. * @return the average value.
* @author Iaroslav Postovalov
*/ */
fun <T> Space<T>.average(data: Iterable<T>): T = sum(data) / data.count() public fun <T> Space<T>.average(data: Iterable<T>): T = sum(data) / data.count()
/** /**
* Returns an average value of elements in the sequence in this [Space]. * Returns an average value of elements in the sequence in this [Space].
@ -33,8 +34,9 @@ fun <T> Space<T>.average(data: Iterable<T>): T = sum(data) / data.count()
* @receiver the algebra that provides addition and division. * @receiver the algebra that provides addition and division.
* @param data the sequence to find average. * @param data the sequence to find average.
* @return the average value. * @return the average value.
* @author Iaroslav Postovalov
*/ */
fun <T> Space<T>.average(data: Sequence<T>): T = sum(data) / data.count() public fun <T> Space<T>.average(data: Sequence<T>): T = sum(data) / data.count()
/** /**
* Returns the sum of all elements in the iterable in provided space. * Returns the sum of all elements in the iterable in provided space.
@ -43,7 +45,7 @@ fun <T> Space<T>.average(data: Sequence<T>): T = sum(data) / data.count()
* @param space the algebra that provides addition. * @param space the algebra that provides addition.
* @return the sum. * @return the sum.
*/ */
fun <T> Iterable<T>.sumWith(space: Space<T>): T = space.sum(this) public fun <T> Iterable<T>.sumWith(space: Space<T>): T = space.sum(this)
/** /**
* Returns the sum of all elements in the sequence in provided space. * Returns the sum of all elements in the sequence in provided space.
@ -52,7 +54,7 @@ fun <T> Iterable<T>.sumWith(space: Space<T>): T = space.sum(this)
* @param space the algebra that provides addition. * @param space the algebra that provides addition.
* @return the sum. * @return the sum.
*/ */
fun <T> Sequence<T>.sumWith(space: Space<T>): T = space.sum(this) public fun <T> Sequence<T>.sumWith(space: Space<T>): T = space.sum(this)
/** /**
* Returns an average value of elements in the iterable in this [Space]. * Returns an average value of elements in the iterable in this [Space].
@ -60,8 +62,9 @@ fun <T> Sequence<T>.sumWith(space: Space<T>): T = space.sum(this)
* @receiver the iterable to find average. * @receiver the iterable to find average.
* @param space the algebra that provides addition and division. * @param space the algebra that provides addition and division.
* @return the average value. * @return the average value.
* @author Iaroslav Postovalov
*/ */
fun <T> Iterable<T>.averageWith(space: Space<T>): T = space.average(this) public fun <T> Iterable<T>.averageWith(space: Space<T>): T = space.average(this)
/** /**
* Returns an average value of elements in the sequence in this [Space]. * Returns an average value of elements in the sequence in this [Space].
@ -69,8 +72,9 @@ fun <T> Iterable<T>.averageWith(space: Space<T>): T = space.average(this)
* @receiver the sequence to find average. * @receiver the sequence to find average.
* @param space the algebra that provides addition and division. * @param space the algebra that provides addition and division.
* @return the average value. * @return the average value.
* @author Iaroslav Postovalov
*/ */
fun <T> Sequence<T>.averageWith(space: Space<T>): T = space.average(this) public fun <T> Sequence<T>.averageWith(space: Space<T>): T = space.average(this)
//TODO optimized power operation //TODO optimized power operation
@ -82,7 +86,7 @@ fun <T> Sequence<T>.averageWith(space: Space<T>): T = space.average(this)
* @param power the exponent. * @param power the exponent.
* @return the base raised to the power. * @return the base raised to the power.
*/ */
fun <T> Ring<T>.power(arg: T, power: Int): T { public fun <T> Ring<T>.power(arg: T, power: Int): T {
require(power >= 0) { "The power can't be negative." } require(power >= 0) { "The power can't be negative." }
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one if (power == 0) return one
@ -98,8 +102,9 @@ fun <T> Ring<T>.power(arg: T, power: Int): T {
* @param arg the base. * @param arg the base.
* @param power the exponent. * @param power the exponent.
* @return the base raised to the power. * @return the base raised to the power.
* @author Iaroslav Postovalov
*/ */
fun <T> Field<T>.power(arg: T, power: Int): T { public fun <T> Field<T>.power(arg: T, power: Int): T {
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one if (power == 0) return one
if (power < 0) return one / (this as Ring<T>).power(arg, -power) if (power < 0) return one / (this as Ring<T>).power(arg, -power)

View File

@ -1,23 +1,22 @@
package scientifik.kmath.operations package kscience.kmath.operations
import scientifik.kmath.operations.BigInt.Companion.BASE import kscience.kmath.operations.BigInt.Companion.BASE
import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE import kscience.kmath.operations.BigInt.Companion.BASE_SIZE
import scientifik.kmath.structures.* import kscience.kmath.structures.*
import kotlin.contracts.contract
import kotlin.math.log2 import kotlin.math.log2
import kotlin.math.max import kotlin.math.max
import kotlin.math.min import kotlin.math.min
import kotlin.math.sign import kotlin.math.sign
typealias Magnitude = UIntArray public typealias Magnitude = UIntArray
typealias TBase = ULong public typealias TBase = ULong
/** /**
* Kotlin Multiplatform implementation of Big Integer numbers (KBigInteger). * Kotlin Multiplatform implementation of Big Integer numbers (KBigInteger).
* *
* @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai) * @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai)
*/ */
object BigIntField : Field<BigInt> { public object BigIntField : Field<BigInt> {
override val zero: BigInt = BigInt.ZERO override val zero: BigInt = BigInt.ZERO
override val one: BigInt = BigInt.ONE override val one: BigInt = BigInt.ONE
@ -28,113 +27,92 @@ object BigIntField : Field<BigInt> {
override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b)
operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer") public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer")
operator fun String.unaryMinus(): BigInt = public operator fun String.unaryMinus(): BigInt =
-(this.parseBigInteger() ?: error("Can't parse $this as big integer")) -(this.parseBigInteger() ?: error("Can't parse $this as big integer"))
override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b) override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b)
} }
class BigInt internal constructor( public class BigInt internal constructor(
private val sign: Byte, private val sign: Byte,
private val magnitude: Magnitude private val magnitude: Magnitude
) : Comparable<BigInt> { ) : Comparable<BigInt> {
public override fun compareTo(other: BigInt): Int = when {
override fun compareTo(other: BigInt): Int { (sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0
return when { sign < other.sign -> -1
(this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0 sign > other.sign -> 1
this.sign < other.sign -> -1 else -> sign * compareMagnitudes(magnitude, other.magnitude)
this.sign > other.sign -> 1
else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude)
}
} }
override fun equals(other: Any?): Boolean { public override fun equals(other: Any?): Boolean =
if (other is BigInt) { if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
return this.compareTo(other) == 0
} else error("Can't compare KBigInteger to a different type")
}
override fun hashCode(): Int { public override fun hashCode(): Int = magnitude.hashCode() + sign
return magnitude.hashCode() + this.sign
}
fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude) public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude)
operator fun unaryMinus(): BigInt { public operator fun unaryMinus(): BigInt =
return if (this.sign == 0.toByte()) this else BigInt((-this.sign).toByte(), this.magnitude) if (this.sign == 0.toByte()) this else BigInt((-this.sign).toByte(), this.magnitude)
}
operator fun plus(b: BigInt): BigInt { public operator fun plus(b: BigInt): BigInt = when {
return when {
b.sign == 0.toByte() -> this b.sign == 0.toByte() -> this
this.sign == 0.toByte() -> b sign == 0.toByte() -> b
this == -b -> ZERO this == -b -> ZERO
this.sign == b.sign -> BigInt(this.sign, addMagnitudes(this.magnitude, b.magnitude)) sign == b.sign -> BigInt(sign, addMagnitudes(magnitude, b.magnitude))
else -> { else -> {
val comp: Int = compareMagnitudes(this.magnitude, b.magnitude) val comp = compareMagnitudes(magnitude, b.magnitude)
if (comp == 1) { if (comp == 1)
BigInt(this.sign, subtractMagnitudes(this.magnitude, b.magnitude)) BigInt(sign, subtractMagnitudes(magnitude, b.magnitude))
} else { else
BigInt((-this.sign).toByte(), subtractMagnitudes(b.magnitude, this.magnitude)) BigInt((-sign).toByte(), subtractMagnitudes(b.magnitude, magnitude))
}
}
} }
} }
operator fun minus(b: BigInt): BigInt { public operator fun minus(b: BigInt): BigInt = this + (-b)
return this + (-b)
}
operator fun times(b: BigInt): BigInt { public operator fun times(b: BigInt): BigInt = when {
return when {
this.sign == 0.toByte() -> ZERO this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba // TODO: Karatsuba
else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude)) else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
} }
}
operator fun times(other: UInt): BigInt { public operator fun times(other: UInt): BigInt = when {
return when { sign == 0.toByte() -> ZERO
this.sign == 0.toByte() -> ZERO
other == 0U -> ZERO other == 0U -> ZERO
else -> BigInt(this.sign, multiplyMagnitudeByUInt(this.magnitude, other)) else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
}
} }
operator fun times(other: Int): BigInt { public operator fun times(other: Int): BigInt = if (other > 0)
return if (other > 0)
this * kotlin.math.abs(other).toUInt() this * kotlin.math.abs(other).toUInt()
else else
-this * kotlin.math.abs(other).toUInt() -this * kotlin.math.abs(other).toUInt()
}
operator fun div(other: UInt): BigInt { public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))
return BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))
}
operator fun div(other: Int): BigInt { public operator fun div(other: Int): BigInt = BigInt(
return BigInt(
(this.sign * other.sign).toByte(), (this.sign * other.sign).toByte(),
divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt()) divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt())
) )
}
private fun division(other: BigInt): Pair<BigInt, BigInt> { private fun division(other: BigInt): Pair<BigInt, BigInt> {
// Long division algorithm: // Long division algorithm:
// https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder // https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder
// TODO: Implement more effective algorithm // TODO: Implement more effective algorithm
var q: BigInt = ZERO var q = ZERO
var r: BigInt = ZERO var r = ZERO
val bitSize = val bitSize =
(BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.lastOrNull()?.toFloat() ?: 0f + 1)).toInt() (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.lastOrNull()?.toFloat() ?: 0f + 1)).toInt()
for (i in bitSize downTo 0) { for (i in bitSize downTo 0) {
r = r shl 1 r = r shl 1
r = r or ((abs(this) shr i) and ONE) r = r or ((abs(this) shr i) and ONE)
if (r >= abs(other)) { if (r >= abs(other)) {
r -= abs(other) r -= abs(other)
q += (ONE shl i) q += (ONE shl i)
@ -144,101 +122,86 @@ class BigInt internal constructor(
return Pair(BigInt((this.sign * other.sign).toByte(), q.magnitude), r) return Pair(BigInt((this.sign * other.sign).toByte(), q.magnitude), r)
} }
operator fun div(other: BigInt): BigInt { public operator fun div(other: BigInt): BigInt = division(other).first
return this.division(other).first
}
infix fun shl(i: Int): BigInt { public infix fun shl(i: Int): BigInt {
if (this == ZERO) return ZERO if (this == ZERO) return ZERO
if (i == 0) return this if (i == 0) return this
val fullShifts = i / BASE_SIZE + 1 val fullShifts = i / BASE_SIZE + 1
val relShift = i % BASE_SIZE val relShift = i % BASE_SIZE
val shiftLeft = { x: UInt -> if (relShift >= 32) 0U else x shl relShift } val shiftLeft = { x: UInt -> if (relShift >= 32) 0U else x shl relShift }
val shiftRight = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift) } val shiftRight = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift) }
val newMagnitude = Magnitude(magnitude.size + fullShifts)
val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts) for (j in magnitude.indices) {
for (j in this.magnitude.indices) {
newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j]) newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j])
if (j != 0) {
if (j != 0)
newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1]) newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1])
} }
}
newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last())
newMagnitude[magnitude.size + fullShifts - 1] = shiftRight(magnitude.last())
return BigInt(this.sign, stripLeadingZeros(newMagnitude)) return BigInt(this.sign, stripLeadingZeros(newMagnitude))
} }
infix fun shr(i: Int): BigInt { public infix fun shr(i: Int): BigInt {
if (this == ZERO) return ZERO if (this == ZERO) return ZERO
if (i == 0) return this if (i == 0) return this
val fullShifts = i / BASE_SIZE val fullShifts = i / BASE_SIZE
val relShift = i % BASE_SIZE val relShift = i % BASE_SIZE
val shiftRight = { x: UInt -> if (relShift >= 32) 0U else x shr relShift } val shiftRight = { x: UInt -> if (relShift >= 32) 0U else x shr relShift }
val shiftLeft = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift) } val shiftLeft = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift) }
if (this.magnitude.size - fullShifts <= 0) { if (this.magnitude.size - fullShifts <= 0) return ZERO
return ZERO val newMagnitude: Magnitude = Magnitude(magnitude.size - fullShifts)
}
val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts)
for (j in fullShifts until this.magnitude.size) { for (j in fullShifts until magnitude.size) {
newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j]) newMagnitude[j - fullShifts] = shiftRight(magnitude[j])
if (j != this.magnitude.size - 1) {
newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1]) if (j != magnitude.size - 1)
} newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(magnitude[j + 1])
} }
return BigInt(this.sign, stripLeadingZeros(newMagnitude)) return BigInt(this.sign, stripLeadingZeros(newMagnitude))
} }
infix fun or(other: BigInt): BigInt { public infix fun or(other: BigInt): BigInt {
if (this == ZERO) return other if (this == ZERO) return other
if (other == ZERO) return this if (other == ZERO) return this
val resSize = max(this.magnitude.size, other.magnitude.size) val resSize = max(magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize) val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) { for (i in 0 until resSize) {
if (i < this.magnitude.size) { if (i < magnitude.size) newMagnitude[i] = newMagnitude[i] or magnitude[i]
newMagnitude[i] = newMagnitude[i] or this.magnitude[i] if (i < other.magnitude.size) newMagnitude[i] = newMagnitude[i] or other.magnitude[i]
}
if (i < other.magnitude.size) {
newMagnitude[i] = newMagnitude[i] or other.magnitude[i]
}
} }
return BigInt(1, stripLeadingZeros(newMagnitude)) return BigInt(1, stripLeadingZeros(newMagnitude))
} }
infix fun and(other: BigInt): BigInt { public infix fun and(other: BigInt): BigInt {
if ((this == ZERO) or (other == ZERO)) return ZERO if ((this == ZERO) or (other == ZERO)) return ZERO
val resSize = min(this.magnitude.size, other.magnitude.size) val resSize = min(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize) val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) { for (i in 0 until resSize) newMagnitude[i] = this.magnitude[i] and other.magnitude[i]
newMagnitude[i] = this.magnitude[i] and other.magnitude[i]
}
return BigInt(1, stripLeadingZeros(newMagnitude)) return BigInt(1, stripLeadingZeros(newMagnitude))
} }
operator fun rem(other: Int): Int { public operator fun rem(other: Int): Int {
val res = this - (this / other) * other val res = this - (this / other) * other
return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt() return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt()
} }
operator fun rem(other: BigInt): BigInt { public operator fun rem(other: BigInt): BigInt = this - (this / other) * other
return this - (this / other) * other
}
fun modPow(exponent: BigInt, m: BigInt): BigInt { public fun modPow(exponent: BigInt, m: BigInt): BigInt = when {
return when {
exponent == ZERO -> ONE exponent == ZERO -> ONE
exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m
else -> { else -> {
val sqRoot = modPow(exponent / 2, m) val sqRoot = modPow(exponent / 2, m)
(sqRoot * sqRoot) % m (sqRoot * sqRoot) % m
} }
} }
}
override fun toString(): String { override fun toString(): String {
if (this.sign == 0.toByte()) { if (this.sign == 0.toByte()) {
@ -260,11 +223,11 @@ class BigInt internal constructor(
return res return res
} }
companion object { public companion object {
const val BASE: ULong = 0xffffffffUL public const val BASE: ULong = 0xffffffffUL
const val BASE_SIZE: Int = 32 public const val BASE_SIZE: Int = 32
val ZERO: BigInt = BigInt(0, uintArrayOf()) public val ZERO: BigInt = BigInt(0, uintArrayOf())
val ONE: BigInt = BigInt(1, uintArrayOf(1u)) public val ONE: BigInt = BigInt(1, uintArrayOf(1u))
private val hexMapping: HashMap<UInt, String> = hashMapOf( private val hexMapping: HashMap<UInt, String> = hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3", 0U to "0", 1U to "1", 2U to "2", 3U to "3",
@ -291,9 +254,9 @@ class BigInt internal constructor(
} }
private fun addMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { private fun addMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength: Int = max(mag1.size, mag2.size) + 1 val resultLength = max(mag1.size, mag2.size) + 1
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
var carry: TBase = 0UL var carry = 0uL
for (i in 0 until resultLength - 1) { for (i in 0 until resultLength - 1) {
val res = when { val res = when {
@ -301,20 +264,22 @@ class BigInt internal constructor(
i >= mag2.size -> mag1[i].toULong() + carry i >= mag2.size -> mag1[i].toULong() + carry
else -> mag1[i].toULong() + mag2[i].toULong() + carry else -> mag1[i].toULong() + mag2[i].toULong() + carry
} }
result[i] = (res and BASE).toUInt() result[i] = (res and BASE).toUInt()
carry = (res shr BASE_SIZE) carry = (res shr BASE_SIZE)
} }
result[resultLength - 1] = carry.toUInt() result[resultLength - 1] = carry.toUInt()
return stripLeadingZeros(result) return stripLeadingZeros(result)
} }
private fun subtractMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { private fun subtractMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength: Int = mag1.size val resultLength = mag1.size
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
var carry = 0L var carry = 0L
for (i in 0 until resultLength) { for (i in 0 until resultLength) {
var res: Long = var res =
if (i < mag2.size) mag1[i].toLong() - mag2[i].toLong() - carry if (i < mag2.size) mag1[i].toLong() - mag2[i].toLong() - carry
else mag1[i].toLong() - carry else mag1[i].toLong() - carry
@ -328,9 +293,9 @@ class BigInt internal constructor(
} }
private fun multiplyMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { private fun multiplyMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength: Int = mag.size + 1 val resultLength = mag.size + 1
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
var carry: ULong = 0UL var carry = 0uL
for (i in mag.indices) { for (i in mag.indices) {
val cur: ULong = carry + mag[i].toULong() * x.toULong() val cur: ULong = carry + mag[i].toULong() * x.toULong()
@ -343,16 +308,18 @@ class BigInt internal constructor(
} }
private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength: Int = mag1.size + mag2.size val resultLength = mag1.size + mag2.size
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
for (i in mag1.indices) { for (i in mag1.indices) {
var carry: ULong = 0UL var carry = 0uL
for (j in mag2.indices) { for (j in mag2.indices) {
val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry
result[i + j] = (cur and BASE.toULong()).toUInt() result[i + j] = (cur and BASE.toULong()).toUInt()
carry = cur shr BASE_SIZE carry = cur shr BASE_SIZE
} }
result[i + mag2.size] = (carry and BASE).toUInt() result[i + mag2.size] = (carry and BASE).toUInt()
} }
@ -360,48 +327,46 @@ class BigInt internal constructor(
} }
private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength: Int = mag.size val resultLength = mag.size
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
var carry: ULong = 0UL var carry = 0uL
for (i in mag.size - 1 downTo 0) { for (i in mag.size - 1 downTo 0) {
val cur: ULong = mag[i].toULong() + (carry shl BASE_SIZE) val cur: ULong = mag[i].toULong() + (carry shl BASE_SIZE)
result[i] = (cur / x).toUInt() result[i] = (cur / x).toUInt()
carry = cur % x carry = cur % x
} }
return stripLeadingZeros(result) return stripLeadingZeros(result)
} }
} }
} }
private fun stripLeadingZeros(mag: Magnitude): Magnitude { private fun stripLeadingZeros(mag: Magnitude): Magnitude {
if (mag.isEmpty() || mag.last() != 0U) { if (mag.isEmpty() || mag.last() != 0U) return mag
return mag var resSize = mag.size - 1
}
var resSize: Int = mag.size - 1
while (mag[resSize] == 0U) { while (mag[resSize] == 0U) {
if (resSize == 0) if (resSize == 0) break
break
resSize -= 1 resSize -= 1
} }
return mag.sliceArray(IntRange(0, resSize)) return mag.sliceArray(IntRange(0, resSize))
} }
fun abs(x: BigInt): BigInt = x.abs() public fun abs(x: BigInt): BigInt = x.abs()
/** /**
* Convert this [Int] to [BigInt] * Convert this [Int] to [BigInt]
*/ */
fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt())) public fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))
/** /**
* Convert this [Long] to [BigInt] * Convert this [Long] to [BigInt]
*/ */
fun Long.toBigInt(): BigInt = BigInt( public fun Long.toBigInt(): BigInt = BigInt(
sign.toByte(), stripLeadingZeros( sign.toByte(),
stripLeadingZeros(
uintArrayOf( uintArrayOf(
(kotlin.math.abs(this).toULong() and BASE).toUInt(), (kotlin.math.abs(this).toULong() and BASE).toUInt(),
((kotlin.math.abs(this).toULong() shr BASE_SIZE) and BASE).toUInt() ((kotlin.math.abs(this).toULong() shr BASE_SIZE) and BASE).toUInt()
@ -412,12 +377,12 @@ fun Long.toBigInt(): BigInt = BigInt(
/** /**
* Convert UInt to [BigInt] * Convert UInt to [BigInt]
*/ */
fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this)) public fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this))
/** /**
* Convert ULong to [BigInt] * Convert ULong to [BigInt]
*/ */
fun ULong.toBigInt(): BigInt = BigInt( public fun ULong.toBigInt(): BigInt = BigInt(
1, 1,
stripLeadingZeros( stripLeadingZeros(
uintArrayOf( uintArrayOf(
@ -430,12 +395,12 @@ fun ULong.toBigInt(): BigInt = BigInt(
/** /**
* Create a [BigInt] with this array of magnitudes with protective copy * Create a [BigInt] with this array of magnitudes with protective copy
*/ */
fun UIntArray.toBigInt(sign: Byte): BigInt { public fun UIntArray.toBigInt(sign: Byte): BigInt {
require(sign != 0.toByte() || !isNotEmpty()) require(sign != 0.toByte() || !isNotEmpty())
return BigInt(sign, copyOf()) return BigInt(sign, copyOf())
} }
val hexChToInt: MutableMap<Char, Int> = hashMapOf( private val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3, '0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7, '4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11, '8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
@ -445,9 +410,10 @@ val hexChToInt: MutableMap<Char, Int> = hashMapOf(
/** /**
* Returns null if a valid number can not be read from a string * Returns null if a valid number can not be read from a string
*/ */
fun String.parseBigInteger(): BigInt? { public fun String.parseBigInteger(): BigInt? {
val sign: Int val sign: Int
val sPositive: String val sPositive: String
when { when {
this[0] == '+' -> { this[0] == '+' -> {
sign = +1 sign = +1
@ -462,18 +428,21 @@ fun String.parseBigInteger(): BigInt? {
sign = +1 sign = +1
} }
} }
var res = BigInt.ZERO var res = BigInt.ZERO
var digitValue = BigInt.ONE var digitValue = BigInt.ONE
val sPositiveUpper = sPositive.toUpperCase() val sPositiveUpper = sPositive.toUpperCase()
if (sPositiveUpper.startsWith("0X")) { // hex representation if (sPositiveUpper.startsWith("0X")) { // hex representation
val sHex = sPositiveUpper.substring(2) val sHex = sPositiveUpper.substring(2)
for (ch in sHex.reversed()) { for (ch in sHex.reversed()) {
if (ch == '_') continue if (ch == '_') continue
res += digitValue * (hexChToInt[ch] ?: return null) res += digitValue * (hexChToInt[ch] ?: return null)
digitValue *= 16.toBigInt() digitValue *= 16.toBigInt()
} }
} else { // decimal representation } else for (ch in sPositiveUpper.reversed()) {
for (ch in sPositiveUpper.reversed()) { // decimal representation
if (ch == '_') continue if (ch == '_') continue
if (ch !in '0'..'9') { if (ch !in '0'..'9') {
return null return null
@ -481,24 +450,20 @@ fun String.parseBigInteger(): BigInt? {
res += digitValue * (ch.toInt() - '0'.toInt()) res += digitValue * (ch.toInt() - '0'.toInt())
digitValue *= 10.toBigInt() digitValue *= 10.toBigInt()
} }
}
return res * sign return res * sign
} }
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> { public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
contract { callsInPlace(initializer) } boxing(size, initializer)
return boxing(size, initializer)
}
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> { public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
contract { callsInPlace(initializer) } boxing(size, initializer)
return boxing(size, initializer)
}
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> = public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
fun NDElement.Companion.bigInt( public fun NDElement.Companion.bigInt(
vararg shape: Int, vararg shape: Int,
initializer: BigIntField.(IntArray) -> BigInt initializer: BigIntField.(IntArray) -> BigInt
): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer) ): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer)

View File

@ -1,24 +1,23 @@
package scientifik.kmath.operations package kscience.kmath.operations
import scientifik.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import scientifik.kmath.structures.MemoryBuffer import kscience.kmath.structures.MemoryBuffer
import scientifik.kmath.structures.MutableBuffer import kscience.kmath.structures.MutableBuffer
import scientifik.memory.MemoryReader import kscience.memory.MemoryReader
import scientifik.memory.MemorySpec import kscience.memory.MemorySpec
import scientifik.memory.MemoryWriter import kscience.memory.MemoryWriter
import kotlin.contracts.contract
import kotlin.math.* import kotlin.math.*
/** /**
* This complex's conjugate. * This complex's conjugate.
*/ */
val Complex.conjugate: Complex public val Complex.conjugate: Complex
get() = Complex(re, -im) get() = Complex(re, -im)
/** /**
* This complex's reciprocal. * This complex's reciprocal.
*/ */
val Complex.reciprocal: Complex public val Complex.reciprocal: Complex
get() { get() {
val scale = re * re + im * im val scale = re * re + im * im
return Complex(re / scale, -im / scale) return Complex(re / scale, -im / scale)
@ -27,13 +26,13 @@ val Complex.reciprocal: Complex
/** /**
* Absolute value of complex number. * Absolute value of complex number.
*/ */
val Complex.r: Double public val Complex.r: Double
get() = sqrt(re * re + im * im) get() = sqrt(re * re + im * im)
/** /**
* An angle between vector represented by complex number and X axis. * An angle between vector represented by complex number and X axis.
*/ */
val Complex.theta: Double public val Complex.theta: Double
get() = atan(im / re) get() = atan(im / re)
private val PI_DIV_2 = Complex(PI / 2, 0) private val PI_DIV_2 = Complex(PI / 2, 0)
@ -41,14 +40,14 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
/** /**
* A field of [Complex]. * A field of [Complex].
*/ */
object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> { public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
override val zero: Complex = 0.0.toComplex() override val zero: Complex = 0.0.toComplex()
override val one: Complex = 1.0.toComplex() override val one: Complex = 1.0.toComplex()
/** /**
* The imaginary unit. * The imaginary unit.
*/ */
val i: Complex = Complex(0.0, 1.0) public val i: Complex = Complex(0.0, 1.0)
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
@ -116,7 +115,7 @@ object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
* @param c the augend. * @param c the augend.
* @return the sum. * @return the sum.
*/ */
operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c) public operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c)
/** /**
* Subtracts complex number from real one. * Subtracts complex number from real one.
@ -125,7 +124,7 @@ object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
* @param c the subtrahend. * @param c the subtrahend.
* @return the difference. * @return the difference.
*/ */
operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c) public operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c)
/** /**
* Adds real number to complex one. * Adds real number to complex one.
@ -134,7 +133,7 @@ object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
* @param d the augend. * @param d the augend.
* @return the sum. * @return the sum.
*/ */
operator fun Complex.plus(d: Double): Complex = d + this public operator fun Complex.plus(d: Double): Complex = d + this
/** /**
* Subtracts real number from complex one. * Subtracts real number from complex one.
@ -143,7 +142,7 @@ object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
* @param d the subtrahend. * @param d the subtrahend.
* @return the difference. * @return the difference.
*/ */
operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex()) public operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex())
/** /**
* Multiplies real number by complex one. * Multiplies real number by complex one.
@ -152,7 +151,7 @@ object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
* @param c the multiplicand. * @param c the multiplicand.
* @receiver the product. * @receiver the product.
*/ */
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) public operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg)
@ -165,8 +164,9 @@ object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
* @property re The real part. * @property re The real part.
* @property im The imaginary part. * @property im The imaginary part.
*/ */
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> { public data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>,
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) Comparable<Complex> {
public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
override val context: ComplexField get() = ComplexField override val context: ComplexField get() = ComplexField
@ -176,7 +176,7 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
override fun compareTo(other: Complex): Int = r.compareTo(other.r) override fun compareTo(other: Complex): Int = r.compareTo(other.r)
companion object : MemorySpec<Complex> { public companion object : MemorySpec<Complex> {
override val objectSize: Int = 16 override val objectSize: Int = 16
override fun MemoryReader.read(offset: Int): Complex = override fun MemoryReader.read(offset: Int): Complex =
@ -195,14 +195,10 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
* @receiver the real part. * @receiver the real part.
* @return the new complex number. * @return the new complex number.
*/ */
fun Number.toComplex(): Complex = Complex(this, 0.0) public fun Number.toComplex(): Complex = Complex(this, 0.0)
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { public inline fun Buffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
contract { callsInPlace(init) } MemoryBuffer.create(Complex, size, init)
return MemoryBuffer.create(Complex, size, init)
}
inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
contract { callsInPlace(init) } MemoryBuffer.create(Complex, size, init)
return MemoryBuffer.create(Complex, size, init)
}

View File

@ -0,0 +1,266 @@
package kscience.kmath.operations
import kotlin.math.abs
import kotlin.math.pow as kpow
/**
* Advanced Number-like semifield that implements basic operations.
*/
public interface ExtendedFieldOperations<T> :
FieldOperations<T>,
TrigonometricOperations<T>,
HyperbolicOperations<T>,
PowerOperations<T>,
ExponentialOperations<T> {
public override fun tan(arg: T): T = sin(arg) / cos(arg)
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
public override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg)
TrigonometricOperations.TAN_OPERATION -> tan(arg)
TrigonometricOperations.ACOS_OPERATION -> acos(arg)
TrigonometricOperations.ASIN_OPERATION -> asin(arg)
TrigonometricOperations.ATAN_OPERATION -> atan(arg)
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
PowerOperations.SQRT_OPERATION -> sqrt(arg)
ExponentialOperations.EXP_OPERATION -> exp(arg)
ExponentialOperations.LN_OPERATION -> ln(arg)
else -> super.unaryOperation(operation, arg)
}
}
/**
* Advanced Number-like field that implements basic operations.
*/
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
public override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg)
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right)
}
}
/**
* Real field element wrapping double.
*
* @property value the [Double] value wrapped by this [Real].
*
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/
public inline class Real(public val value: Double) : FieldElement<Double, Real, RealField> {
public override val context: RealField
get() = RealField
public override fun unwrap(): Double = value
public override fun Double.wrap(): Real = Real(value)
public companion object
}
/**
* A field for [Double] without boxing. Does not produce appropriate field element.
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object RealField : ExtendedField<Double>, Norm<Double, Double> {
public override val zero: Double
get() = 0.0
public override val one: Double
get() = 1.0
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
public override inline fun add(a: Double, b: Double): Double = a + b
public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
public override inline fun multiply(a: Double, b: Double): Double = a * b
public override inline fun divide(a: Double, b: Double): Double = a / b
public override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
public override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
public override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
public override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
public override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
public override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
public override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg)
public override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg)
public override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg)
public override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg)
public override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg)
public override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg)
public override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
public override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
public override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
public override inline fun norm(arg: Double): Double = abs(arg)
public override inline fun Double.unaryMinus(): Double = -this
public override inline fun Double.plus(b: Double): Double = this + b
public override inline fun Double.minus(b: Double): Double = this - b
public override inline fun Double.times(b: Double): Double = this * b
public override inline fun Double.div(b: Double): Double = this / b
}
/**
* A field for [Float] without boxing. Does not produce appropriate field element.
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
public override val zero: Float
get() = 0.0f
public override val one: Float
get() = 1.0f
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
public override inline fun add(a: Float, b: Float): Float = a + b
public override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
public override inline fun multiply(a: Float, b: Float): Float = a * b
public override inline fun divide(a: Float, b: Float): Float = a / b
public override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
public override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
public override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
public override inline fun acos(arg: Float): Float = kotlin.math.acos(arg)
public override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
public override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
public override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg)
public override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg)
public override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg)
public override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg)
public override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg)
public override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg)
public override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat())
public override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
public override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
public override inline fun norm(arg: Float): Float = abs(arg)
public override inline fun Float.unaryMinus(): Float = -this
public override inline fun Float.plus(b: Float): Float = this + b
public override inline fun Float.minus(b: Float): Float = this - b
public override inline fun Float.times(b: Float): Float = this * b
public override inline fun Float.div(b: Float): Float = this / b
}
/**
* A field for [Int] without boxing. Does not produce corresponding ring element.
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object IntRing : Ring<Int>, Norm<Int, Int> {
public override val zero: Int
get() = 0
public override val one: Int
get() = 1
public override inline fun add(a: Int, b: Int): Int = a + b
public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
public override inline fun multiply(a: Int, b: Int): Int = a * b
public override inline fun norm(arg: Int): Int = abs(arg)
public override inline fun Int.unaryMinus(): Int = -this
public override inline fun Int.plus(b: Int): Int = this + b
public override inline fun Int.minus(b: Int): Int = this - b
public override inline fun Int.times(b: Int): Int = this * b
}
/**
* A field for [Short] without boxing. Does not produce appropriate ring element.
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ShortRing : Ring<Short>, Norm<Short, Short> {
public override val zero: Short
get() = 0
public override val one: Short
get() = 1
public override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
public override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
public override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
public override inline fun Short.unaryMinus(): Short = (-this).toShort()
public override inline fun Short.plus(b: Short): Short = (this + b).toShort()
public override inline fun Short.minus(b: Short): Short = (this - b).toShort()
public override inline fun Short.times(b: Short): Short = (this * b).toShort()
}
/**
* A field for [Byte] without boxing. Does not produce appropriate ring element.
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
public override val zero: Byte
get() = 0
public override val one: Byte
get() = 1
public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
public override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
public override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
public override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
public override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
public override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
public override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
}
/**
* A field for [Double] without boxing. Does not produce appropriate ring element.
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object LongRing : Ring<Long>, Norm<Long, Long> {
public override val zero: Long
get() = 0
public override val one: Long
get() = 1
public override inline fun add(a: Long, b: Long): Long = a + b
public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
public override inline fun multiply(a: Long, b: Long): Long = a * b
public override fun norm(arg: Long): Long = abs(arg)
public override inline fun Long.unaryMinus(): Long = (-this)
public override inline fun Long.plus(b: Long): Long = (this + b)
public override inline fun Long.minus(b: Long): Long = (this - b)
public override inline fun Long.times(b: Long): Long = (this * b)
}

View File

@ -1,234 +1,234 @@
package scientifik.kmath.operations package kscience.kmath.operations
/** /**
* A container for trigonometric operations for specific type. * A container for trigonometric operations for specific type.
* *
* @param T the type of element of this structure. * @param T the type of element of this structure.
*/ */
interface TrigonometricOperations<T> : Algebra<T> { public interface TrigonometricOperations<T> : Algebra<T> {
/** /**
* Computes the sine of [arg]. * Computes the sine of [arg].
*/ */
fun sin(arg: T): T public fun sin(arg: T): T
/** /**
* Computes the cosine of [arg]. * Computes the cosine of [arg].
*/ */
fun cos(arg: T): T public fun cos(arg: T): T
/** /**
* Computes the tangent of [arg]. * Computes the tangent of [arg].
*/ */
fun tan(arg: T): T public fun tan(arg: T): T
/** /**
* Computes the inverse sine of [arg]. * Computes the inverse sine of [arg].
*/ */
fun asin(arg: T): T public fun asin(arg: T): T
/** /**
* Computes the inverse cosine of [arg]. * Computes the inverse cosine of [arg].
*/ */
fun acos(arg: T): T public fun acos(arg: T): T
/** /**
* Computes the inverse tangent of [arg]. * Computes the inverse tangent of [arg].
*/ */
fun atan(arg: T): T public fun atan(arg: T): T
companion object { public companion object {
/** /**
* The identifier of sine. * The identifier of sine.
*/ */
const val SIN_OPERATION: String = "sin" public const val SIN_OPERATION: String = "sin"
/** /**
* The identifier of cosine. * The identifier of cosine.
*/ */
const val COS_OPERATION: String = "cos" public const val COS_OPERATION: String = "cos"
/** /**
* The identifier of tangent. * The identifier of tangent.
*/ */
const val TAN_OPERATION: String = "tan" public const val TAN_OPERATION: String = "tan"
/** /**
* The identifier of inverse sine. * The identifier of inverse sine.
*/ */
const val ASIN_OPERATION: String = "asin" public const val ASIN_OPERATION: String = "asin"
/** /**
* The identifier of inverse cosine. * The identifier of inverse cosine.
*/ */
const val ACOS_OPERATION: String = "acos" public const val ACOS_OPERATION: String = "acos"
/** /**
* The identifier of inverse tangent. * The identifier of inverse tangent.
*/ */
const val ATAN_OPERATION: String = "atan" public const val ATAN_OPERATION: String = "atan"
} }
} }
/** /**
* Computes the sine of [arg]. * Computes the sine of [arg].
*/ */
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg) public fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
/** /**
* Computes the cosine of [arg]. * Computes the cosine of [arg].
*/ */
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg) public fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
/** /**
* Computes the tangent of [arg]. * Computes the tangent of [arg].
*/ */
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg) public fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
/** /**
* Computes the inverse sine of [arg]. * Computes the inverse sine of [arg].
*/ */
fun <T : MathElement<out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg) public fun <T : MathElement<out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
/** /**
* Computes the inverse cosine of [arg]. * Computes the inverse cosine of [arg].
*/ */
fun <T : MathElement<out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg) public fun <T : MathElement<out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
/** /**
* Computes the inverse tangent of [arg]. * Computes the inverse tangent of [arg].
*/ */
fun <T : MathElement<out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg) public fun <T : MathElement<out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/** /**
* A container for hyperbolic trigonometric operations for specific type. * A container for hyperbolic trigonometric operations for specific type.
* *
* @param T the type of element of this structure. * @param T the type of element of this structure.
*/ */
interface HyperbolicOperations<T> : Algebra<T> { public interface HyperbolicOperations<T> : Algebra<T> {
/** /**
* Computes the hyperbolic sine of [arg]. * Computes the hyperbolic sine of [arg].
*/ */
fun sinh(arg: T): T public fun sinh(arg: T): T
/** /**
* Computes the hyperbolic cosine of [arg]. * Computes the hyperbolic cosine of [arg].
*/ */
fun cosh(arg: T): T public fun cosh(arg: T): T
/** /**
* Computes the hyperbolic tangent of [arg]. * Computes the hyperbolic tangent of [arg].
*/ */
fun tanh(arg: T): T public fun tanh(arg: T): T
/** /**
* Computes the inverse hyperbolic sine of [arg]. * Computes the inverse hyperbolic sine of [arg].
*/ */
fun asinh(arg: T): T public fun asinh(arg: T): T
/** /**
* Computes the inverse hyperbolic cosine of [arg]. * Computes the inverse hyperbolic cosine of [arg].
*/ */
fun acosh(arg: T): T public fun acosh(arg: T): T
/** /**
* Computes the inverse hyperbolic tangent of [arg]. * Computes the inverse hyperbolic tangent of [arg].
*/ */
fun atanh(arg: T): T public fun atanh(arg: T): T
companion object { public companion object {
/** /**
* The identifier of hyperbolic sine. * The identifier of hyperbolic sine.
*/ */
const val SINH_OPERATION: String = "sinh" public const val SINH_OPERATION: String = "sinh"
/** /**
* The identifier of hyperbolic cosine. * The identifier of hyperbolic cosine.
*/ */
const val COSH_OPERATION: String = "cosh" public const val COSH_OPERATION: String = "cosh"
/** /**
* The identifier of hyperbolic tangent. * The identifier of hyperbolic tangent.
*/ */
const val TANH_OPERATION: String = "tanh" public const val TANH_OPERATION: String = "tanh"
/** /**
* The identifier of inverse hyperbolic sine. * The identifier of inverse hyperbolic sine.
*/ */
const val ASINH_OPERATION: String = "asinh" public const val ASINH_OPERATION: String = "asinh"
/** /**
* The identifier of inverse hyperbolic cosine. * The identifier of inverse hyperbolic cosine.
*/ */
const val ACOSH_OPERATION: String = "acosh" public const val ACOSH_OPERATION: String = "acosh"
/** /**
* The identifier of inverse hyperbolic tangent. * The identifier of inverse hyperbolic tangent.
*/ */
const val ATANH_OPERATION: String = "atanh" public const val ATANH_OPERATION: String = "atanh"
} }
} }
/** /**
* Computes the hyperbolic sine of [arg]. * Computes the hyperbolic sine of [arg].
*/ */
fun <T : MathElement<out HyperbolicOperations<T>>> sinh(arg: T): T = arg.context.sinh(arg) public fun <T : MathElement<out HyperbolicOperations<T>>> sinh(arg: T): T = arg.context.sinh(arg)
/** /**
* Computes the hyperbolic cosine of [arg]. * Computes the hyperbolic cosine of [arg].
*/ */
fun <T : MathElement<out HyperbolicOperations<T>>> cosh(arg: T): T = arg.context.cosh(arg) public fun <T : MathElement<out HyperbolicOperations<T>>> cosh(arg: T): T = arg.context.cosh(arg)
/** /**
* Computes the hyperbolic tangent of [arg]. * Computes the hyperbolic tangent of [arg].
*/ */
fun <T : MathElement<out HyperbolicOperations<T>>> tanh(arg: T): T = arg.context.tanh(arg) public fun <T : MathElement<out HyperbolicOperations<T>>> tanh(arg: T): T = arg.context.tanh(arg)
/** /**
* Computes the inverse hyperbolic sine of [arg]. * Computes the inverse hyperbolic sine of [arg].
*/ */
fun <T : MathElement<out HyperbolicOperations<T>>> asinh(arg: T): T = arg.context.asinh(arg) public fun <T : MathElement<out HyperbolicOperations<T>>> asinh(arg: T): T = arg.context.asinh(arg)
/** /**
* Computes the inverse hyperbolic cosine of [arg]. * Computes the inverse hyperbolic cosine of [arg].
*/ */
fun <T : MathElement<out HyperbolicOperations<T>>> acosh(arg: T): T = arg.context.acosh(arg) public fun <T : MathElement<out HyperbolicOperations<T>>> acosh(arg: T): T = arg.context.acosh(arg)
/** /**
* Computes the inverse hyperbolic tangent of [arg]. * Computes the inverse hyperbolic tangent of [arg].
*/ */
fun <T : MathElement<out HyperbolicOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg) public fun <T : MathElement<out HyperbolicOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg)
/** /**
* A context extension to include power operations based on exponentiation. * A context extension to include power operations based on exponentiation.
* *
* @param T the type of element of this structure. * @param T the type of element of this structure.
*/ */
interface PowerOperations<T> : Algebra<T> { public interface PowerOperations<T> : Algebra<T> {
/** /**
* Raises [arg] to the power [pow]. * Raises [arg] to the power [pow].
*/ */
fun power(arg: T, pow: Number): T public fun power(arg: T, pow: Number): T
/** /**
* Computes the square root of the value [arg]. * Computes the square root of the value [arg].
*/ */
fun sqrt(arg: T): T = power(arg, 0.5) public fun sqrt(arg: T): T = power(arg, 0.5)
/** /**
* Raises this value to the power [pow]. * Raises this value to the power [pow].
*/ */
infix fun T.pow(pow: Number): T = power(this, pow) public infix fun T.pow(pow: Number): T = power(this, pow)
companion object { public companion object {
/** /**
* The identifier of exponentiation. * The identifier of exponentiation.
*/ */
const val POW_OPERATION: String = "pow" public const val POW_OPERATION: String = "pow"
/** /**
* The identifier of square root. * The identifier of square root.
*/ */
const val SQRT_OPERATION: String = "sqrt" public const val SQRT_OPERATION: String = "sqrt"
} }
} }
@ -239,56 +239,56 @@ interface PowerOperations<T> : Algebra<T> {
* @param power the exponent. * @param power the exponent.
* @return the base raised to the power. * @return the base raised to the power.
*/ */
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power) public infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
/** /**
* Computes the square root of the value [arg]. * Computes the square root of the value [arg].
*/ */
fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5 public fun <T : MathElement<out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
/** /**
* Computes the square of the value [arg]. * Computes the square of the value [arg].
*/ */
fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0 public fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/** /**
* A container for operations related to `exp` and `ln` functions. * A container for operations related to `exp` and `ln` functions.
* *
* @param T the type of element of this structure. * @param T the type of element of this structure.
*/ */
interface ExponentialOperations<T> : Algebra<T> { public interface ExponentialOperations<T> : Algebra<T> {
/** /**
* Computes Euler's number `e` raised to the power of the value [arg]. * Computes Euler's number `e` raised to the power of the value [arg].
*/ */
fun exp(arg: T): T public fun exp(arg: T): T
/** /**
* Computes the natural logarithm (base `e`) of the value [arg]. * Computes the natural logarithm (base `e`) of the value [arg].
*/ */
fun ln(arg: T): T public fun ln(arg: T): T
companion object { public companion object {
/** /**
* The identifier of exponential function. * The identifier of exponential function.
*/ */
const val EXP_OPERATION: String = "exp" public const val EXP_OPERATION: String = "exp"
/** /**
* The identifier of natural logarithm. * The identifier of natural logarithm.
*/ */
const val LN_OPERATION: String = "ln" public const val LN_OPERATION: String = "ln"
} }
} }
/** /**
* The identifier of exponential function. * The identifier of exponential function.
*/ */
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg) public fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
/** /**
* The identifier of natural logarithm. * The identifier of natural logarithm.
*/ */
fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg) public fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
/** /**
* A container for norm functional on element. * A container for norm functional on element.
@ -296,14 +296,14 @@ fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.
* @param T the type of element having norm defined. * @param T the type of element having norm defined.
* @param R the type of norm. * @param R the type of norm.
*/ */
interface Norm<in T : Any, out R> { public interface Norm<in T : Any, out R> {
/** /**
* Computes the norm of [arg] (i.e. absolute value or vector length). * Computes the norm of [arg] (i.e. absolute value or vector length).
*/ */
fun norm(arg: T): R public fun norm(arg: T): R
} }
/** /**
* Computes the norm of [arg] (i.e. absolute value or vector length). * Computes the norm of [arg] (i.e. absolute value or vector length).
*/ */
fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg) public fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)

View File

@ -1,31 +1,30 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.operations.Field import kscience.kmath.operations.Field
import scientifik.kmath.operations.FieldElement import kscience.kmath.operations.FieldElement
class BoxingNDField<T, F : Field<T>>( public class BoxingNDField<T, F : Field<T>>(
override val shape: IntArray, public override val shape: IntArray,
override val elementContext: F, public override val elementContext: F,
val bufferFactory: BufferFactory<T> public val bufferFactory: BufferFactory<T>
) : BufferedNDField<T, F> { ) : BufferedNDField<T, F> {
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } } public override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } } public override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
override val strides: Strides = DefaultStrides(shape) public override val strides: Strides = DefaultStrides(shape)
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer) bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> { public override fun check(vararg elements: NDBuffer<T>) {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
return elements
} }
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> = public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
BufferedNDFieldElement( BufferedNDFieldElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> { public override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
@ -37,7 +36,7 @@ class BoxingNDField<T, F : Field<T>>(
} }
override fun mapIndexed( public override fun mapIndexed(
arg: NDBuffer<T>, arg: NDBuffer<T>,
transform: F.(index: IntArray, T) -> T transform: F.(index: IntArray, T) -> T
): BufferedNDFieldElement<T, F> { ): BufferedNDFieldElement<T, F> {
@ -56,7 +55,7 @@ class BoxingNDField<T, F : Field<T>>(
// return BufferedNDFieldElement(this, buffer) // return BufferedNDFieldElement(this, buffer)
} }
override fun combine( public override fun combine(
a: NDBuffer<T>, a: NDBuffer<T>,
b: NDBuffer<T>, b: NDBuffer<T>,
transform: F.(T, T) -> T transform: F.(T, T) -> T
@ -67,15 +66,15 @@ class BoxingNDField<T, F : Field<T>>(
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
} }
override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> = public override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> =
BufferedNDFieldElement(this@BoxingNDField, buffer) BufferedNDFieldElement(this@BoxingNDField, buffer)
} }
inline fun <T : Any, F : Field<T>, R> F.nd( public inline fun <T : Any, F : Field<T>, R> F.nd(
noinline bufferFactory: BufferFactory<T>, noinline bufferFactory: BufferFactory<T>,
vararg shape: Int, vararg shape: Int,
action: NDField<T, F, *>.() -> R action: NDField<T, F, *>.() -> R
): R { ): R {
val ndfield: BoxingNDField<T, F> = NDField.boxing(this, *shape, bufferFactory = bufferFactory) val ndfield = NDField.boxing(this, *shape, bufferFactory = bufferFactory)
return ndfield.action() return ndfield.action()
} }

View File

@ -1,18 +1,18 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.operations.Ring import kscience.kmath.operations.Ring
import scientifik.kmath.operations.RingElement import kscience.kmath.operations.RingElement
class BoxingNDRing<T, R : Ring<T>>( public class BoxingNDRing<T, R : Ring<T>>(
override val shape: IntArray, override val shape: IntArray,
override val elementContext: R, override val elementContext: R,
val bufferFactory: BufferFactory<T> public val bufferFactory: BufferFactory<T>
) : BufferedNDRing<T, R> { ) : BufferedNDRing<T, R> {
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } } override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
override val one: BufferedNDRingElement<T, R> by lazy { produce { one } } override val one: BufferedNDRingElement<T, R> by lazy { produce { one } }
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer) public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> { override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
@ -60,6 +60,7 @@ class BoxingNDRing<T, R : Ring<T>>(
transform: R.(T, T) -> T transform: R.(T, T) -> T
): BufferedNDRingElement<T, R> { ): BufferedNDRingElement<T, R> {
check(a, b) check(a, b)
return BufferedNDRingElement( return BufferedNDRingElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })

View File

@ -1,28 +1,27 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
* A context that allows to operate on a [MutableBuffer] as on 2d array * A context that allows to operate on a [MutableBuffer] as on 2d array
*/ */
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) { public class BufferAccessor2D<T : Any>(public val type: KClass<T>, public val rowNum: Int, public val colNum: Int) {
operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j) public operator fun Buffer<T>.get(i: Int, j: Int): T = get(i + colNum * j)
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) { public operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
set(i + colNum * j, value) set(i + colNum * j, value)
} }
inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer<T> = public inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer<T> =
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] } public fun create(mat: Structure2D<T>): MutableBuffer<T> = create { i, j -> mat[i, j] }
//TODO optimize wrapper //TODO optimize wrapper
fun MutableBuffer<T>.collect(): Structure2D<T> = public fun MutableBuffer<T>.collect(): Structure2D<T> =
NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D() NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D()
public inner class Row(public val buffer: MutableBuffer<T>, public val rowIndex: Int) : MutableBuffer<T> {
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
override val size: Int get() = colNum override val size: Int get() = colNum
override operator fun get(index: Int): T = buffer[rowIndex, index] override operator fun get(index: Int): T = buffer[rowIndex, index]
@ -39,5 +38,5 @@ class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum
/** /**
* Get row * Get row
*/ */
fun MutableBuffer<T>.row(i: Int): Row = Row(this, i) public fun MutableBuffer<T>.row(i: Int): Row = Row(this, i)
} }

View File

@ -0,0 +1,41 @@
package kscience.kmath.structures
import kscience.kmath.operations.*
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
public val strides: Strides
public override fun check(vararg elements: NDBuffer<T>): Unit =
require(elements.all { it.strides == strides }) { ("Strides mismatch") }
/**
* Convert any [NDStructure] to buffered structure using strides from this context.
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over
* indices.
*
* If the argument is [NDBuffer] with different strides structure, the new element will be produced.
*/
public fun NDStructure<T>.toBuffer(): NDBuffer<T> =
if (this is NDBuffer<T> && this.strides == this@BufferedNDAlgebra.strides)
this
else
produce { index -> this@toBuffer[index] }
/**
* Convert a buffer to element of this algebra
*/
public fun NDBuffer<T>.toElement(): MathElement<out BufferedNDAlgebra<T, C>>
}
public interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
public override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
}
public interface BufferedNDRing<T, R : Ring<T>> : NDRing<T, R, NDBuffer<T>>, BufferedNDSpace<T, R> {
override fun NDBuffer<T>.toElement(): RingElement<NDBuffer<T>, *, out BufferedNDRing<T, R>>
}
public interface BufferedNDField<T, F : Field<T>> : NDField<T, F, NDBuffer<T>>, BufferedNDRing<T, F> {
override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>>
}

View File

@ -1,11 +1,11 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.operations.* import kscience.kmath.operations.*
/** /**
* Base class for an element with context, containing strides * Base class for an element with context, containing strides
*/ */
abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> { public abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer<T>> {
abstract override val context: BufferedNDAlgebra<T, C> abstract override val context: BufferedNDAlgebra<T, C>
override val strides: Strides get() = context.strides override val strides: Strides get() = context.strides
@ -13,7 +13,7 @@ abstract class BufferedNDElement<T, C> : NDBuffer<T>(), NDElement<T, C, NDBuffer
override val shape: IntArray get() = context.shape override val shape: IntArray get() = context.shape
} }
class BufferedNDSpaceElement<T, S : Space<T>>( public class BufferedNDSpaceElement<T, S : Space<T>>(
override val context: BufferedNDSpace<T, S>, override val context: BufferedNDSpace<T, S>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, S>(), SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> { ) : BufferedNDElement<T, S>(), SpaceElement<NDBuffer<T>, BufferedNDSpaceElement<T, S>, BufferedNDSpace<T, S>> {
@ -26,7 +26,7 @@ class BufferedNDSpaceElement<T, S : Space<T>>(
} }
} }
class BufferedNDRingElement<T, R : Ring<T>>( public class BufferedNDRingElement<T, R : Ring<T>>(
override val context: BufferedNDRing<T, R>, override val context: BufferedNDRing<T, R>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> { ) : BufferedNDElement<T, R>(), RingElement<NDBuffer<T>, BufferedNDRingElement<T, R>, BufferedNDRing<T, R>> {
@ -38,7 +38,7 @@ class BufferedNDRingElement<T, R : Ring<T>>(
} }
} }
class BufferedNDFieldElement<T, F : Field<T>>( public class BufferedNDFieldElement<T, F : Field<T>>(
override val context: BufferedNDField<T, F>, override val context: BufferedNDField<T, F>,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> { ) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
@ -54,22 +54,21 @@ class BufferedNDFieldElement<T, F : Field<T>>(
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy. * Element by element application of any operation on elements to the whole array. Just like in numpy.
*/ */
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke( public operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>): MathElement<out BufferedNDAlgebra<T, F>> =
ndElement: BufferedNDElement<T, F> ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
): MathElement<out BufferedNDAlgebra<T, F>> = ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
/* plus and minus */ /* plus and minus */
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T): NDElement<T, F, NDBuffer<T>> = public operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.plus(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it + arg }.wrap() context.map(this) { it + arg }.wrap()
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T): NDElement<T, F, NDBuffer<T>> = public operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it - arg }.wrap() context.map(this) { it - arg }.wrap()
/* prod and div */ /* prod and div */
@ -77,11 +76,11 @@ operator fun <T : Any, F : Space<T>> BufferedNDElement<T, F>.minus(arg: T): NDEl
/** /**
* Product operation for [BufferedNDElement] and single element * Product operation for [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T): NDElement<T, F, NDBuffer<T>> = public operator fun <T : Any, F : Ring<T>> BufferedNDElement<T, F>.times(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it * arg }.wrap() context.map(this) { it * arg }.wrap()
/** /**
* Division operation between [BufferedNDElement] and single element * Division operation between [BufferedNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T): NDElement<T, F, NDBuffer<T>> = public operator fun <T : Any, F : Field<T>> BufferedNDElement<T, F>.div(arg: T): NDElement<T, F, NDBuffer<T>> =
context.map(this) { it / arg }.wrap() context.map(this) { it / arg }.wrap()

View File

@ -1,8 +1,7 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.operations.Complex import kscience.kmath.operations.Complex
import scientifik.kmath.operations.complex import kscience.kmath.operations.complex
import kotlin.contracts.contract
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -10,44 +9,44 @@ import kotlin.reflect.KClass
* *
* @param T the type of buffer. * @param T the type of buffer.
*/ */
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T> public typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
/** /**
* Function that produces [MutableBuffer] from its size and function that supplies values. * Function that produces [MutableBuffer] from its size and function that supplies values.
* *
* @param T the type of buffer. * @param T the type of buffer.
*/ */
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T> public typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
/** /**
* A generic immutable random-access structure for both primitives and objects. * A generic immutable random-access structure for both primitives and objects.
* *
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
*/ */
interface Buffer<T> { public interface Buffer<T> {
/** /**
* The size of this buffer. * The size of this buffer.
*/ */
val size: Int public val size: Int
/** /**
* Gets element at given index. * Gets element at given index.
*/ */
operator fun get(index: Int): T public operator fun get(index: Int): T
/** /**
* Iterates over all elements. * Iterates over all elements.
*/ */
operator fun iterator(): Iterator<T> public operator fun iterator(): Iterator<T>
/** /**
* Checks content equality with another buffer. * Checks content equality with another buffer.
*/ */
fun contentEquals(other: Buffer<*>): Boolean = public fun contentEquals(other: Buffer<*>): Boolean =
asSequence().mapIndexed { index, value -> value == other[index] }.all { it } asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
companion object { public companion object {
inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
val array = DoubleArray(size) { initializer(it) } val array = DoubleArray(size) { initializer(it) }
return RealBuffer(array) return RealBuffer(array)
} }
@ -55,10 +54,11 @@ interface Buffer<T> {
/** /**
* Create a boxing buffer of given type * Create a boxing buffer of given type
*/ */
inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer)) public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> =
ListBuffer(List(size, initializer))
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> { public inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
//TODO add resolution based on Annotation or companion resolution //TODO add resolution based on Annotation or companion resolution
return when (type) { return when (type) {
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
@ -74,7 +74,7 @@ interface Buffer<T> {
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible * Create most appropriate immutable buffer for given type avoiding boxing wherever possible
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> = public inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> =
auto(T::class, size, initializer) auto(T::class, size, initializer)
} }
} }
@ -82,43 +82,43 @@ interface Buffer<T> {
/** /**
* Creates a sequence that returns all elements from this [Buffer]. * Creates a sequence that returns all elements from this [Buffer].
*/ */
fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator) public fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
/** /**
* Creates an iterable that returns all elements from this [Buffer]. * Creates an iterable that returns all elements from this [Buffer].
*/ */
fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator) public fun <T> Buffer<T>.asIterable(): Iterable<T> = Iterable(::iterator)
/** /**
* Returns an [IntRange] of the valid indices for this [Buffer]. * Returns an [IntRange] of the valid indices for this [Buffer].
*/ */
val Buffer<*>.indices: IntRange get() = 0 until size public val Buffer<*>.indices: IntRange get() = 0 until size
/** /**
* A generic mutable random-access structure for both primitives and objects. * A generic mutable random-access structure for both primitives and objects.
* *
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
*/ */
interface MutableBuffer<T> : Buffer<T> { public interface MutableBuffer<T> : Buffer<T> {
/** /**
* Sets the array element at the specified [index] to the specified [value]. * Sets the array element at the specified [index] to the specified [value].
*/ */
operator fun set(index: Int, value: T) public operator fun set(index: Int, value: T)
/** /**
* Returns a shallow copy of the buffer. * Returns a shallow copy of the buffer.
*/ */
fun copy(): MutableBuffer<T> public fun copy(): MutableBuffer<T>
companion object { public companion object {
/** /**
* Create a boxing mutable buffer of given type * Create a boxing mutable buffer of given type
*/ */
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
MutableListBuffer(MutableList(size, initializer)) MutableListBuffer(MutableList(size, initializer))
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
when (type) { when (type) {
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
@ -131,12 +131,11 @@ interface MutableBuffer<T> : Buffer<T> {
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible * Create most appropriate mutable buffer for given type avoiding boxing wherever possible
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
auto(T::class, size, initializer) auto(T::class, size, initializer)
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double -> public val real: MutableBufferFactory<Double> =
RealBuffer(DoubleArray(size) { initializer(it) }) { size, initializer -> RealBuffer(DoubleArray(size) { initializer(it) }) }
}
} }
} }
@ -146,7 +145,7 @@ interface MutableBuffer<T> : Buffer<T> {
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property list The underlying list. * @property list The underlying list.
*/ */
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> { public inline class ListBuffer<T>(public val list: List<T>) : Buffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
@ -157,7 +156,7 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
/** /**
* Returns an [ListBuffer] that wraps the original list. * Returns an [ListBuffer] that wraps the original list.
*/ */
fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this) public fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
/** /**
* Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified * Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified
@ -166,10 +165,7 @@ fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an array element given its index. * It should return the value for an array element given its index.
*/ */
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> { public inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer()
contract { callsInPlace(init) }
return List(size, init).asBuffer()
}
/** /**
* [MutableBuffer] implementation over [MutableList]. * [MutableBuffer] implementation over [MutableList].
@ -177,7 +173,7 @@ inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> {
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property list The underlying list. * @property list The underlying list.
*/ */
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> { public inline class MutableListBuffer<T>(public val list: MutableList<T>) : MutableBuffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
@ -197,7 +193,7 @@ inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property array The underlying array. * @property array The underlying array.
*/ */
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> { public class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
// Can't inline because array is invariant // Can't inline because array is invariant
override val size: Int override val size: Int
get() = array.size get() = array.size
@ -215,7 +211,7 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
/** /**
* Returns an [ArrayBuffer] that wraps the original array. * Returns an [ArrayBuffer] that wraps the original array.
*/ */
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this) public fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
/** /**
* Immutable wrapper for [MutableBuffer]. * Immutable wrapper for [MutableBuffer].
@ -223,7 +219,7 @@ fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property buffer The underlying buffer. * @property buffer The underlying buffer.
*/ */
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> { public inline class ReadOnlyBuffer<T>(public val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size override val size: Int get() = buffer.size
override operator fun get(index: Int): T = buffer[index] override operator fun get(index: Int): T = buffer[index]
@ -237,7 +233,7 @@ inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
* *
* @param T the type of elements provided by the buffer. * @param T the type of elements provided by the buffer.
*/ */
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> { public class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
override operator fun get(index: Int): T { override operator fun get(index: Int): T {
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
return generator(index) return generator(index)
@ -257,14 +253,14 @@ class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T
/** /**
* Convert this buffer to read-only buffer. * Convert this buffer to read-only buffer.
*/ */
fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) ReadOnlyBuffer(this) else this public fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) ReadOnlyBuffer(this) else this
/** /**
* Typealias for buffer transformations. * Typealias for buffer transformations.
*/ */
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R> public typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
/** /**
* Typealias for buffer transformations with suspend function. * Typealias for buffer transformations with suspend function.
*/ */
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R> public typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>

View File

@ -1,18 +1,18 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.operations.Complex import kscience.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import kscience.kmath.operations.ComplexField
import scientifik.kmath.operations.FieldElement import kscience.kmath.operations.FieldElement
import scientifik.kmath.operations.complex import kscience.kmath.operations.complex
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField> public typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
/** /**
* An optimized nd-field for complex numbers * An optimized nd-field for complex numbers
*/ */
class ComplexNDField(override val shape: IntArray) : public class ComplexNDField(override val shape: IntArray) :
BufferedNDField<Complex, ComplexField>, BufferedNDField<Complex, ComplexField>,
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> { ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
@ -21,7 +21,7 @@ class ComplexNDField(override val shape: IntArray) :
override val zero: ComplexNDElement by lazy { produce { zero } } override val zero: ComplexNDElement by lazy { produce { zero } }
override val one: ComplexNDElement by lazy { produce { one } } override val one: ComplexNDElement by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> = public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
Buffer.complex(size) { initializer(it) } Buffer.complex(size) { initializer(it) }
/** /**
@ -97,7 +97,7 @@ class ComplexNDField(override val shape: IntArray) :
/** /**
* Fast element production using function inlining * Fast element production using function inlining
*/ */
inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline initializer: ComplexField.(Int) -> Complex): ComplexNDElement { public inline fun BufferedNDField<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): ComplexNDElement {
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
return BufferedNDFieldElement(this, buffer) return BufferedNDFieldElement(this, buffer)
} }
@ -105,14 +105,13 @@ inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline init
/** /**
* Map one [ComplexNDElement] using function with indices. * Map one [ComplexNDElement] using function with indices.
*/ */
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement = public inline fun ComplexNDElement.mapIndexed(transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
/** /**
* Map one [ComplexNDElement] using function without indices. * Map one [ComplexNDElement] using function without indices.
*/ */
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { public inline fun ComplexNDElement.map(transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
contract { callsInPlace(transform) }
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
return BufferedNDFieldElement(context, buffer) return BufferedNDFieldElement(context, buffer)
} }
@ -120,38 +119,35 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) ->
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement = public operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
ndElement.map { this@invoke(it) } ndElement.map { this@invoke(it) }
/* plus and minus */ /* plus and minus */
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg } public operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg }
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement = public operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement = map { it - arg }
map { it - arg }
operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement = public operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement = map { it + arg }
map { it + arg } public operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement = map { it - arg }
operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement = public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
map { it - arg }
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape) public fun NDElement.Companion.complex(
vararg shape: Int,
fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement = initializer: ComplexField.(IntArray) -> Complex
NDField.complex(*shape).produce(initializer) ): ComplexNDElement = NDField.complex(*shape).produce(initializer)
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { public inline fun <R> ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return NDField.complex(*shape).action() return NDField.complex(*shape).action()
} }

View File

@ -1,6 +1,6 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.kmath.operations.ExtendedField import kscience.kmath.operations.ExtendedField
/** /**
* [ExtendedField] over [NDStructure]. * [ExtendedField] over [NDStructure].
@ -9,7 +9,7 @@ import scientifik.kmath.operations.ExtendedField
* @param N the type of ND structure. * @param N the type of ND structure.
* @param F the extended field of structure elements. * @param F the extended field of structure elements.
*/ */
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N> public interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
///** ///**
// * NDField that supports [ExtendedField] operations on its elements // * NDField that supports [ExtendedField] operations on its elements

View File

@ -1,6 +1,5 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlin.contracts.contract
import kotlin.experimental.and import kotlin.experimental.and
/** /**
@ -8,7 +7,7 @@ import kotlin.experimental.and
* *
* @property mask bit mask value of this flag. * @property mask bit mask value of this flag.
*/ */
enum class ValueFlag(val mask: Byte) { public enum class ValueFlag(public val mask: Byte) {
/** /**
* Reports the value is NaN. * Reports the value is NaN.
*/ */
@ -33,23 +32,23 @@ enum class ValueFlag(val mask: Byte) {
/** /**
* A buffer with flagged values. * A buffer with flagged values.
*/ */
interface FlaggedBuffer<T> : Buffer<T> { public interface FlaggedBuffer<T> : Buffer<T> {
fun getFlag(index: Int): Byte public fun getFlag(index: Int): Byte
} }
/** /**
* The value is valid if all flags are down * The value is valid if all flags are down
*/ */
fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte() public fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte() public fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING) public fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
/** /**
* A real buffer which supports flags for each value like NaN or Missing * A real buffer which supports flags for each value like NaN or Missing
*/ */
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> { public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
init { init {
require(values.size == flags.size) { "Values and flags must have the same dimensions" } require(values.size == flags.size) { "Values and flags must have the same dimensions" }
} }
@ -65,9 +64,7 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
}.iterator() }.iterator()
} }
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { public inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
contract { callsInPlace(block) }
indices indices
.asSequence() .asSequence()
.filter(::isValid) .filter(::isValid)

View File

@ -1,13 +1,12 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [FloatArray]. * Specialized [MutableBuffer] implementation over [FloatArray].
* *
* @property array the underlying array. * @property array the underlying array.
* @author Iaroslav Postovalov
*/ */
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> { public inline class FloatBuffer(public val array: FloatArray) : MutableBuffer<Float> {
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Float = array[index] override operator fun get(index: Int): Float = array[index]
@ -29,20 +28,17 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer { public inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) })
contract { callsInPlace(init) }
return FloatBuffer(FloatArray(size) { init(it) })
}
/** /**
* Returns a new [FloatBuffer] of given elements. * Returns a new [FloatBuffer] of given elements.
*/ */
fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats) public fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats)
/** /**
* Returns a [FloatArray] containing all of the elements of this [MutableBuffer]. * Returns a [FloatArray] containing all of the elements of this [MutableBuffer].
*/ */
val MutableBuffer<out Float>.array: FloatArray public val MutableBuffer<out Float>.array: FloatArray
get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) }) get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) })
/** /**
@ -51,4 +47,4 @@ val MutableBuffer<out Float>.array: FloatArray
* @receiver the array. * @receiver the array.
* @return the new buffer. * @return the new buffer.
*/ */
fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this) public fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this)

View File

@ -1,13 +1,11 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [IntArray]. * Specialized [MutableBuffer] implementation over [IntArray].
* *
* @property array the underlying array. * @property array the underlying array.
*/ */
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> { public inline class IntBuffer(public val array: IntArray) : MutableBuffer<Int> {
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Int = array[index] override operator fun get(index: Int): Int = array[index]
@ -29,17 +27,17 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) }) public inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) })
/** /**
* Returns a new [IntBuffer] of given elements. * Returns a new [IntBuffer] of given elements.
*/ */
fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints) public fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints)
/** /**
* Returns a [IntArray] containing all of the elements of this [MutableBuffer]. * Returns a [IntArray] containing all of the elements of this [MutableBuffer].
*/ */
val MutableBuffer<out Int>.array: IntArray public val MutableBuffer<out Int>.array: IntArray
get() = (if (this is IntBuffer) array else IntArray(size) { get(it) }) get() = (if (this is IntBuffer) array else IntArray(size) { get(it) })
/** /**
@ -48,4 +46,4 @@ val MutableBuffer<out Int>.array: IntArray
* @receiver the array. * @receiver the array.
* @return the new buffer. * @return the new buffer.
*/ */
fun IntArray.asBuffer(): IntBuffer = IntBuffer(this) public fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)

View File

@ -1,13 +1,11 @@
package scientifik.kmath.structures package kscience.kmath.structures
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [LongArray]. * Specialized [MutableBuffer] implementation over [LongArray].
* *
* @property array the underlying array. * @property array the underlying array.
*/ */
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> { public inline class LongBuffer(public val array: LongArray) : MutableBuffer<Long> {
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Long = array[index] override operator fun get(index: Int): Long = array[index]
@ -20,7 +18,6 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
override fun copy(): MutableBuffer<Long> = override fun copy(): MutableBuffer<Long> =
LongBuffer(array.copyOf()) LongBuffer(array.copyOf())
} }
/** /**
@ -30,20 +27,17 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer { public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) })
contract { callsInPlace(init) }
return LongBuffer(LongArray(size) { init(it) })
}
/** /**
* Returns a new [LongBuffer] of given elements. * Returns a new [LongBuffer] of given elements.
*/ */
fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs) public fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs)
/** /**
* Returns a [IntArray] containing all of the elements of this [MutableBuffer]. * Returns a [IntArray] containing all of the elements of this [MutableBuffer].
*/ */
val MutableBuffer<out Long>.array: LongArray public val MutableBuffer<out Long>.array: LongArray
get() = (if (this is LongBuffer) array else LongArray(size) { get(it) }) get() = (if (this is LongBuffer) array else LongArray(size) { get(it) })
/** /**
@ -52,4 +46,4 @@ val MutableBuffer<out Long>.array: LongArray
* @receiver the array. * @receiver the array.
* @return the new buffer. * @return the new buffer.
*/ */
fun LongArray.asBuffer(): LongBuffer = LongBuffer(this) public fun LongArray.asBuffer(): LongBuffer = LongBuffer(this)

View File

@ -1,6 +1,6 @@
package scientifik.kmath.structures package kscience.kmath.structures
import scientifik.memory.* import kscience.memory.*
/** /**
* A non-boxing buffer over [Memory] object. * A non-boxing buffer over [Memory] object.
@ -9,7 +9,7 @@ import scientifik.memory.*
* @property memory the underlying memory segment. * @property memory the underlying memory segment.
* @property spec the spec of [T] type. * @property spec the spec of [T] type.
*/ */
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> { public open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
override val size: Int get() = memory.size / spec.objectSize override val size: Int get() = memory.size / spec.objectSize
private val reader: MemoryReader = memory.reader() private val reader: MemoryReader = memory.reader()
@ -17,19 +17,16 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index) override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
companion object { public companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> = public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( public inline fun <T : Any> create(
spec: MemorySpec<T>, spec: MemorySpec<T>,
size: Int, size: Int,
crossinline initializer: (Int) -> T initializer: (Int) -> T
): MemoryBuffer<T> = ): MemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> (0 until size).forEach { buffer[it] = initializer(it) }
(0 until size).forEach {
buffer[it] = initializer(it)
}
} }
} }
} }
@ -41,7 +38,7 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
* @property memory the underlying memory segment. * @property memory the underlying memory segment.
* @property spec the spec of [T] type. * @property spec the spec of [T] type.
*/ */
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec), public class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
MutableBuffer<T> { MutableBuffer<T> {
private val writer: MemoryWriter = memory.writer() private val writer: MemoryWriter = memory.writer()
@ -49,19 +46,16 @@ class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : Memory
override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec) override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
companion object { public companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> = public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( public inline fun <T : Any> create(
spec: MemorySpec<T>, spec: MemorySpec<T>,
size: Int, size: Int,
crossinline initializer: (Int) -> T crossinline initializer: (Int) -> T
): MutableMemoryBuffer<T> = ): MutableMemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> (0 until size).forEach { buffer[it] = initializer(it) }
(0 until size).forEach {
buffer[it] = initializer(it)
}
} }
} }
} }

Some files were not shown because too many files have changed in this diff Show More