forked from kscience/kmath
commit
82e4ea8897
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,6 +1,7 @@
|
|||||||
.gradle
|
.gradle
|
||||||
**/build/
|
build/
|
||||||
/.idea/
|
out/
|
||||||
|
.idea/
|
||||||
|
|
||||||
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
||||||
!gradle-wrapper.jar
|
!gradle-wrapper.jar
|
||||||
|
@ -21,7 +21,7 @@ dependencies {
|
|||||||
//jmh project(':kmath-core')
|
//jmh project(':kmath-core')
|
||||||
}
|
}
|
||||||
|
|
||||||
jmh{
|
jmh {
|
||||||
warmupIterations = 1
|
warmupIterations = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,16 +4,14 @@ import org.openjdk.jmh.annotations.Benchmark
|
|||||||
import org.openjdk.jmh.annotations.Scope
|
import org.openjdk.jmh.annotations.Scope
|
||||||
import org.openjdk.jmh.annotations.State
|
import org.openjdk.jmh.annotations.State
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
|
import scientifik.kmath.operations.complex
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
open class BufferBenchmark {
|
open class BufferBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun genericDoubleBufferReadWrite() {
|
fun genericDoubleBufferReadWrite() {
|
||||||
val buffer = Double.createBuffer(size)
|
val buffer = DoubleBuffer(size){it.toDouble()}
|
||||||
(0 until size).forEach {
|
|
||||||
buffer[it] = it.toDouble()
|
|
||||||
}
|
|
||||||
|
|
||||||
(0 until size).forEach {
|
(0 until size).forEach {
|
||||||
buffer[it]
|
buffer[it]
|
||||||
@ -22,10 +20,7 @@ open class BufferBenchmark {
|
|||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun complexBufferReadWrite() {
|
fun complexBufferReadWrite() {
|
||||||
val buffer = MutableBuffer.complex(size / 2)
|
val buffer = MutableBuffer.complex(size / 2){Complex(it.toDouble(), -it.toDouble())}
|
||||||
(0 until size / 2).forEach {
|
|
||||||
buffer[it] = Complex(it.toDouble(), -it.toDouble())
|
|
||||||
}
|
|
||||||
|
|
||||||
(0 until size / 2).forEach {
|
(0 until size / 2).forEach {
|
||||||
buffer[it]
|
buffer[it]
|
||||||
|
@ -34,19 +34,6 @@ open class NDFieldBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Benchmark
|
|
||||||
fun lazyFieldAdd() {
|
|
||||||
lazyNDField.run {
|
|
||||||
var res = one
|
|
||||||
repeat(n) {
|
|
||||||
res += one
|
|
||||||
}
|
|
||||||
|
|
||||||
res.elements().sumByDouble { it.second }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun boxingFieldAdd() {
|
fun boxingFieldAdd() {
|
||||||
genericField.run {
|
genericField.run {
|
||||||
@ -61,9 +48,8 @@ open class NDFieldBenchmark {
|
|||||||
val dim = 1000
|
val dim = 1000
|
||||||
val n = 100
|
val n = 100
|
||||||
|
|
||||||
val bufferedField = NDField.auto(RealField, intArrayOf(dim, dim))
|
val bufferedField = NDField.auto(RealField, dim, dim)
|
||||||
val specializedField = NDField.real(intArrayOf(dim, dim))
|
val specializedField = NDField.real(dim, dim)
|
||||||
val genericField = NDField.buffered(intArrayOf(dim, dim), RealField)
|
val genericField = NDField.buffered(intArrayOf(dim, dim), RealField)
|
||||||
val lazyNDField = NDField.lazy(intArrayOf(dim, dim), RealField)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,39 +1,42 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import koma.matrix.ejml.EJMLMatrixFactory
|
import koma.matrix.ejml.EJMLMatrixFactory
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
@ExperimentalContracts
|
||||||
fun main() {
|
fun main() {
|
||||||
val random = Random(12224)
|
val random = Random(1224)
|
||||||
val dim = 100
|
val dim = 100
|
||||||
//creating invertible matrix
|
//creating invertible matrix
|
||||||
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
|
val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
|
||||||
val matrix = l dot u
|
val matrix = l dot u
|
||||||
|
|
||||||
val n = 500 // iterations
|
val n = 5000 // iterations
|
||||||
|
|
||||||
val solver = LUSolver.real
|
MatrixContext.real.run {
|
||||||
|
|
||||||
repeat(50) {
|
repeat(50) {
|
||||||
val res = solver.inverse(matrix)
|
val res = inverse(matrix)
|
||||||
}
|
|
||||||
|
|
||||||
val inverseTime = measureTimeMillis {
|
|
||||||
repeat(n) {
|
|
||||||
val res = solver.inverse(matrix)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
val inverseTime = measureTimeMillis {
|
||||||
|
repeat(n) {
|
||||||
|
val res = inverse(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||||
|
}
|
||||||
|
|
||||||
//commons-math
|
//commons-math
|
||||||
|
|
||||||
val cmContext = CMMatrixContext
|
|
||||||
|
|
||||||
val commonsTime = measureTimeMillis {
|
val commonsTime = measureTimeMillis {
|
||||||
cmContext.run {
|
CMMatrixContext.run {
|
||||||
val cm = matrix.toCM() //avoid overhead on conversion
|
val cm = matrix.toCM() //avoid overhead on conversion
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
val res = inverse(cm)
|
val res = inverse(cm)
|
||||||
@ -46,10 +49,8 @@ fun main() {
|
|||||||
|
|
||||||
//koma-ejml
|
//koma-ejml
|
||||||
|
|
||||||
val komaContext = KomaMatrixContext(EJMLMatrixFactory())
|
|
||||||
|
|
||||||
val komaTime = measureTimeMillis {
|
val komaTime = measureTimeMillis {
|
||||||
komaContext.run {
|
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
|
||||||
val km = matrix.toKoma() //avoid overhead on conversion
|
val km = matrix.toKoma() //avoid overhead on conversion
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
val res = inverse(km)
|
val res = inverse(km)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import koma.matrix.ejml.EJMLMatrixFactory
|
import koma.matrix.ejml.EJMLMatrixFactory
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
@ -26,7 +28,7 @@ fun main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
KomaMatrixContext(EJMLMatrixFactory()).run {
|
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
|
||||||
val komaMatrix1 = matrix1.toKoma()
|
val komaMatrix1 = matrix1.toKoma()
|
||||||
val komaMatrix2 = matrix2.toKoma()
|
val komaMatrix2 = matrix2.toKoma()
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ fun main() {
|
|||||||
|
|
||||||
val complexTime = measureTimeMillis {
|
val complexTime = measureTimeMillis {
|
||||||
complexField.run {
|
complexField.run {
|
||||||
var res: ComplexNDElement = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += 1.0
|
||||||
}
|
}
|
||||||
|
117
build.gradle.kts
117
build.gradle.kts
@ -1,110 +1,29 @@
|
|||||||
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
|
val kmathVersion by extra("0.1.2-dev-4")
|
||||||
|
|
||||||
buildscript {
|
|
||||||
val kotlinVersion: String by rootProject.extra("1.3.21")
|
|
||||||
val ioVersion: String by rootProject.extra("0.1.5")
|
|
||||||
val coroutinesVersion: String by rootProject.extra("1.1.1")
|
|
||||||
val atomicfuVersion: String by rootProject.extra("0.12.1")
|
|
||||||
val dokkaVersion: String by rootProject.extra("0.9.17")
|
|
||||||
|
|
||||||
repositories {
|
|
||||||
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
|
||||||
jcenter()
|
|
||||||
}
|
|
||||||
|
|
||||||
dependencies {
|
|
||||||
classpath("org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlinVersion")
|
|
||||||
classpath("org.jfrog.buildinfo:build-info-extractor-gradle:4+")
|
|
||||||
classpath("com.jfrog.bintray.gradle:gradle-bintray-plugin:1.8.4")
|
|
||||||
classpath("org.jetbrains.dokka:dokka-gradle-plugin:$dokkaVersion")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
plugins {
|
|
||||||
id("com.jfrog.artifactory") version "4.9.1" apply false
|
|
||||||
}
|
|
||||||
|
|
||||||
val kmathVersion by extra("0.1.0")
|
|
||||||
|
|
||||||
allprojects {
|
allprojects {
|
||||||
group = "scientifik"
|
|
||||||
version = kmathVersion
|
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
|
||||||
jcenter()
|
jcenter()
|
||||||
|
maven("https://kotlin.bintray.com/kotlinx")
|
||||||
}
|
}
|
||||||
|
|
||||||
apply(plugin = "maven")
|
group = "scientifik"
|
||||||
apply(plugin = "maven-publish")
|
version = kmathVersion
|
||||||
|
|
||||||
// apply bintray configuration
|
|
||||||
apply(from = "${rootProject.rootDir}/gradle/bintray.gradle")
|
|
||||||
|
|
||||||
//apply artifactory configuration
|
|
||||||
apply(from = "${rootProject.rootDir}/gradle/artifactory.gradle")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
if (!name.startsWith("kmath")) return@subprojects
|
// Actually, probably we should apply it to plugins explicitly
|
||||||
|
// We also can merge them to single kmath-publish plugin
|
||||||
|
if (name.startsWith("kmath")) {
|
||||||
extensions.findByType<KotlinMultiplatformExtension>()?.apply {
|
apply(plugin = "bintray-config")
|
||||||
jvm {
|
apply(plugin = "artifactory-config")
|
||||||
compilations.all {
|
|
||||||
kotlinOptions {
|
|
||||||
jvmTarget = "1.8"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
targets.all {
|
|
||||||
sourceSets.all {
|
|
||||||
languageSettings.progressiveMode = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
extensions.findByType<PublishingExtension>()?.apply {
|
|
||||||
publications.filterIsInstance<MavenPublication>().forEach { publication ->
|
|
||||||
if (publication.name == "kotlinMultiplatform") {
|
|
||||||
// for our root metadata publication, set artifactId with a package and project name
|
|
||||||
publication.artifactId = project.name
|
|
||||||
} else {
|
|
||||||
// for targets, set artifactId with a package, project name and target name (e.g. iosX64)
|
|
||||||
publication.artifactId = "${project.name}-${publication.name}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create empty jar for sources classifier to satisfy maven requirements
|
|
||||||
val stubSources by tasks.registering(Jar::class) {
|
|
||||||
archiveClassifier.set("sources")
|
|
||||||
//from(sourceSets.main.get().allSource)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create empty jar for javadoc classifier to satisfy maven requirements
|
|
||||||
val stubJavadoc by tasks.registering(Jar::class) {
|
|
||||||
archiveClassifier.set("javadoc")
|
|
||||||
}
|
|
||||||
|
|
||||||
extensions.findByType<KotlinMultiplatformExtension>()?.apply {
|
|
||||||
|
|
||||||
targets.forEach { target ->
|
|
||||||
val publication = publications.findByName(target.name) as MavenPublication
|
|
||||||
|
|
||||||
// Patch publications with fake javadoc
|
|
||||||
publication.artifact(stubJavadoc)
|
|
||||||
|
|
||||||
// Remove gradle metadata publishing from all targets which are not native
|
|
||||||
// if (target.platformType.name != "native") {
|
|
||||||
// publication.gradleModuleMetadataFile = null
|
|
||||||
// tasks.matching { it.name == "generateMetadataFileFor${name.capitalize()}Publication" }.all {
|
|
||||||
// onlyIf { false }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// dokka {
|
||||||
|
// outputFormat = "html"
|
||||||
|
// outputDirectory = javadoc.destinationDir
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// task dokkaJar (type: Jar, dependsOn: dokka) {
|
||||||
|
// from javadoc . destinationDir
|
||||||
|
// classifier = "javadoc"
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
18
buildSrc/build.gradle.kts
Normal file
18
buildSrc/build.gradle.kts
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
plugins {
|
||||||
|
`kotlin-dsl`
|
||||||
|
}
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
gradlePluginPortal()
|
||||||
|
jcenter()
|
||||||
|
}
|
||||||
|
|
||||||
|
val kotlinVersion = "1.3.31"
|
||||||
|
|
||||||
|
// Add plugins used in buildSrc as dependencies, also we should specify version only here
|
||||||
|
dependencies {
|
||||||
|
implementation("org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlinVersion")
|
||||||
|
implementation("org.jfrog.buildinfo:build-info-extractor-gradle:4.9.5")
|
||||||
|
implementation("com.jfrog.bintray.gradle:gradle-bintray-plugin:1.8.4")
|
||||||
|
implementation("com.moowork.gradle:gradle-node-plugin:1.3.1")
|
||||||
|
}
|
0
buildSrc/settings.gradle.kts
Normal file
0
buildSrc/settings.gradle.kts
Normal file
10
buildSrc/src/main/kotlin/Versions.kt
Normal file
10
buildSrc/src/main/kotlin/Versions.kt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
// Instead of defining runtime properties and use them dynamically
|
||||||
|
// define version in buildSrc and have autocompletion and compile-time check
|
||||||
|
// Also dependencies itself can be moved here
|
||||||
|
object Versions {
|
||||||
|
val ioVersion = "0.1.8"
|
||||||
|
val coroutinesVersion = "1.2.1"
|
||||||
|
val atomicfuVersion = "0.12.6"
|
||||||
|
// This version is not used and IDEA shows this property as unused
|
||||||
|
val dokkaVersion = "0.9.18"
|
||||||
|
}
|
38
buildSrc/src/main/kotlin/artifactory-config.gradle.kts
Normal file
38
buildSrc/src/main/kotlin/artifactory-config.gradle.kts
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import groovy.lang.GroovyObject
|
||||||
|
import org.jfrog.gradle.plugin.artifactory.dsl.PublisherConfig
|
||||||
|
import org.jfrog.gradle.plugin.artifactory.dsl.ResolverConfig
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
id("com.jfrog.artifactory")
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactory {
|
||||||
|
val artifactoryUser: String? by project
|
||||||
|
val artifactoryPassword: String? by project
|
||||||
|
val artifactoryContextUrl = "http://npm.mipt.ru:8081/artifactory"
|
||||||
|
|
||||||
|
setContextUrl(artifactoryContextUrl)//The base Artifactory URL if not overridden by the publisher/resolver
|
||||||
|
publish(delegateClosureOf<PublisherConfig> {
|
||||||
|
repository(delegateClosureOf<GroovyObject> {
|
||||||
|
setProperty("repoKey", "gradle-dev-local")
|
||||||
|
setProperty("username", artifactoryUser)
|
||||||
|
setProperty("password", artifactoryPassword)
|
||||||
|
})
|
||||||
|
|
||||||
|
defaults(delegateClosureOf<GroovyObject>{
|
||||||
|
invokeMethod("publications", arrayOf("jvm", "js", "kotlinMultiplatform", "metadata"))
|
||||||
|
//TODO: This property is not available for ArtifactoryTask
|
||||||
|
//setProperty("publishBuildInfo", false)
|
||||||
|
setProperty("publishArtifacts", true)
|
||||||
|
setProperty("publishPom", true)
|
||||||
|
setProperty("publishIvy", false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
resolve(delegateClosureOf<ResolverConfig> {
|
||||||
|
repository(delegateClosureOf<GroovyObject> {
|
||||||
|
setProperty("repoKey", "gradle-dev")
|
||||||
|
setProperty("username", artifactoryUser)
|
||||||
|
setProperty("password", artifactoryPassword)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
97
buildSrc/src/main/kotlin/bintray-config.gradle.kts
Normal file
97
buildSrc/src/main/kotlin/bintray-config.gradle.kts
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
@file:Suppress("UnstableApiUsage")
|
||||||
|
|
||||||
|
import com.jfrog.bintray.gradle.BintrayExtension.PackageConfig
|
||||||
|
import com.jfrog.bintray.gradle.BintrayExtension.VersionConfig
|
||||||
|
|
||||||
|
// Old bintray.gradle script converted to real Gradle plugin (precompiled script plugin)
|
||||||
|
// It now has own dependencies and support type safe accessors
|
||||||
|
// Syntax is pretty close to what we had in Groovy
|
||||||
|
// (excluding Property.set and bintray dynamic configs)
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
id("com.jfrog.bintray")
|
||||||
|
`maven-publish`
|
||||||
|
}
|
||||||
|
|
||||||
|
val vcs = "https://github.com/mipt-npm/kmath"
|
||||||
|
|
||||||
|
// Configure publishing
|
||||||
|
publishing {
|
||||||
|
repositories {
|
||||||
|
maven("https://bintray.com/mipt-npm/scientifik")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each publication we have in this project
|
||||||
|
publications.filterIsInstance<MavenPublication>().forEach { publication ->
|
||||||
|
|
||||||
|
// use type safe pom config GSL insterad of old dynamic
|
||||||
|
publication.pom {
|
||||||
|
name.set(project.name)
|
||||||
|
description.set(project.description)
|
||||||
|
url.set(vcs)
|
||||||
|
|
||||||
|
licenses {
|
||||||
|
license {
|
||||||
|
name.set("The Apache Software License, Version 2.0")
|
||||||
|
url.set("http://www.apache.org/licenses/LICENSE-2.0.txt")
|
||||||
|
distribution.set("repo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
developers {
|
||||||
|
developer {
|
||||||
|
id.set("MIPT-NPM")
|
||||||
|
name.set("MIPT nuclear physics methods laboratory")
|
||||||
|
organization.set("MIPT")
|
||||||
|
organizationUrl.set("http://npm.mipt.ru")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
scm {
|
||||||
|
url.set(vcs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bintray {
|
||||||
|
// delegates for runtime properties
|
||||||
|
val bintrayUser: String? by project
|
||||||
|
val bintrayApiKey: String? by project
|
||||||
|
user = bintrayUser ?: System.getenv("BINTRAY_USER")
|
||||||
|
key = bintrayApiKey ?: System.getenv("BINTRAY_API_KEY")
|
||||||
|
publish = true
|
||||||
|
override = true // for multi-platform Kotlin/Native publishing
|
||||||
|
|
||||||
|
// We have to use delegateClosureOf because bintray supports only dynamic groovy syntax
|
||||||
|
// this is a problem of this plugin
|
||||||
|
pkg(delegateClosureOf<PackageConfig> {
|
||||||
|
userOrg = "mipt-npm"
|
||||||
|
repo = "scientifik"
|
||||||
|
name = "scientifik.kmath"
|
||||||
|
issueTrackerUrl = "https://github.com/mipt-npm/kmath/issues"
|
||||||
|
setLicenses("Apache-2.0")
|
||||||
|
vcsUrl = vcs
|
||||||
|
version(delegateClosureOf<VersionConfig> {
|
||||||
|
name = project.version.toString()
|
||||||
|
vcsTag = project.version.toString()
|
||||||
|
released = java.util.Date().toString()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
tasks {
|
||||||
|
bintrayUpload {
|
||||||
|
dependsOn(publishToMavenLocal)
|
||||||
|
doFirst {
|
||||||
|
setPublications(project.publishing.publications
|
||||||
|
.filterIsInstance<MavenPublication>()
|
||||||
|
.filter { !it.name.contains("-test") && it.name != "kotlinMultiplatform" }
|
||||||
|
.map {
|
||||||
|
println("""Uploading artifact "${it.groupId}:${it.artifactId}:${it.version}" from publication "${it.name}""")
|
||||||
|
it.name //https://github.com/bintray/gradle-bintray-plugin/issues/256
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
44
buildSrc/src/main/kotlin/js-test.gradle.kts
Normal file
44
buildSrc/src/main/kotlin/js-test.gradle.kts
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import com.moowork.gradle.node.npm.NpmTask
|
||||||
|
import com.moowork.gradle.node.task.NodeTask
|
||||||
|
import org.gradle.kotlin.dsl.*
|
||||||
|
import org.jetbrains.kotlin.gradle.tasks.Kotlin2JsCompile
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
id("com.moowork.node")
|
||||||
|
kotlin("multiplatform")
|
||||||
|
}
|
||||||
|
|
||||||
|
node {
|
||||||
|
nodeModulesDir = file("$buildDir/node_modules")
|
||||||
|
}
|
||||||
|
|
||||||
|
val compileKotlinJs by tasks.getting(Kotlin2JsCompile::class)
|
||||||
|
val compileTestKotlinJs by tasks.getting(Kotlin2JsCompile::class)
|
||||||
|
|
||||||
|
val populateNodeModules by tasks.registering(Copy::class) {
|
||||||
|
dependsOn(compileKotlinJs)
|
||||||
|
from(compileKotlinJs.destinationDir)
|
||||||
|
|
||||||
|
kotlin.js().compilations["test"].runtimeDependencyFiles.forEach {
|
||||||
|
if (it.exists() && !it.isDirectory) {
|
||||||
|
from(zipTree(it.absolutePath).matching { include("*.js") })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
into("$buildDir/node_modules")
|
||||||
|
}
|
||||||
|
|
||||||
|
val installMocha by tasks.registering(NpmTask::class) {
|
||||||
|
setWorkingDir(buildDir)
|
||||||
|
setArgs(listOf("install", "mocha"))
|
||||||
|
}
|
||||||
|
|
||||||
|
val runMocha by tasks.registering(NodeTask::class) {
|
||||||
|
dependsOn(compileTestKotlinJs, populateNodeModules, installMocha)
|
||||||
|
setScript(file("$buildDir/node_modules/mocha/bin/mocha"))
|
||||||
|
setArgs(listOf(compileTestKotlinJs.outputFile))
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks["jsTest"].dependsOn(runMocha)
|
||||||
|
|
||||||
|
|
118
buildSrc/src/main/kotlin/multiplatform-config.gradle.kts
Normal file
118
buildSrc/src/main/kotlin/multiplatform-config.gradle.kts
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import org.gradle.kotlin.dsl.*
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
kotlin("multiplatform")
|
||||||
|
`maven-publish`
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
kotlin {
|
||||||
|
jvm {
|
||||||
|
compilations.all {
|
||||||
|
kotlinOptions {
|
||||||
|
jvmTarget = "1.8"
|
||||||
|
// This was used in kmath-koma, but probably if we need it better to apply it for all modules
|
||||||
|
freeCompilerArgs = listOf("-progressive")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
js {
|
||||||
|
compilations.all {
|
||||||
|
kotlinOptions {
|
||||||
|
metaInfo = true
|
||||||
|
sourceMap = true
|
||||||
|
sourceMapEmbedSources = "always"
|
||||||
|
moduleKind = "commonjs"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compilations.named("main") {
|
||||||
|
kotlinOptions {
|
||||||
|
main = "call"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceSets.invoke {
|
||||||
|
commonMain {
|
||||||
|
dependencies {
|
||||||
|
api(kotlin("stdlib"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
commonTest {
|
||||||
|
dependencies {
|
||||||
|
implementation(kotlin("test-common"))
|
||||||
|
implementation(kotlin("test-annotations-common"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"jvmMain" {
|
||||||
|
dependencies {
|
||||||
|
api(kotlin("stdlib-jdk8"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"jvmTest" {
|
||||||
|
dependencies {
|
||||||
|
implementation(kotlin("test"))
|
||||||
|
implementation(kotlin("test-junit"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"jsMain" {
|
||||||
|
dependencies {
|
||||||
|
api(kotlin("stdlib-js"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"jsTest" {
|
||||||
|
dependencies {
|
||||||
|
implementation(kotlin("test-js"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
targets.all {
|
||||||
|
sourceSets.all {
|
||||||
|
languageSettings.progressiveMode = true
|
||||||
|
languageSettings.enableLanguageFeature("InlineClasses")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty jar for sources classifier to satisfy maven requirements
|
||||||
|
val stubSources by tasks.registering(Jar::class){
|
||||||
|
archiveClassifier.set("sources")
|
||||||
|
//from(sourceSets.main.get().allSource)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty jar for javadoc classifier to satisfy maven requirements
|
||||||
|
val stubJavadoc by tasks.registering(Jar::class) {
|
||||||
|
archiveClassifier.set("javadoc")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
publishing {
|
||||||
|
|
||||||
|
publications.filterIsInstance<MavenPublication>().forEach { publication ->
|
||||||
|
if (publication.name == "kotlinMultiplatform") {
|
||||||
|
// for our root metadata publication, set artifactId with a package and project name
|
||||||
|
publication.artifactId = project.name
|
||||||
|
} else {
|
||||||
|
// for targets, set artifactId with a package, project name and target name (e.g. iosX64)
|
||||||
|
publication.artifactId = "${project.name}-${publication.name}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
targets.all {
|
||||||
|
val publication = publications.findByName(name) as MavenPublication
|
||||||
|
|
||||||
|
// Patch publications with fake javadoc
|
||||||
|
publication.artifact(stubJavadoc.get())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply JS test configuration
|
||||||
|
val runJsTests by ext(false)
|
||||||
|
|
||||||
|
if (runJsTests) {
|
||||||
|
apply(plugin = "js-test")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1 +1,19 @@
|
|||||||
**TODO**
|
## Basic linear algebra layout
|
||||||
|
|
||||||
|
Kmath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared
|
||||||
|
in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple
|
||||||
|
back-ends. The new operations added as extensions to contexts instead of being member functions of data structures.
|
||||||
|
|
||||||
|
Two major contexts used for linear algebra and hyper-geometry:
|
||||||
|
|
||||||
|
* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its typealias `Point` used for geometry).
|
||||||
|
|
||||||
|
* `MatrixContext` forms a space-like context for 2d-structures. It does not store matrix size and therefore does not implement
|
||||||
|
`Space` interface (it is not possible to create zero element without knowing the matrix size).
|
||||||
|
|
||||||
|
## Vector spaces
|
||||||
|
|
||||||
|
|
||||||
|
## Matrix operations
|
||||||
|
|
||||||
|
## Back-end overview
|
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
Binary file not shown.
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
|||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-5.0-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-5.4-bin.zip
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
18
gradlew
vendored
18
gradlew
vendored
@ -1,5 +1,21 @@
|
|||||||
#!/usr/bin/env sh
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
#
|
||||||
|
# Copyright 2015 the original author or authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
##
|
##
|
||||||
## Gradle start up script for UN*X
|
## Gradle start up script for UN*X
|
||||||
@ -28,7 +44,7 @@ APP_NAME="Gradle"
|
|||||||
APP_BASE_NAME=`basename "$0"`
|
APP_BASE_NAME=`basename "$0"`
|
||||||
|
|
||||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||||
DEFAULT_JVM_OPTS='"-Xmx64m"'
|
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||||
|
|
||||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||||
MAX_FD="maximum"
|
MAX_FD="maximum"
|
||||||
|
18
gradlew.bat
vendored
18
gradlew.bat
vendored
@ -1,3 +1,19 @@
|
|||||||
|
@rem
|
||||||
|
@rem Copyright 2015 the original author or authors.
|
||||||
|
@rem
|
||||||
|
@rem Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@rem you may not use this file except in compliance with the License.
|
||||||
|
@rem You may obtain a copy of the License at
|
||||||
|
@rem
|
||||||
|
@rem http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
@rem
|
||||||
|
@rem Unless required by applicable law or agreed to in writing, software
|
||||||
|
@rem distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
@rem See the License for the specific language governing permissions and
|
||||||
|
@rem limitations under the License.
|
||||||
|
@rem
|
||||||
|
|
||||||
@if "%DEBUG%" == "" @echo off
|
@if "%DEBUG%" == "" @echo off
|
||||||
@rem ##########################################################################
|
@rem ##########################################################################
|
||||||
@rem
|
@rem
|
||||||
@ -14,7 +30,7 @@ set APP_BASE_NAME=%~n0
|
|||||||
set APP_HOME=%DIRNAME%
|
set APP_HOME=%DIRNAME%
|
||||||
|
|
||||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||||
set DEFAULT_JVM_OPTS="-Xmx64m"
|
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
|
||||||
|
|
||||||
@rem Find java.exe
|
@rem Find java.exe
|
||||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||||
|
@ -7,7 +7,7 @@ description = "Commons math binding for kmath"
|
|||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api(project(":kmath-sequential"))
|
api(project(":kmath-coroutines"))
|
||||||
api("org.apache.commons:commons-math3:3.6.1")
|
api("org.apache.commons:commons-math3:3.6.1")
|
||||||
testImplementation("org.jetbrains.kotlin:kotlin-test")
|
testImplementation("org.jetbrains.kotlin:kotlin-test")
|
||||||
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
|
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
|
||||||
|
@ -3,13 +3,14 @@ package scientifik.kmath.linear
|
|||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
import org.apache.commons.math3.linear.RealMatrix
|
import org.apache.commons.math3.linear.RealMatrix
|
||||||
import org.apache.commons.math3.linear.RealVector
|
import org.apache.commons.math3.linear.RealVector
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
|
||||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : Matrix<Double> {
|
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||||
override val rowNum: Int get() = origin.rowDimension
|
override val rowNum: Int get() = origin.rowDimension
|
||||||
override val colNum: Int get() = origin.columnDimension
|
override val colNum: Int get() = origin.columnDimension
|
||||||
|
|
||||||
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
||||||
if(origin is DiagonalMatrix) yield(DiagonalFeature)
|
if (origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||||
}.toSet()
|
}.toSet()
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||||
@ -26,7 +27,7 @@ fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
|||||||
CMMatrix(Array2DRowRealMatrix(array))
|
CMMatrix(Array2DRowRealMatrix(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun RealMatrix.toMatrix() = CMMatrix(this)
|
fun RealMatrix.asMatrix() = CMMatrix(this)
|
||||||
|
|
||||||
class CMVector(val origin: RealVector) : Point<Double> {
|
class CMVector(val origin: RealVector) : Point<Double> {
|
||||||
override val size: Int get() = origin.dimension
|
override val size: Int get() = origin.dimension
|
||||||
@ -45,46 +46,31 @@ fun Point<Double>.toCM(): CMVector = if (this is CMVector) {
|
|||||||
|
|
||||||
fun RealVector.toPoint() = CMVector(this)
|
fun RealVector.toPoint() = CMVector(this)
|
||||||
|
|
||||||
object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
|
object CMMatrixContext : MatrixContext<Double> {
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
||||||
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
|
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
|
||||||
return CMMatrix(Array2DRowRealMatrix(array))
|
return CMMatrix(Array2DRowRealMatrix(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
|
|
||||||
val decomposition = LUDecomposition(a.toCM().origin)
|
|
||||||
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
|
|
||||||
val decomposition = LUDecomposition(a.toCM().origin)
|
|
||||||
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun inverse(a: Matrix<Double>): CMMatrix {
|
|
||||||
val decomposition = LUDecomposition(a.toCM().origin)
|
|
||||||
return decomposition.solver.inverse.toMatrix()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun Matrix<Double>.dot(other: Matrix<Double>) =
|
override fun Matrix<Double>.dot(other: Matrix<Double>) =
|
||||||
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
||||||
|
|
||||||
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
|
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
|
||||||
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
|
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
|
||||||
|
|
||||||
|
|
||||||
override fun Matrix<Double>.unaryMinus(): CMMatrix =
|
override fun Matrix<Double>.unaryMinus(): CMMatrix =
|
||||||
produce(rowNum, colNum) { i, j -> -get(i, j) }
|
produce(rowNum, colNum) { i, j -> -get(i, j) }
|
||||||
|
|
||||||
override fun Matrix<Double>.plus(b: Matrix<Double>) =
|
override fun add(a: Matrix<Double>, b: Matrix<Double>) =
|
||||||
CMMatrix(this.toCM().origin.multiply(b.toCM().origin))
|
CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
|
||||||
|
|
||||||
override fun Matrix<Double>.minus(b: Matrix<Double>) =
|
override fun Matrix<Double>.minus(b: Matrix<Double>) =
|
||||||
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
|
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
|
||||||
|
|
||||||
override fun Matrix<Double>.times(value: Double) =
|
override fun multiply(a: Matrix<Double>, k: Number) =
|
||||||
CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble()))
|
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
|
||||||
|
|
||||||
|
override fun Matrix<Double>.times(value: Double): Matrix<Double> = produce(rowNum,colNum){i,j-> get(i,j)*value}
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import org.apache.commons.math3.linear.*
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
|
||||||
|
enum class CMDecomposition {
|
||||||
|
LUP,
|
||||||
|
QR,
|
||||||
|
RRQR,
|
||||||
|
EIGEN,
|
||||||
|
CHOLESKY
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun CMMatrixContext.solver(a: Matrix<Double>, decomposition: CMDecomposition = CMDecomposition.LUP) =
|
||||||
|
when (decomposition) {
|
||||||
|
CMDecomposition.LUP -> LUDecomposition(a.toCM().origin).solver
|
||||||
|
CMDecomposition.RRQR -> RRQRDecomposition(a.toCM().origin).solver
|
||||||
|
CMDecomposition.QR -> QRDecomposition(a.toCM().origin).solver
|
||||||
|
CMDecomposition.EIGEN -> EigenDecomposition(a.toCM().origin).solver
|
||||||
|
CMDecomposition.CHOLESKY -> CholeskyDecomposition(a.toCM().origin).solver
|
||||||
|
}
|
||||||
|
|
||||||
|
fun CMMatrixContext.solve(
|
||||||
|
a: Matrix<Double>,
|
||||||
|
b: Matrix<Double>,
|
||||||
|
decomposition: CMDecomposition = CMDecomposition.LUP
|
||||||
|
) = solver(a, decomposition).solve(b.toCM().origin).asMatrix()
|
||||||
|
|
||||||
|
fun CMMatrixContext.solve(
|
||||||
|
a: Matrix<Double>,
|
||||||
|
b: Point<Double>,
|
||||||
|
decomposition: CMDecomposition = CMDecomposition.LUP
|
||||||
|
) = solver(a, decomposition).solve(b.toCM().origin).toPoint()
|
||||||
|
|
||||||
|
fun CMMatrixContext.inverse(
|
||||||
|
a: Matrix<Double>,
|
||||||
|
decomposition: CMDecomposition = CMDecomposition.LUP
|
||||||
|
) = solver(a, decomposition).inverse.asMatrix()
|
@ -1,10 +1,12 @@
|
|||||||
package scientifik.kmath.transform
|
package scientifik.kmath.transform
|
||||||
|
|
||||||
|
import kotlinx.coroutines.FlowPreview
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
import kotlinx.coroutines.flow.map
|
||||||
import org.apache.commons.math3.transform.*
|
import org.apache.commons.math3.transform.*
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.sequential.Processor
|
import scientifik.kmath.streaming.chunked
|
||||||
import scientifik.kmath.sequential.Producer
|
import scientifik.kmath.streaming.spread
|
||||||
import scientifik.kmath.sequential.map
|
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
|
|
||||||
|
|
||||||
@ -33,54 +35,75 @@ object Transformations {
|
|||||||
fun fourier(
|
fun fourier(
|
||||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||||
direction: TransformType = TransformType.FORWARD
|
direction: TransformType = TransformType.FORWARD
|
||||||
): BufferTransform<Complex, Complex> = {
|
): SuspendBufferTransform<Complex, Complex> = {
|
||||||
FastFourierTransformer(normalization).transform(it.toArray(), direction).asBuffer()
|
FastFourierTransformer(normalization).transform(it.toArray(), direction).asBuffer()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun realFourier(
|
fun realFourier(
|
||||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||||
direction: TransformType = TransformType.FORWARD
|
direction: TransformType = TransformType.FORWARD
|
||||||
): BufferTransform<Double, Complex> = {
|
): SuspendBufferTransform<Double, Complex> = {
|
||||||
FastFourierTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
FastFourierTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun sine(
|
fun sine(
|
||||||
normalization: DstNormalization = DstNormalization.STANDARD_DST_I,
|
normalization: DstNormalization = DstNormalization.STANDARD_DST_I,
|
||||||
direction: TransformType = TransformType.FORWARD
|
direction: TransformType = TransformType.FORWARD
|
||||||
): BufferTransform<Double, Double> = {
|
): SuspendBufferTransform<Double, Double> = {
|
||||||
FastSineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
FastSineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun cosine(
|
fun cosine(
|
||||||
normalization: DctNormalization = DctNormalization.STANDARD_DCT_I,
|
normalization: DctNormalization = DctNormalization.STANDARD_DCT_I,
|
||||||
direction: TransformType = TransformType.FORWARD
|
direction: TransformType = TransformType.FORWARD
|
||||||
): BufferTransform<Double, Double> = {
|
): SuspendBufferTransform<Double, Double> = {
|
||||||
FastCosineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
FastCosineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun hadamard(
|
fun hadamard(
|
||||||
direction: TransformType = TransformType.FORWARD
|
direction: TransformType = TransformType.FORWARD
|
||||||
): BufferTransform<Double, Double> = {
|
): SuspendBufferTransform<Double, Double> = {
|
||||||
FastHadamardTransformer().transform(it.asArray(), direction).asBuffer()
|
FastHadamardTransformer().transform(it.asArray(), direction).asBuffer()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Process given [Producer] with commons-math fft transformation
|
* Process given [Flow] with commons-math fft transformation
|
||||||
*/
|
*/
|
||||||
fun Producer<Buffer<Complex>>.FFT(
|
@FlowPreview
|
||||||
|
fun Flow<Buffer<Complex>>.FFT(
|
||||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||||
direction: TransformType = TransformType.FORWARD
|
direction: TransformType = TransformType.FORWARD
|
||||||
): Processor<Buffer<Complex>, Buffer<Complex>> {
|
): Flow<Buffer<Complex>> {
|
||||||
val transform = Transformations.fourier(normalization, direction)
|
val transform = Transformations.fourier(normalization, direction)
|
||||||
return map { transform(it) }
|
return map { transform(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@FlowPreview
|
||||||
@JvmName("realFFT")
|
@JvmName("realFFT")
|
||||||
fun Producer<Buffer<Double>>.FFT(
|
fun Flow<Buffer<Double>>.FFT(
|
||||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||||
direction: TransformType = TransformType.FORWARD
|
direction: TransformType = TransformType.FORWARD
|
||||||
): Processor<Buffer<Double>, Buffer<Complex>> {
|
): Flow<Buffer<Complex>> {
|
||||||
val transform = Transformations.realFourier(normalization, direction)
|
val transform = Transformations.realFourier(normalization, direction)
|
||||||
return map { transform(it) }
|
return map(transform)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process a continous flow of real numbers in FFT splitting it in chunks of [bufferSize].
|
||||||
|
*/
|
||||||
|
@FlowPreview
|
||||||
|
@JvmName("realFFT")
|
||||||
|
fun Flow<Double>.FFT(
|
||||||
|
bufferSize: Int = Int.MAX_VALUE,
|
||||||
|
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||||
|
direction: TransformType = TransformType.FORWARD
|
||||||
|
): Flow<Complex> {
|
||||||
|
return chunked(bufferSize).FFT(normalization,direction).spread()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map a complex flow into real flow by taking real part of each number
|
||||||
|
*/
|
||||||
|
@FlowPreview
|
||||||
|
fun Flow<Complex>.real(): Flow<Double> = map{it.re}
|
@ -1,51 +1,13 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
`multiplatform-config`
|
||||||
}
|
}
|
||||||
|
|
||||||
val ioVersion: String by rootProject.extra
|
kotlin.sourceSets {
|
||||||
|
commonMain {
|
||||||
|
dependencies {
|
||||||
kotlin {
|
api(project(":kmath-memory"))
|
||||||
jvm()
|
|
||||||
js()
|
|
||||||
|
|
||||||
sourceSets {
|
|
||||||
val commonMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(project(":kmath-memory"))
|
|
||||||
api(kotlin("stdlib"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
val commonTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-common"))
|
|
||||||
implementation(kotlin("test-annotations-common"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jvmMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(kotlin("stdlib-jdk8"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jvmTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test"))
|
|
||||||
implementation(kotlin("test-junit"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(kotlin("stdlib-js"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-js"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// mingwMain {
|
|
||||||
// }
|
|
||||||
// mingwTest {
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
//mingwMain {}
|
||||||
|
//mingwTest {}
|
||||||
}
|
}
|
@ -1,8 +1,8 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import kotlin.jvm.JvmSynthetic
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Basic implementation of Matrix space based on [NDStructure]
|
* Basic implementation of Matrix space based on [NDStructure]
|
||||||
@ -18,6 +18,22 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||||
|
|
||||||
|
override val elementContext = RealField
|
||||||
|
|
||||||
|
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||||
|
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
|
return BufferMatrix(rows, columns, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
class BufferMatrix<T : Any>(
|
class BufferMatrix<T : Any>(
|
||||||
@ -25,7 +41,7 @@ class BufferMatrix<T : Any>(
|
|||||||
override val colNum: Int,
|
override val colNum: Int,
|
||||||
val buffer: Buffer<out T>,
|
val buffer: Buffer<out T>,
|
||||||
override val features: Set<MatrixFeature> = emptySet()
|
override val features: Set<MatrixFeature> = emptySet()
|
||||||
) : Matrix<T> {
|
) : FeaturedMatrix<T> {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (buffer.size != rowNum * colNum) {
|
if (buffer.size != rowNum * colNum) {
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import scientifik.kmath.structures.Structure2D
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A 2d structure plus optional matrix-specific features
|
||||||
|
*/
|
||||||
|
interface FeaturedMatrix<T : Any> : Matrix<T> {
|
||||||
|
|
||||||
|
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||||
|
|
||||||
|
val features: Set<MatrixFeature>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure.
|
||||||
|
*
|
||||||
|
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
|
||||||
|
* add only those features that are valid.
|
||||||
|
*/
|
||||||
|
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||||
|
MatrixContext.real.produce(rows, columns, initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a square matrix from given elements.
|
||||||
|
*/
|
||||||
|
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
|
||||||
|
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||||
|
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
||||||
|
val buffer = elements.asBuffer()
|
||||||
|
return BufferMatrix(size, size, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T : Any> Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
|
||||||
|
|
||||||
|
class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
|
||||||
|
operator fun invoke(vararg elements: T): FeaturedMatrix<T> {
|
||||||
|
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
||||||
|
val buffer = elements.asBuffer()
|
||||||
|
return BufferMatrix(rows, columns, buffer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if matrix has the given feature class
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
|
||||||
|
features.find { it is T } != null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
|
||||||
|
features.filterIsInstance<T>().firstOrNull()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||||
|
*/
|
||||||
|
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||||
|
VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
|
||||||
|
if (i == j) elementContext.one else elementContext.zero
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A virtual matrix of zeroes
|
||||||
|
*/
|
||||||
|
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||||
|
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
|
||||||
|
|
||||||
|
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||||
|
*/
|
||||||
|
fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
|
||||||
|
return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||||
|
this.colNum,
|
||||||
|
this.rowNum,
|
||||||
|
setOf(TransposedFeature(this))
|
||||||
|
) { i, j -> get(j, i) }
|
||||||
|
}
|
||||||
|
|
||||||
|
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }
|
@ -1,26 +1,31 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.structures.MutableBuffer
|
import scientifik.kmath.structures.BufferAccessor2D
|
||||||
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.MutableBufferFactory
|
import scientifik.kmath.structures.Structure2D
|
||||||
import scientifik.kmath.structures.NDStructure
|
import kotlin.reflect.KClass
|
||||||
import scientifik.kmath.structures.get
|
|
||||||
|
|
||||||
class LUPDecomposition<T : Comparable<T>>(
|
/**
|
||||||
private val elementContext: Ring<T>,
|
* Common implementation of [LUPDecompositionFeature]
|
||||||
internal val lu: NDStructure<T>,
|
*/
|
||||||
|
class LUPDecomposition<T : Any>(
|
||||||
|
val context: GenericMatrixContext<T, out Field<T>>,
|
||||||
|
val lu: Structure2D<T>,
|
||||||
val pivot: IntArray,
|
val pivot: IntArray,
|
||||||
private val even: Boolean
|
private val even: Boolean
|
||||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||||
|
|
||||||
|
val elementContext get() = context.elementContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the matrix L of the decomposition.
|
* Returns the matrix L of the decomposition.
|
||||||
*
|
*
|
||||||
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
||||||
*/
|
*/
|
||||||
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
||||||
when {
|
when {
|
||||||
j < i -> lu[i, j]
|
j < i -> lu[i, j]
|
||||||
j == i -> elementContext.one
|
j == i -> elementContext.one
|
||||||
@ -34,7 +39,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
|||||||
*
|
*
|
||||||
* U is an upper-triangular matrix including the diagonal
|
* U is an upper-triangular matrix including the diagonal
|
||||||
*/
|
*/
|
||||||
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
||||||
if (j >= i) lu[i, j] else elementContext.zero
|
if (j >= i) lu[i, j] else elementContext.zero
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,7 +50,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
|||||||
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
||||||
* each row and each column, all other elements being set to [Ring.zero].
|
* each row and each column, all other elements being set to [Ring.zero].
|
||||||
*/
|
*/
|
||||||
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||||
if (j == pivot[i]) elementContext.one else elementContext.zero
|
if (j == pivot[i]) elementContext.one else elementContext.zero
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,70 +67,86 @@ class LUPDecomposition<T : Comparable<T>>(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
||||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
if (value > elementContext.zero) value else with(elementContext) { -value }
|
||||||
val context: GenericMatrixContext<T, F>,
|
|
||||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
|
||||||
val singularityCheck: (T) -> Boolean
|
|
||||||
) : LinearSolver<T> {
|
|
||||||
|
|
||||||
|
|
||||||
private fun abs(value: T) =
|
/**
|
||||||
if (value > context.elementContext.zero) value else with(context.elementContext) { -value }
|
* Create a lup decomposition of generic matrix
|
||||||
|
*/
|
||||||
|
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||||
|
type: KClass<T>,
|
||||||
|
matrix: Matrix<T>,
|
||||||
|
checkSingular: (T) -> Boolean
|
||||||
|
): LUPDecomposition<T> {
|
||||||
|
if (matrix.rowNum != matrix.colNum) {
|
||||||
|
error("LU decomposition supports only square matrices")
|
||||||
|
}
|
||||||
|
|
||||||
fun buildDecomposition(matrix: Matrix<T>): LUPDecomposition<T> {
|
val m = matrix.colNum
|
||||||
if (matrix.rowNum != matrix.colNum) {
|
val pivot = IntArray(matrix.rowNum)
|
||||||
error("LU decomposition supports only square matrices")
|
|
||||||
}
|
|
||||||
|
|
||||||
val m = matrix.colNum
|
//TODO just waits for KEEP-176
|
||||||
val pivot = IntArray(matrix.rowNum)
|
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||||
|
elementContext.run {
|
||||||
|
|
||||||
val lu = Mutable2DStructure.create(matrix.rowNum, matrix.colNum, bufferFactory) { i, j ->
|
val lu = create(matrix)
|
||||||
matrix[i, j]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
with(context.elementContext) {
|
|
||||||
// Initialize permutation array and parity
|
// Initialize permutation array and parity
|
||||||
for (row in 0 until m) {
|
for (row in 0 until m) {
|
||||||
pivot[row] = row
|
pivot[row] = row
|
||||||
}
|
}
|
||||||
var even = true
|
var even = true
|
||||||
|
|
||||||
|
// Initialize permutation array and parity
|
||||||
|
for (row in 0 until m) {
|
||||||
|
pivot[row] = row
|
||||||
|
}
|
||||||
|
|
||||||
// Loop over columns
|
// Loop over columns
|
||||||
for (col in 0 until m) {
|
for (col in 0 until m) {
|
||||||
|
|
||||||
// upper
|
// upper
|
||||||
for (row in 0 until col) {
|
for (row in 0 until col) {
|
||||||
var sum = lu[row, col]
|
val luRow = lu.row(row)
|
||||||
|
var sum = luRow[col]
|
||||||
for (i in 0 until row) {
|
for (i in 0 until row) {
|
||||||
sum -= lu[row, i] * lu[i, col]
|
sum -= luRow[i] * lu[i, col]
|
||||||
}
|
}
|
||||||
lu[row, col] = sum
|
luRow[col] = sum
|
||||||
}
|
}
|
||||||
|
|
||||||
// lower
|
// lower
|
||||||
val max = (col until m).maxBy { row ->
|
var max = col // permutation row
|
||||||
var sum = lu[row, col]
|
var largest = -one
|
||||||
|
for (row in col until m) {
|
||||||
|
val luRow = lu.row(row)
|
||||||
|
var sum = luRow[col]
|
||||||
for (i in 0 until col) {
|
for (i in 0 until col) {
|
||||||
sum -= lu[row, i] * lu[i, col]
|
sum -= luRow[i] * lu[i, col]
|
||||||
}
|
}
|
||||||
lu[row, col] = sum
|
luRow[col] = sum
|
||||||
|
|
||||||
abs(sum)
|
// maintain best permutation choice
|
||||||
} ?: col
|
if (abs(sum) > largest) {
|
||||||
|
largest = abs(sum)
|
||||||
|
max = row
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Singularity check
|
// Singularity check
|
||||||
if (singularityCheck(lu[max, col])) {
|
if (checkSingular(abs(lu[max, col]))) {
|
||||||
error("Singular matrix")
|
error("The matrix is singular")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pivot if necessary
|
// Pivot if necessary
|
||||||
if (max != col) {
|
if (max != col) {
|
||||||
|
val luMax = lu.row(max)
|
||||||
|
val luCol = lu.row(col)
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
lu[max, i] = lu[col, i]
|
val tmp = luMax[i]
|
||||||
lu[col, i] = lu[max, i]
|
luMax[i] = luCol[i]
|
||||||
|
luCol[i] = tmp
|
||||||
}
|
}
|
||||||
val temp = pivot[max]
|
val temp = pivot[max]
|
||||||
pivot[max] = pivot[col]
|
pivot[max] = pivot[col]
|
||||||
@ -136,70 +157,96 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
|
|||||||
// Divide the lower elements by the "winning" diagonal elt.
|
// Divide the lower elements by the "winning" diagonal elt.
|
||||||
val luDiag = lu[col, col]
|
val luDiag = lu[col, col]
|
||||||
for (row in col + 1 until m) {
|
for (row in col + 1 until m) {
|
||||||
lu[row, col] = lu[row, col] / luDiag
|
lu[row, col] /= luDiag
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return LUPDecomposition(context.elementContext, lu, pivot, even)
|
|
||||||
|
return LUPDecomposition(this@lup, lu.collect(), pivot, even)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce a matrix with added decomposition feature
|
|
||||||
*/
|
|
||||||
fun decompose(matrix: Matrix<T>): Matrix<T> {
|
|
||||||
if (matrix.hasFeature<LUPDecomposition<*>>()) {
|
|
||||||
return matrix
|
|
||||||
} else {
|
|
||||||
val decomposition = buildDecomposition(matrix)
|
|
||||||
return VirtualMatrix.wrap(matrix, decomposition)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
|
||||||
if (b.rowNum != a.colNum) {
|
|
||||||
error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use existing decomposition if it is provided by matrix
|
|
||||||
val decomposition = a.getFeature() ?: buildDecomposition(a)
|
|
||||||
|
|
||||||
with(decomposition) {
|
|
||||||
with(context.elementContext) {
|
|
||||||
// Apply permutations to b
|
|
||||||
val bp = Mutable2DStructure.create(a.rowNum, a.colNum, bufferFactory) { i, j ->
|
|
||||||
b[pivot[i], j]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Solve LY = b
|
|
||||||
for (col in 0 until a.rowNum) {
|
|
||||||
for (i in col + 1 until a.rowNum) {
|
|
||||||
for (j in 0 until b.colNum) {
|
|
||||||
bp[i, j] -= bp[col, j] * lu[i, col]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Solve UX = Y
|
|
||||||
for (col in a.rowNum - 1 downTo 0) {
|
|
||||||
for (j in 0 until b.colNum) {
|
|
||||||
bp[col, j] /= lu[col, col]
|
|
||||||
}
|
|
||||||
for (i in 0 until col) {
|
|
||||||
for (j in 0 until b.colNum) {
|
|
||||||
bp[i, j] -= bp[col, j] * lu[i, col]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return context.produce(a.rowNum, a.colNum) { i, j -> bp[i, j] }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun inverse(a: Matrix<T>): Matrix<T> = solve(a, context.one(a.rowNum, a.colNum))
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||||
|
matrix: Matrix<T>,
|
||||||
|
noinline checkSingular: (T) -> Boolean
|
||||||
|
) = lup(T::class, matrix, checkSingular)
|
||||||
|
|
||||||
|
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 }
|
||||||
|
|
||||||
|
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
||||||
|
|
||||||
|
if (matrix.rowNum != pivot.size) {
|
||||||
|
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||||
|
elementContext.run {
|
||||||
|
|
||||||
|
// Apply permutations to b
|
||||||
|
val bp = create { i, j -> zero }
|
||||||
|
|
||||||
|
for (row in 0 until pivot.size) {
|
||||||
|
val bpRow = bp.row(row)
|
||||||
|
val pRow = pivot[row]
|
||||||
|
for (col in 0 until matrix.colNum) {
|
||||||
|
bpRow[col] = matrix[pRow, col]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve LY = b
|
||||||
|
for (col in 0 until pivot.size) {
|
||||||
|
val bpCol = bp.row(col)
|
||||||
|
for (i in col + 1 until pivot.size) {
|
||||||
|
val bpI = bp.row(i)
|
||||||
|
val luICol = lu[i, col]
|
||||||
|
for (j in 0 until matrix.colNum) {
|
||||||
|
bpI[j] -= bpCol[j] * luICol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve UX = Y
|
||||||
|
for (col in pivot.size - 1 downTo 0) {
|
||||||
|
val bpCol = bp.row(col)
|
||||||
|
val luDiag = lu[col, col]
|
||||||
|
for (j in 0 until matrix.colNum) {
|
||||||
|
bpCol[j] /= luDiag
|
||||||
|
}
|
||||||
|
for (i in 0 until col) {
|
||||||
|
val bpI = bp.row(i)
|
||||||
|
val luICol = lu[i, col]
|
||||||
|
for (j in 0 until matrix.colNum) {
|
||||||
|
bpI[j] -= bpCol[j] * luICol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solve a linear equation **a*x = b**
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
|
||||||
|
a: Matrix<T>,
|
||||||
|
b: Matrix<T>,
|
||||||
|
noinline checkSingular: (T) -> Boolean
|
||||||
|
): Matrix<T> {
|
||||||
|
// Use existing decomposition if it is provided by matrix
|
||||||
|
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
|
||||||
|
return decomposition.solve(T::class, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) =
|
||||||
|
solve(a, b) { it < 1e-11 }
|
||||||
|
|
||||||
|
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
||||||
|
matrix: Matrix<T>,
|
||||||
|
noinline checkSingular: (T) -> Boolean
|
||||||
|
) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||||
|
|
||||||
|
fun RealMatrixContext.inverse(matrix: Matrix<Double>) =
|
||||||
|
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
@ -3,6 +3,7 @@ package scientifik.kmath.linear
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Norm
|
import scientifik.kmath.operations.Norm
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
import scientifik.kmath.structures.VirtualBuffer
|
import scientifik.kmath.structures.VirtualBuffer
|
||||||
import scientifik.kmath.structures.asSequence
|
import scientifik.kmath.structures.asSequence
|
||||||
|
|
||||||
@ -10,9 +11,9 @@ import scientifik.kmath.structures.asSequence
|
|||||||
/**
|
/**
|
||||||
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
||||||
*/
|
*/
|
||||||
interface LinearSolver<T : Any> {
|
interface LinearSolver<T : Any> {
|
||||||
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.asMatrix()).asPoint()
|
||||||
fun inverse(a: Matrix<T>): Matrix<T>
|
fun inverse(a: Matrix<T>): Matrix<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,14 +33,14 @@ object VectorL2Norm : Norm<Point<out Number>, Double> {
|
|||||||
typealias RealVector = Vector<Double, RealField>
|
typealias RealVector = Vector<Double, RealField>
|
||||||
typealias RealMatrix = Matrix<Double>
|
typealias RealMatrix = Matrix<Double>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert matrix to vector if it is possible
|
* Convert matrix to vector if it is possible
|
||||||
*/
|
*/
|
||||||
fun <T: Any> Matrix<T>.toVector(): Point<T> =
|
fun <T : Any> Matrix<T>.asPoint(): Point<T> =
|
||||||
if (this.colNum == 1) {
|
if (this.colNum == 1) {
|
||||||
VirtualBuffer(rowNum){ get(it, 0) }
|
VirtualBuffer(rowNum) { get(it, 0) }
|
||||||
} else error("Can't convert matrix with more than one column to vector")
|
} else {
|
||||||
|
error("Can't convert matrix with more than one column to vector")
|
||||||
|
}
|
||||||
|
|
||||||
fun <T: Any> Point<T>.toMatrix(): Matrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
fun <T : Any> Point<T>.asMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
@ -1,213 +0,0 @@
|
|||||||
package scientifik.kmath.linear
|
|
||||||
|
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import scientifik.kmath.operations.Ring
|
|
||||||
import scientifik.kmath.operations.sum
|
|
||||||
import scientifik.kmath.structures.*
|
|
||||||
import scientifik.kmath.structures.Buffer.Companion.boxing
|
|
||||||
import kotlin.math.sqrt
|
|
||||||
|
|
||||||
|
|
||||||
interface MatrixContext<T : Any> {
|
|
||||||
/**
|
|
||||||
* Produce a matrix with this context and given dimensions
|
|
||||||
*/
|
|
||||||
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
|
||||||
|
|
||||||
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
|
||||||
|
|
||||||
infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
|
||||||
|
|
||||||
operator fun Matrix<T>.unaryMinus(): Matrix<T>
|
|
||||||
|
|
||||||
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T>
|
|
||||||
|
|
||||||
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T>
|
|
||||||
|
|
||||||
operator fun Matrix<T>.times(value: T): Matrix<T>
|
|
||||||
|
|
||||||
operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
/**
|
|
||||||
* Non-boxing double matrix
|
|
||||||
*/
|
|
||||||
val real = BufferMatrixContext(RealField, Buffer.Companion::auto)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A structured matrix with custom buffer
|
|
||||||
*/
|
|
||||||
fun <T : Any, R : Ring<T>> buffered(
|
|
||||||
ring: R,
|
|
||||||
bufferFactory: BufferFactory<T> = ::boxing
|
|
||||||
): GenericMatrixContext<T, R> =
|
|
||||||
BufferMatrixContext(ring, bufferFactory)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic buffered matrix, unboxed if it is possible
|
|
||||||
*/
|
|
||||||
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> =
|
|
||||||
buffered(ring, Buffer.Companion::auto)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|
||||||
/**
|
|
||||||
* The ring context for matrix elements
|
|
||||||
*/
|
|
||||||
val elementContext: R
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce a point compatible with matrix space
|
|
||||||
*/
|
|
||||||
fun point(size: Int, initializer: (Int) -> T): Point<T>
|
|
||||||
|
|
||||||
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
|
||||||
//TODO add typed error
|
|
||||||
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
|
||||||
return produce(rowNum, other.colNum) { i, j ->
|
|
||||||
val row = rows[i]
|
|
||||||
val column = other.columns[j]
|
|
||||||
with(elementContext) {
|
|
||||||
sum(row.asSequence().zip(column.asSequence(), ::multiply))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
|
||||||
//TODO add typed error
|
|
||||||
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
|
||||||
return point(rowNum) { i ->
|
|
||||||
val row = rows[i]
|
|
||||||
with(elementContext) {
|
|
||||||
sum(row.asSequence().zip(vector.asSequence(), ::multiply))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Matrix<T>.unaryMinus() =
|
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
|
||||||
|
|
||||||
override operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
|
||||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]")
|
|
||||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
|
||||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
|
||||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
|
||||||
}
|
|
||||||
|
|
||||||
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
|
|
||||||
|
|
||||||
operator fun Number.times(matrix: Matrix<T>): Matrix<T> = matrix * this
|
|
||||||
|
|
||||||
override fun Matrix<T>.times(value: T): Matrix<T> =
|
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Specialized 2-d structure
|
|
||||||
*/
|
|
||||||
interface Matrix<T : Any> : Structure2D<T> {
|
|
||||||
val rowNum: Int
|
|
||||||
val colNum: Int
|
|
||||||
|
|
||||||
val features: Set<MatrixFeature>
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure.
|
|
||||||
*
|
|
||||||
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
|
|
||||||
* add only those features that are valid.
|
|
||||||
*/
|
|
||||||
fun suggestFeature(vararg features: MatrixFeature): Matrix<T>
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
|
||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
|
||||||
|
|
||||||
val rows: Point<Point<T>>
|
|
||||||
get() = VirtualBuffer(rowNum) { i ->
|
|
||||||
VirtualBuffer(colNum) { j -> get(i, j) }
|
|
||||||
}
|
|
||||||
|
|
||||||
val columns: Point<Point<T>>
|
|
||||||
get() = VirtualBuffer(colNum) { j ->
|
|
||||||
VirtualBuffer(rowNum) { i -> get(i, j) }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
|
||||||
for (i in (0 until rowNum)) {
|
|
||||||
for (j in (0 until colNum)) {
|
|
||||||
yield(intArrayOf(i, j) to get(i, j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
|
||||||
MatrixContext.real.produce(rows, columns, initializer)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a square matrix from given elements.
|
|
||||||
*/
|
|
||||||
fun <T : Any> square(vararg elements: T): Matrix<T> {
|
|
||||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
|
||||||
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
|
||||||
val buffer = elements.asBuffer()
|
|
||||||
return BufferMatrix(size, size, buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <T : Any> build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
|
|
||||||
operator fun invoke(vararg elements: T): Matrix<T> {
|
|
||||||
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
|
||||||
val buffer = elements.asBuffer()
|
|
||||||
return BufferMatrix(rows, columns, buffer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if matrix has the given feature class
|
|
||||||
*/
|
|
||||||
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { it is T } != null
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
|
||||||
*/
|
|
||||||
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
|
||||||
*/
|
|
||||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
|
||||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
|
||||||
if (i == j) elementContext.one else elementContext.zero
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A virtual matrix of zeroes
|
|
||||||
*/
|
|
||||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
|
||||||
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
|
|
||||||
|
|
||||||
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
|
||||||
*/
|
|
||||||
fun <T : Any, R : Ring<T>> Matrix<T>.transpose(): Matrix<T> {
|
|
||||||
return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
|
||||||
this.colNum,
|
|
||||||
this.rowNum,
|
|
||||||
setOf(TransposedFeature(this))
|
|
||||||
) { i, j -> get(j, i) }
|
|
||||||
}
|
|
||||||
|
|
||||||
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }
|
|
@ -0,0 +1,105 @@
|
|||||||
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
|
import scientifik.kmath.operations.sum
|
||||||
|
import scientifik.kmath.structures.Buffer
|
||||||
|
import scientifik.kmath.structures.BufferFactory
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import scientifik.kmath.structures.asSequence
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Basic operations on matrices. Operates on [Matrix]
|
||||||
|
*/
|
||||||
|
interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
||||||
|
/**
|
||||||
|
* Produce a matrix with this context and given dimensions
|
||||||
|
*/
|
||||||
|
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||||
|
|
||||||
|
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
||||||
|
|
||||||
|
infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
||||||
|
|
||||||
|
operator fun Matrix<T>.times(value: T): Matrix<T>
|
||||||
|
|
||||||
|
operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* Non-boxing double matrix
|
||||||
|
*/
|
||||||
|
val real = RealMatrixContext
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A structured matrix with custom buffer
|
||||||
|
*/
|
||||||
|
fun <T : Any, R : Ring<T>> buffered(
|
||||||
|
ring: R,
|
||||||
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
|
): GenericMatrixContext<T, R> =
|
||||||
|
BufferMatrixContext(ring, bufferFactory)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatic buffered matrix, unboxed if it is possible
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> =
|
||||||
|
buffered(ring, Buffer.Companion::auto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||||
|
/**
|
||||||
|
* The ring context for matrix elements
|
||||||
|
*/
|
||||||
|
val elementContext: R
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce a point compatible with matrix space
|
||||||
|
*/
|
||||||
|
fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||||
|
|
||||||
|
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||||
|
//TODO add typed error
|
||||||
|
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
||||||
|
return produce(rowNum, other.colNum) { i, j ->
|
||||||
|
val row = rows[i]
|
||||||
|
val column = other.columns[j]
|
||||||
|
with(elementContext) {
|
||||||
|
sum(row.asSequence().zip(column.asSequence(), ::multiply))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||||
|
//TODO add typed error
|
||||||
|
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
||||||
|
return point(rowNum) { i ->
|
||||||
|
val row = rows[i]
|
||||||
|
with(elementContext) {
|
||||||
|
sum(row.asSequence().zip(vector.asSequence(), ::multiply))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Matrix<T>.unaryMinus() =
|
||||||
|
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||||
|
|
||||||
|
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
|
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]")
|
||||||
|
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) + b[i, j] } }
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||||
|
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
||||||
|
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||||
|
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) * k } }
|
||||||
|
|
||||||
|
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||||
|
|
||||||
|
override fun Matrix<T>.times(value: T): Matrix<T> =
|
||||||
|
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
||||||
|
}
|
@ -25,7 +25,7 @@ object UnitFeature : MatrixFeature
|
|||||||
* Inverted matrix feature
|
* Inverted matrix feature
|
||||||
*/
|
*/
|
||||||
interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
||||||
val inverse: Matrix<T>
|
val inverse: FeaturedMatrix<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -54,9 +54,9 @@ object UFeature: MatrixFeature
|
|||||||
* TODO add documentation
|
* TODO add documentation
|
||||||
*/
|
*/
|
||||||
interface LUPDecompositionFeature<T : Any> : MatrixFeature {
|
interface LUPDecompositionFeature<T : Any> : MatrixFeature {
|
||||||
val l: Matrix<T>
|
val l: FeaturedMatrix<T>
|
||||||
val u: Matrix<T>
|
val u: FeaturedMatrix<T>
|
||||||
val p: Matrix<T>
|
val p: FeaturedMatrix<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO add sparse matrix feature
|
//TODO add sparse matrix feature
|
@ -1,40 +0,0 @@
|
|||||||
package scientifik.kmath.linear
|
|
||||||
|
|
||||||
import scientifik.kmath.structures.MutableBuffer
|
|
||||||
import scientifik.kmath.structures.MutableBufferFactory
|
|
||||||
import scientifik.kmath.structures.MutableNDStructure
|
|
||||||
|
|
||||||
class Mutable2DStructure<T>(val rowNum: Int, val colNum: Int, val buffer: MutableBuffer<T>) : MutableNDStructure<T> {
|
|
||||||
override val shape: IntArray
|
|
||||||
get() = intArrayOf(rowNum, colNum)
|
|
||||||
|
|
||||||
operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
|
||||||
for (i in 0 until rowNum) {
|
|
||||||
for (j in 0 until colNum) {
|
|
||||||
yield(intArrayOf(i, j) to get(i, j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
operator fun set(i: Int, j: Int, value: T) {
|
|
||||||
buffer[i * colNum + j] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun set(index: IntArray, value: T) = set(index[0], index[1], value)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
fun <T> create(
|
|
||||||
rowNum: Int,
|
|
||||||
colNum: Int,
|
|
||||||
bufferFactory: MutableBufferFactory<T>,
|
|
||||||
init: (i: Int, j: Int) -> T
|
|
||||||
): Mutable2DStructure<T> {
|
|
||||||
val buffer = bufferFactory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
|
||||||
return Mutable2DStructure(rowNum, colNum, buffer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,68 +4,23 @@ import scientifik.kmath.operations.RealField
|
|||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
import scientifik.kmath.operations.SpaceElement
|
import scientifik.kmath.operations.SpaceElement
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.BufferFactory
|
|
||||||
import scientifik.kmath.structures.asSequence
|
import scientifik.kmath.structures.asSequence
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
|
||||||
typealias Point<T> = Buffer<T>
|
typealias Point<T> = Buffer<T>
|
||||||
|
|
||||||
/**
|
|
||||||
* A linear space for vectors.
|
|
||||||
* Could be used on any point-like structure
|
|
||||||
*/
|
|
||||||
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
|
||||||
|
|
||||||
val size: Int
|
fun <T : Any, S : Space<T>> BufferVectorSpace<T, S>.produceElement(initializer: (Int) -> T): Vector<T, S> =
|
||||||
|
BufferVector(this, produce(initializer))
|
||||||
val space: S
|
|
||||||
|
|
||||||
fun produce(initializer: (Int) -> T): Point<T>
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce a space-element of this vector space for expressions
|
|
||||||
*/
|
|
||||||
fun produceElement(initializer: (Int) -> T): Vector<T, S>
|
|
||||||
|
|
||||||
override val zero: Point<T> get() = produce { space.zero }
|
|
||||||
|
|
||||||
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
|
|
||||||
|
|
||||||
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
|
|
||||||
|
|
||||||
//TODO add basis
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
|
|
||||||
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Non-boxing double vector space
|
|
||||||
*/
|
|
||||||
fun real(size: Int): BufferVectorSpace<Double, RealField> {
|
|
||||||
return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, RealField, Buffer.Companion::auto) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A structured vector space with custom buffer
|
|
||||||
*/
|
|
||||||
fun <T : Any, S : Space<T>> buffered(
|
|
||||||
size: Int,
|
|
||||||
space: S,
|
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
|
||||||
): VectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic buffered vector, unboxed if it is possible
|
|
||||||
*/
|
|
||||||
inline fun <reified T : Any, S : Space<T>> smart(size: Int, space: S): VectorSpace<T, S> =
|
|
||||||
buffered(size, space, Buffer.Companion::auto)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@JvmName("produceRealElement")
|
||||||
|
fun BufferVectorSpace<Double, RealField>.produceElement(initializer: (Int) -> Double): Vector<Double, RealField> =
|
||||||
|
BufferVector(this, produce(initializer))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A point coupled to the linear space
|
* A point coupled to the linear space
|
||||||
*/
|
*/
|
||||||
|
@Deprecated("Use VectorContext instead")
|
||||||
interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, Vector<T, S>, VectorSpace<T, S>>, Point<T> {
|
interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, Vector<T, S>, VectorSpace<T, S>>, Point<T> {
|
||||||
override val size: Int get() = context.size
|
override val size: Int get() = context.size
|
||||||
|
|
||||||
@ -90,16 +45,7 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, Vector<T, S>, V
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data class BufferVectorSpace<T : Any, S : Space<T>>(
|
@Deprecated("Use VectorContext instead")
|
||||||
override val size: Int,
|
|
||||||
override val space: S,
|
|
||||||
val bufferFactory: BufferFactory<T>
|
|
||||||
) : VectorSpace<T, S> {
|
|
||||||
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
|
|
||||||
override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace<T, S>, val buffer: Buffer<T>) :
|
data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace<T, S>, val buffer: Buffer<T>) :
|
||||||
Vector<T, S> {
|
Vector<T, S> {
|
||||||
|
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.structures.Buffer
|
||||||
|
import scientifik.kmath.structures.BufferFactory
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A linear space for vectors.
|
||||||
|
* Could be used on any point-like structure
|
||||||
|
*/
|
||||||
|
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||||
|
|
||||||
|
val size: Int
|
||||||
|
|
||||||
|
val space: S
|
||||||
|
|
||||||
|
fun produce(initializer: (Int) -> T): Point<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Produce a space-element of this vector space for expressions
|
||||||
|
*/
|
||||||
|
//fun produceElement(initializer: (Int) -> T): Vector<T, S>
|
||||||
|
|
||||||
|
override val zero: Point<T> get() = produce { space.zero }
|
||||||
|
|
||||||
|
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
|
||||||
|
|
||||||
|
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
|
||||||
|
|
||||||
|
//TODO add basis
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
|
||||||
|
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Non-boxing double vector space
|
||||||
|
*/
|
||||||
|
fun real(size: Int): BufferVectorSpace<Double, RealField> {
|
||||||
|
return realSpaceCache.getOrPut(size) {
|
||||||
|
BufferVectorSpace(
|
||||||
|
size,
|
||||||
|
RealField,
|
||||||
|
Buffer.Companion::auto
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A structured vector space with custom buffer
|
||||||
|
*/
|
||||||
|
fun <T : Any, S : Space<T>> buffered(
|
||||||
|
size: Int,
|
||||||
|
space: S,
|
||||||
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
|
) = BufferVectorSpace(size, space, bufferFactory)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatic buffered vector, unboxed if it is possible
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, S : Space<T>> auto(size: Int, space: S): VectorSpace<T, S> =
|
||||||
|
buffered(size, space, Buffer.Companion::auto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BufferVectorSpace<T : Any, S : Space<T>>(
|
||||||
|
override val size: Int,
|
||||||
|
override val space: S,
|
||||||
|
val bufferFactory: BufferFactory<T>
|
||||||
|
) : VectorSpace<T, S> {
|
||||||
|
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
|
||||||
|
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
||||||
|
}
|
@ -1,11 +1,23 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
|
||||||
class VirtualMatrix<T : Any>(
|
class VirtualMatrix<T : Any>(
|
||||||
override val rowNum: Int,
|
override val rowNum: Int,
|
||||||
override val colNum: Int,
|
override val colNum: Int,
|
||||||
override val features: Set<MatrixFeature> = emptySet(),
|
override val features: Set<MatrixFeature> = emptySet(),
|
||||||
val generator: (i: Int, j: Int) -> T
|
val generator: (i: Int, j: Int) -> T
|
||||||
) : Matrix<T> {
|
) : FeaturedMatrix<T> {
|
||||||
|
|
||||||
|
constructor(rowNum: Int, colNum: Int, vararg features: MatrixFeature, generator: (i: Int, j: Int) -> T) : this(
|
||||||
|
rowNum,
|
||||||
|
colNum,
|
||||||
|
setOf(*features),
|
||||||
|
generator
|
||||||
|
)
|
||||||
|
|
||||||
|
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
override fun get(i: Int, j: Int): T = generator(i, j)
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||||
@ -13,7 +25,7 @@ class VirtualMatrix<T : Any>(
|
|||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other !is Matrix<*>) return false
|
if (other !is FeaturedMatrix<*>) return false
|
||||||
|
|
||||||
if (rowNum != other.rowNum) return false
|
if (rowNum != other.rowNum) return false
|
||||||
if (colNum != other.colNum) return false
|
if (colNum != other.colNum) return false
|
||||||
@ -34,7 +46,7 @@ class VirtualMatrix<T : Any>(
|
|||||||
/**
|
/**
|
||||||
* Wrap a matrix adding additional features to it
|
* Wrap a matrix adding additional features to it
|
||||||
*/
|
*/
|
||||||
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): Matrix<T> {
|
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
|
||||||
return if (matrix is VirtualMatrix) {
|
return if (matrix is VirtualMatrix) {
|
||||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.sequential
|
package scientifik.kmath.misc
|
||||||
|
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
@ -0,0 +1,218 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implementation of backward-mode automatic differentiation.
|
||||||
|
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||||
|
* with respect to this variable.
|
||||||
|
*/
|
||||||
|
open class Variable(val value: Double) {
|
||||||
|
constructor(x: Number) : this(x.toDouble())
|
||||||
|
}
|
||||||
|
|
||||||
|
class DerivationResult(value: Double, val deriv: Map<Variable, Double>) : Variable(value) {
|
||||||
|
fun deriv(variable: Variable) = deriv[variable] ?: 0.0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* compute divergence
|
||||||
|
*/
|
||||||
|
fun div() = deriv.values.sum()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute a gradient for variables in given order
|
||||||
|
*/
|
||||||
|
fun grad(vararg variables: Variable): Point<Double> = if (variables.isEmpty()) {
|
||||||
|
error("Variable order is not provided for gradient construction")
|
||||||
|
} else {
|
||||||
|
variables.map(::deriv).toDoubleArray().asBuffer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||||
|
*
|
||||||
|
* The partial derivatives are placed in argument `d` variable
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* ```
|
||||||
|
* val x = Variable(2) // define variable(s) and their values
|
||||||
|
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||||
|
* assertEquals(17.0, y.x) // the value of result (y)
|
||||||
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
fun deriv(body: AutoDiffField.() -> Variable): DerivationResult =
|
||||||
|
AutoDiffContext().run {
|
||||||
|
val result = body()
|
||||||
|
result.d = 1.0 // computing derivative w.r.t result
|
||||||
|
runBackwardPass()
|
||||||
|
DerivationResult(result.value, derivatives)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
abstract class AutoDiffField : Field<Variable> {
|
||||||
|
/**
|
||||||
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
|
*
|
||||||
|
* For example, implementation of `sin` function is:
|
||||||
|
*
|
||||||
|
* ```
|
||||||
|
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||||
|
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
abstract fun <R> derive(value: R, block: (R) -> Unit): R
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variable accessing inner state of derivatives.
|
||||||
|
* Use this function in inner builders to avoid creating additional derivative bindings
|
||||||
|
*/
|
||||||
|
abstract var Variable.d: Double
|
||||||
|
|
||||||
|
abstract fun variable(value: Double): Variable
|
||||||
|
|
||||||
|
// Overloads for Double constants
|
||||||
|
|
||||||
|
operator fun Number.plus(that: Variable): Variable = derive(variable(this.toDouble() + that.value)) { z ->
|
||||||
|
that.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
operator fun Variable.plus(b: Number): Variable = b.plus(this)
|
||||||
|
|
||||||
|
operator fun Number.minus(that: Variable): Variable = derive(variable(this.toDouble() - that.value)) { z ->
|
||||||
|
that.d -= z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
operator fun Variable.minus(that: Number): Variable = derive(variable(this.value - that.toDouble())) { z ->
|
||||||
|
this.d += z.d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatic Differentiation context class.
|
||||||
|
*/
|
||||||
|
private class AutoDiffContext : AutoDiffField() {
|
||||||
|
|
||||||
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
|
private var stack = arrayOfNulls<Any?>(8)
|
||||||
|
private var sp = 0
|
||||||
|
|
||||||
|
internal val derivatives = HashMap<Variable, Double>()
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variable coupled with its derivative. For internal use only
|
||||||
|
*/
|
||||||
|
class VariableWithDeriv(x: Double, var d: Double = 0.0) : Variable(x)
|
||||||
|
|
||||||
|
|
||||||
|
override fun variable(value: Double): Variable = VariableWithDeriv(value)
|
||||||
|
|
||||||
|
override var Variable.d: Double
|
||||||
|
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: 0.0
|
||||||
|
set(value) {
|
||||||
|
if (this is VariableWithDeriv) {
|
||||||
|
d = value
|
||||||
|
} else {
|
||||||
|
derivatives[this] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
override fun <R> derive(value: R, block: (R) -> Unit): R {
|
||||||
|
// save block to stack for backward pass
|
||||||
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
|
stack[sp++] = block
|
||||||
|
stack[sp++] = value
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
fun runBackwardPass() {
|
||||||
|
while (sp > 0) {
|
||||||
|
val value = stack[--sp]
|
||||||
|
val block = stack[--sp] as (Any?) -> Unit
|
||||||
|
block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
|
|
||||||
|
override fun add(a: Variable, b: Variable): Variable =
|
||||||
|
derive(variable(a.value + b.value)) { z ->
|
||||||
|
a.d += z.d
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: Variable, b: Variable): Variable =
|
||||||
|
derive(variable(a.value * b.value)) { z ->
|
||||||
|
a.d += z.d * b.value
|
||||||
|
b.d += z.d * a.value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun divide(a: Variable, b: Variable): Variable =
|
||||||
|
derive(Variable(a.value / b.value)) { z ->
|
||||||
|
a.d += z.d / b.value
|
||||||
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: Variable, k: Number): Variable =
|
||||||
|
derive(variable(k.toDouble() * a.value)) { z ->
|
||||||
|
a.d += z.d * k.toDouble()
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: Variable get() = Variable(0.0)
|
||||||
|
override val one: Variable get() = Variable(1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
|
// x ^ 2
|
||||||
|
fun AutoDiffField.sqr(x: Variable): Variable = derive(variable(x.value * x.value)) { z ->
|
||||||
|
x.d += z.d * 2 * x.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable(sqrt(x.value))) { z ->
|
||||||
|
x.d += z.d * 0.5 / z.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable(x.value.pow(y))) { z ->
|
||||||
|
x.d += z.d * y * x.value.pow(y - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
fun AutoDiffField.exp(x: Variable): Variable = derive(variable(kotlin.math.exp(x.value))) { z ->
|
||||||
|
x.d += z.d * z.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.value))) { z ->
|
||||||
|
x.d += z.d / x.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
fun AutoDiffField.sin(x: Variable): Variable = derive(variable(kotlin.math.sin(x.value))) { z ->
|
||||||
|
x.d += z.d * kotlin.math.cos(x.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
fun AutoDiffField.cos(x: Variable): Variable = derive(variable(kotlin.math.cos(x.value))) { z ->
|
||||||
|
x.d -= z.d * kotlin.math.sin(x.value)
|
||||||
|
}
|
@ -55,23 +55,14 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
|||||||
/**
|
/**
|
||||||
* Complex number class
|
* Complex number class
|
||||||
*/
|
*/
|
||||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField> {
|
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
||||||
override fun unwrap(): Complex = this
|
override fun unwrap(): Complex = this
|
||||||
|
|
||||||
override fun Complex.wrap(): Complex = this
|
override fun Complex.wrap(): Complex = this
|
||||||
|
|
||||||
override val context: ComplexField get() = ComplexField
|
override val context: ComplexField get() = ComplexField
|
||||||
|
|
||||||
/**
|
override fun compareTo(other: Complex): Int = abs.compareTo(other.abs)
|
||||||
* A complex conjugate
|
|
||||||
*/
|
|
||||||
val conjugate: Complex get() = Complex(re, -im)
|
|
||||||
|
|
||||||
val square: Double get() = re * re + im * im
|
|
||||||
|
|
||||||
val abs: Double get() = sqrt(square)
|
|
||||||
|
|
||||||
val theta: Double get() = atan(im / re)
|
|
||||||
|
|
||||||
companion object : MemorySpec<Complex> {
|
companion object : MemorySpec<Complex> {
|
||||||
override val objectSize: Int = 16
|
override val objectSize: Int = 16
|
||||||
@ -86,6 +77,26 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A complex conjugate
|
||||||
|
*/
|
||||||
|
val Complex.conjugate: Complex get() = Complex(re, -im)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Square of the complex number
|
||||||
|
*/
|
||||||
|
val Complex.square: Double get() = re * re + im * im
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Absolute value of complex number
|
||||||
|
*/
|
||||||
|
val Complex.abs: Double get() = sqrt(square)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An angle between vector represented by complex number and X axis
|
||||||
|
*/
|
||||||
|
val Complex.theta: Double get() = atan(im / re)
|
||||||
|
|
||||||
fun Double.toComplex() = Complex(this, 0.0)
|
fun Double.toComplex() = Complex(this, 0.0)
|
||||||
|
|
||||||
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.math.abs
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -31,79 +32,152 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
|||||||
/**
|
/**
|
||||||
* A field for double without boxing. Does not produce appropriate field element
|
* A field for double without boxing. Does not produce appropriate field element
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
|
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
|
||||||
override val zero: Double = 0.0
|
override val zero: Double = 0.0
|
||||||
override fun add(a: Double, b: Double): Double = a + b
|
override inline fun add(a: Double, b: Double) = a + b
|
||||||
override fun multiply(a: Double, b: Double): Double = a * b
|
override inline fun multiply(a: Double, b: Double) = a * b
|
||||||
override fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
override inline fun multiply(a: Double, k: Number) = a * k.toDouble()
|
||||||
|
|
||||||
override val one: Double = 1.0
|
override val one: Double = 1.0
|
||||||
override fun divide(a: Double, b: Double): Double = a / b
|
override inline fun divide(a: Double, b: Double) = a / b
|
||||||
|
|
||||||
override fun sin(arg: Double): Double = kotlin.math.sin(arg)
|
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
|
||||||
override fun cos(arg: Double): Double = kotlin.math.cos(arg)
|
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
|
||||||
|
|
||||||
override fun power(arg: Double, pow: Number): Double = arg.pow(pow.toDouble())
|
override inline fun power(arg: Double, pow: Number) = arg.pow(pow.toDouble())
|
||||||
|
|
||||||
override fun exp(arg: Double): Double = kotlin.math.exp(arg)
|
override inline fun exp(arg: Double) = kotlin.math.exp(arg)
|
||||||
override fun ln(arg: Double): Double = kotlin.math.ln(arg)
|
override inline fun ln(arg: Double) = kotlin.math.ln(arg)
|
||||||
|
|
||||||
override fun norm(arg: Double): Double = kotlin.math.abs(arg)
|
override inline fun norm(arg: Double) = abs(arg)
|
||||||
|
|
||||||
override fun Double.unaryMinus(): Double = -this
|
override inline fun Double.unaryMinus() = -this
|
||||||
|
|
||||||
override fun Double.minus(b: Double): Double = this - b
|
override inline fun Double.plus(b: Double) = this + b
|
||||||
|
|
||||||
|
override inline fun Double.minus(b: Double) = this - b
|
||||||
|
|
||||||
|
override inline fun Double.times(b: Double) = this * b
|
||||||
|
|
||||||
|
override inline fun Double.div(b: Double) = this / b
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
|
object FloatField : Field<Float>, ExtendedFieldOperations<Float>, Norm<Float, Float> {
|
||||||
|
override val zero: Float = 0f
|
||||||
|
override inline fun add(a: Float, b: Float) = a + b
|
||||||
|
override inline fun multiply(a: Float, b: Float) = a * b
|
||||||
|
override inline fun multiply(a: Float, k: Number) = a * k.toFloat()
|
||||||
|
|
||||||
|
override val one: Float = 1f
|
||||||
|
override inline fun divide(a: Float, b: Float) = a / b
|
||||||
|
|
||||||
|
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
|
||||||
|
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
|
||||||
|
|
||||||
|
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
|
||||||
|
|
||||||
|
override inline fun exp(arg: Float) = kotlin.math.exp(arg)
|
||||||
|
override inline fun ln(arg: Float) = kotlin.math.ln(arg)
|
||||||
|
|
||||||
|
override inline fun norm(arg: Float) = abs(arg)
|
||||||
|
|
||||||
|
override inline fun Float.unaryMinus() = -this
|
||||||
|
|
||||||
|
override inline fun Float.plus(b: Float) = this + b
|
||||||
|
|
||||||
|
override inline fun Float.minus(b: Float) = this - b
|
||||||
|
|
||||||
|
override inline fun Float.times(b: Float) = this * b
|
||||||
|
|
||||||
|
override inline fun Float.div(b: Float) = this / b
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for [Int] without boxing. Does not produce corresponding field element
|
* A field for [Int] without boxing. Does not produce corresponding field element
|
||||||
*/
|
*/
|
||||||
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object IntRing : Ring<Int>, Norm<Int, Int> {
|
object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||||
override val zero: Int = 0
|
override val zero: Int = 0
|
||||||
override fun add(a: Int, b: Int): Int = a + b
|
override inline fun add(a: Int, b: Int) = a + b
|
||||||
override fun multiply(a: Int, b: Int): Int = a * b
|
override inline fun multiply(a: Int, b: Int) = a * b
|
||||||
override fun multiply(a: Int, k: Number): Int = (k * a)
|
override inline fun multiply(a: Int, k: Number) = (k * a)
|
||||||
override val one: Int = 1
|
override val one: Int = 1
|
||||||
|
|
||||||
override fun norm(arg: Int): Int = arg
|
override inline fun norm(arg: Int) = abs(arg)
|
||||||
|
|
||||||
|
override inline fun Int.unaryMinus() = -this
|
||||||
|
|
||||||
|
override inline fun Int.plus(b: Int): Int = this + b
|
||||||
|
|
||||||
|
override inline fun Int.minus(b: Int): Int = this - b
|
||||||
|
|
||||||
|
override inline fun Int.times(b: Int): Int = this * b
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for [Short] without boxing. Does not produce appropriate field element
|
* A field for [Short] without boxing. Does not produce appropriate field element
|
||||||
*/
|
*/
|
||||||
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object ShortRing : Ring<Short>, Norm<Short, Short> {
|
object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||||
override val zero: Short = 0
|
override val zero: Short = 0
|
||||||
override fun add(a: Short, b: Short): Short = (a + b).toShort()
|
override inline fun add(a: Short, b: Short) = (a + b).toShort()
|
||||||
override fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
override inline fun multiply(a: Short, b: Short) = (a * b).toShort()
|
||||||
override fun multiply(a: Short, k: Number): Short = (a * k)
|
override inline fun multiply(a: Short, k: Number) = (a * k)
|
||||||
override val one: Short = 1
|
override val one: Short = 1
|
||||||
|
|
||||||
override fun norm(arg: Short): Short = arg
|
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||||
|
|
||||||
|
override inline fun Short.unaryMinus() = (-this).toShort()
|
||||||
|
|
||||||
|
override inline fun Short.plus(b: Short) = (this + b).toShort()
|
||||||
|
|
||||||
|
override inline fun Short.minus(b: Short) = (this - b).toShort()
|
||||||
|
|
||||||
|
override inline fun Short.times(b: Short) = (this * b).toShort()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for [Byte] values
|
* A field for [Byte] values
|
||||||
*/
|
*/
|
||||||
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||||
override val zero: Byte = 0
|
override val zero: Byte = 0
|
||||||
override fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
override inline fun add(a: Byte, b: Byte) = (a + b).toByte()
|
||||||
override fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte()
|
||||||
override fun multiply(a: Byte, k: Number): Byte = (a * k)
|
override inline fun multiply(a: Byte, k: Number) = (a * k)
|
||||||
override val one: Byte = 1
|
override val one: Byte = 1
|
||||||
|
|
||||||
override fun norm(arg: Byte): Byte = arg
|
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||||
|
|
||||||
|
override inline fun Byte.unaryMinus() = (-this).toByte()
|
||||||
|
|
||||||
|
override inline fun Byte.plus(b: Byte) = (this + b).toByte()
|
||||||
|
|
||||||
|
override inline fun Byte.minus(b: Byte) = (this - b).toByte()
|
||||||
|
|
||||||
|
override inline fun Byte.times(b: Byte) = (this * b).toByte()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for [Long] values
|
* A field for [Long] values
|
||||||
*/
|
*/
|
||||||
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object LongRing : Ring<Long>, Norm<Long, Long> {
|
object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||||
override val zero: Long = 0
|
override val zero: Long = 0
|
||||||
override fun add(a: Long, b: Long): Long = (a + b)
|
override inline fun add(a: Long, b: Long) = (a + b)
|
||||||
override fun multiply(a: Long, b: Long): Long = (a * b)
|
override inline fun multiply(a: Long, b: Long) = (a * b)
|
||||||
override fun multiply(a: Long, k: Number): Long = (a * k)
|
override inline fun multiply(a: Long, k: Number) = (a * k)
|
||||||
override val one: Long = 1
|
override val one: Long = 1
|
||||||
|
|
||||||
override fun norm(arg: Long): Long = arg
|
override fun norm(arg: Long): Long = abs(arg)
|
||||||
|
|
||||||
|
override inline fun Long.unaryMinus() = (-this)
|
||||||
|
|
||||||
|
override inline fun Long.plus(b: Long) = (this + b)
|
||||||
|
|
||||||
|
override inline fun Long.minus(b: Long) = (this - b)
|
||||||
|
|
||||||
|
override inline fun Long.times(b: Long) = (this * b)
|
||||||
}
|
}
|
@ -0,0 +1,45 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context that allows to operate on a [MutableBuffer] as on 2d array
|
||||||
|
*/
|
||||||
|
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
||||||
|
|
||||||
|
operator fun Buffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
|
||||||
|
|
||||||
|
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
||||||
|
set(i + colNum * j, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun create(init: (i: Int, j: Int) -> T) =
|
||||||
|
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||||
|
|
||||||
|
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] }
|
||||||
|
|
||||||
|
//TODO optimize wrapper
|
||||||
|
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
||||||
|
NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D()
|
||||||
|
|
||||||
|
|
||||||
|
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
|
||||||
|
override val size: Int get() = colNum
|
||||||
|
|
||||||
|
override fun get(index: Int): T = buffer[rowIndex, index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: T) {
|
||||||
|
buffer[rowIndex, index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) }
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get row
|
||||||
|
*/
|
||||||
|
fun MutableBuffer<T>.row(i: Int) = Row(this, i)
|
||||||
|
}
|
@ -2,6 +2,7 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
|
||||||
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
||||||
@ -36,18 +37,20 @@ interface Buffer<T> {
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
||||||
|
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer {
|
||||||
|
val array = DoubleArray(size) { initializer(it) }
|
||||||
|
return DoubleBuffer(array)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a boxing buffer of given type
|
* Create a boxing buffer of given type
|
||||||
*/
|
*/
|
||||||
inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer))
|
inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer))
|
||||||
|
|
||||||
/**
|
|
||||||
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
|
||||||
*/
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||||
//TODO add resolution based on Annotation or companion resolution
|
//TODO add resolution based on Annotation or companion resolution
|
||||||
return when (T::class) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
||||||
@ -56,6 +59,13 @@ interface Buffer<T> {
|
|||||||
else -> boxing(size, initializer)
|
else -> boxing(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> =
|
||||||
|
auto(T::class, size, initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,12 +88,9 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
MutableListBuffer(MutableList(size, initializer))
|
MutableListBuffer(MutableList(size, initializer))
|
||||||
|
|
||||||
/**
|
|
||||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
|
||||||
*/
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||||
return when (T::class) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||||
@ -91,6 +98,17 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
else -> boxing(size, initializer)
|
else -> boxing(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
|
auto(T::class, size, initializer)
|
||||||
|
|
||||||
|
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
||||||
|
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +254,10 @@ inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
|||||||
* Useful when one needs single element from the buffer.
|
* Useful when one needs single element from the buffer.
|
||||||
*/
|
*/
|
||||||
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
||||||
override fun get(index: Int): T = generator(index)
|
override fun get(index: Int): T {
|
||||||
|
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
|
||||||
|
return generator(index)
|
||||||
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
|
override fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
|
||||||
|
|
||||||
@ -262,3 +283,5 @@ fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) {
|
|||||||
* Typealias for buffer transformations
|
* Typealias for buffer transformations
|
||||||
*/
|
*/
|
||||||
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
|
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
|
||||||
|
|
||||||
|
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>
|
@ -73,6 +73,7 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
|
|||||||
*/
|
*/
|
||||||
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
||||||
|
|
||||||
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
||||||
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
||||||
|
|
||||||
@ -90,6 +91,7 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
|
|||||||
*/
|
*/
|
||||||
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||||
|
|
||||||
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) }
|
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) }
|
||||||
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) }
|
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) }
|
||||||
}
|
}
|
||||||
@ -109,6 +111,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
|||||||
*/
|
*/
|
||||||
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||||
|
|
||||||
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) }
|
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) }
|
||||||
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
||||||
|
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
|
||||||
interface NDStructure<T> {
|
interface NDStructure<T> {
|
||||||
|
|
||||||
val shape: IntArray
|
val shape: IntArray
|
||||||
|
|
||||||
val dimension
|
val dimension get() = shape.size
|
||||||
get() = shape.size
|
|
||||||
|
|
||||||
operator fun get(index: IntArray): T
|
operator fun get(index: IntArray): T
|
||||||
|
|
||||||
@ -41,6 +43,9 @@ interface NDStructure<T> {
|
|||||||
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
|
inline fun <T : Any> auto(type: KClass<T>, strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||||
|
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
fun <T> build(
|
fun <T> build(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
@ -49,6 +54,13 @@ interface NDStructure<T> {
|
|||||||
|
|
||||||
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
|
@JvmName("autoVarArg")
|
||||||
|
inline fun <reified T : Any> auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||||
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
|
inline fun <T : Any> auto(type: KClass<T>, vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||||
|
auto(type, DefaultStrides(shape), initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,7 +70,7 @@ interface MutableNDStructure<T> : NDStructure<T> {
|
|||||||
operator fun set(index: IntArray, value: T)
|
operator fun set(index: IntArray, value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
||||||
elements().forEach { (index, oldValue) ->
|
elements().forEach { (index, oldValue) ->
|
||||||
this[index] = action(index, oldValue)
|
this[index] = action(index, oldValue)
|
||||||
}
|
}
|
||||||
|
@ -27,20 +27,6 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
|
|||||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
|
||||||
*/
|
|
||||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
|
||||||
Structure1DWrapper(this)
|
|
||||||
} else {
|
|
||||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <T> NDBuffer<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
|
||||||
Buffer1DWrapper(this.buffer)
|
|
||||||
} else {
|
|
||||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structure wrapper for buffer
|
* A structure wrapper for buffer
|
||||||
@ -56,39 +42,21 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
|||||||
override fun get(index: Int): T = buffer.get(index)
|
override fun get(index: Int): T = buffer.get(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
|
*/
|
||||||
|
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||||
|
if( this is NDBuffer){
|
||||||
|
Buffer1DWrapper(this.buffer)
|
||||||
|
} else {
|
||||||
|
Structure1DWrapper(this)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent this buffer as 1D structure
|
* Represent this buffer as 1D structure
|
||||||
*/
|
*/
|
||||||
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||||
|
|
||||||
/**
|
|
||||||
* A structure that is guaranteed to be two-dimensional
|
|
||||||
*/
|
|
||||||
interface Structure2D<T> : NDStructure<T> {
|
|
||||||
operator fun get(i: Int, j: Int): T
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T {
|
|
||||||
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
|
||||||
return get(index[0], index[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A 2D wrapper for nd-structure
|
|
||||||
*/
|
|
||||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
|
||||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
|
||||||
|
|
||||||
override val shape: IntArray get() = structure.shape
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
|
||||||
*/
|
|
||||||
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
|
||||||
Structure2DWrapper(this)
|
|
||||||
} else {
|
|
||||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
|
||||||
}
|
|
@ -0,0 +1,79 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A structure that is guaranteed to be two-dimensional
|
||||||
|
*/
|
||||||
|
interface Structure2D<T> : NDStructure<T> {
|
||||||
|
val rowNum: Int get() = shape[0]
|
||||||
|
val colNum: Int get() = shape[1]
|
||||||
|
|
||||||
|
operator fun get(i: Int, j: Int): T
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T {
|
||||||
|
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
||||||
|
return get(index[0], index[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val rows: Buffer<Buffer<T>>
|
||||||
|
get() = VirtualBuffer(rowNum) { i ->
|
||||||
|
VirtualBuffer(colNum) { j -> get(i, j) }
|
||||||
|
}
|
||||||
|
|
||||||
|
val columns: Buffer<Buffer<T>>
|
||||||
|
get() = VirtualBuffer(colNum) { j ->
|
||||||
|
VirtualBuffer(rowNum) { i -> get(i, j) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||||
|
for (i in (0 until rowNum)) {
|
||||||
|
for (j in (0 until colNum)) {
|
||||||
|
yield(intArrayOf(i, j) to get(i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A 2D wrapper for nd-structure
|
||||||
|
*/
|
||||||
|
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||||
|
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
|
|
||||||
|
override val shape: IntArray get() = structure.shape
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
|
*/
|
||||||
|
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
||||||
|
Structure2DWrapper(this)
|
||||||
|
} else {
|
||||||
|
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise.
|
||||||
|
*/
|
||||||
|
fun <T> Structure2D<T>.as1D() = if (colNum == 1) {
|
||||||
|
object : Structure1D<T> {
|
||||||
|
override fun get(index: Int): T = get(index, 0)
|
||||||
|
|
||||||
|
override val shape: IntArray get() = intArrayOf(rowNum)
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = elements()
|
||||||
|
|
||||||
|
override val size: Int get() = rowNum
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
error("Can't convert matrix with more than one column to vector")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
typealias Matrix<T> = Structure2D<T>
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class MatrixTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testVectorToMatrix() {
|
fun testVectorToMatrix() {
|
||||||
val vector = Vector.real(5) { it.toDouble() }
|
val vector = Vector.real(5) { it.toDouble() }
|
||||||
val matrix = vector.toMatrix()
|
val matrix = vector.asMatrix()
|
||||||
assertEquals(4.0, matrix[4, 0])
|
assertEquals(4.0, matrix[4, 0])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,8 +34,8 @@ class MatrixTest {
|
|||||||
val vector1 = Vector.real(5) { it.toDouble() }
|
val vector1 = Vector.real(5) { it.toDouble() }
|
||||||
val vector2 = Vector.real(5) { 5 - it.toDouble() }
|
val vector2 = Vector.real(5) { 5 - it.toDouble() }
|
||||||
|
|
||||||
val matrix1 = vector1.toMatrix()
|
val matrix1 = vector1.asMatrix()
|
||||||
val matrix2 = vector2.toMatrix().transpose()
|
val matrix2 = vector2.asMatrix().transpose()
|
||||||
val product = MatrixContext.real.run { matrix1 dot matrix2 }
|
val product = MatrixContext.real.run { matrix1 dot matrix2 }
|
||||||
|
|
||||||
|
|
||||||
@ -51,4 +52,26 @@ class MatrixTest {
|
|||||||
|
|
||||||
assertEquals(2.0, matrix[1, 2])
|
assertEquals(2.0, matrix[1, 2])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMatrixExtension() {
|
||||||
|
val transitionMatrix: Matrix<Double> = VirtualMatrix(6, 6) { row, col ->
|
||||||
|
when {
|
||||||
|
col == 0 -> .50
|
||||||
|
row + 1 == col -> .50
|
||||||
|
row == 5 && col == 5 -> 1.0
|
||||||
|
else -> 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
infix fun Matrix<Double>.pow(power: Int): Matrix<Double> {
|
||||||
|
var res = this
|
||||||
|
repeat(power - 1) {
|
||||||
|
res = res dot this
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
val toTenthPower = transitionMatrix pow 10
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,16 +1,37 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
@ExperimentalContracts
|
||||||
class RealLUSolverTest {
|
class RealLUSolverTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testInvertOne() {
|
fun testInvertOne() {
|
||||||
val matrix = MatrixContext.real.one(2, 2)
|
val matrix = MatrixContext.real.one(2, 2)
|
||||||
val inverted = LUSolver.real.inverse(matrix)
|
val inverted = MatrixContext.real.inverse(matrix)
|
||||||
assertEquals(matrix, inverted)
|
assertEquals(matrix, inverted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDecomposition() {
|
||||||
|
val matrix = Matrix.square(
|
||||||
|
3.0, 1.0,
|
||||||
|
1.0, 3.0
|
||||||
|
)
|
||||||
|
|
||||||
|
MatrixContext.real.run {
|
||||||
|
val lup = lup(matrix)
|
||||||
|
|
||||||
|
//Check determinant
|
||||||
|
assertEquals(8.0, lup.determinant)
|
||||||
|
|
||||||
|
assertEquals(lup.p dot matrix, lup.l dot lup.u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testInvert() {
|
fun testInvert() {
|
||||||
val matrix = Matrix.square(
|
val matrix = Matrix.square(
|
||||||
@ -18,18 +39,7 @@ class RealLUSolverTest {
|
|||||||
1.0, 3.0
|
1.0, 3.0
|
||||||
)
|
)
|
||||||
|
|
||||||
val decomposed = LUSolver.real.decompose(matrix)
|
val inverted = MatrixContext.real.inverse(matrix)
|
||||||
val decomposition = decomposed.getFeature<LUPDecomposition<Double>>()!!
|
|
||||||
|
|
||||||
//Check determinant
|
|
||||||
assertEquals(8.0, decomposition.determinant)
|
|
||||||
|
|
||||||
//Check decomposition
|
|
||||||
with(MatrixContext.real) {
|
|
||||||
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
|
||||||
}
|
|
||||||
|
|
||||||
val inverted = LUSolver.real.inverse(decomposed)
|
|
||||||
|
|
||||||
val expected = Matrix.square(
|
val expected = Matrix.square(
|
||||||
0.375, -0.125,
|
0.375, -0.125,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package scientifik.kmath.misc
|
package scientifik.kmath.misc
|
||||||
|
|
||||||
import scientifik.kmath.sequential.cumulativeSum
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
@ -0,0 +1,176 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class AutoDiffTest {
|
||||||
|
@Test
|
||||||
|
fun testPlusX2() {
|
||||||
|
val x = Variable(3) // diff w.r.t this x at 3
|
||||||
|
val y = deriv { x + x }
|
||||||
|
assertEquals(6.0, y.value) // y = x + x = 6
|
||||||
|
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlus() {
|
||||||
|
// two variables
|
||||||
|
val x = Variable(2)
|
||||||
|
val y = Variable(3)
|
||||||
|
val z = deriv { x + y }
|
||||||
|
assertEquals(5.0, z.value) // z = x + y = 5
|
||||||
|
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||||
|
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
// two variables
|
||||||
|
val x = Variable(7)
|
||||||
|
val y = Variable(3)
|
||||||
|
val z = deriv { x - y }
|
||||||
|
assertEquals(4.0, z.value) // z = x - y = 4
|
||||||
|
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||||
|
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMulX2() {
|
||||||
|
val x = Variable(3) // diff w.r.t this x at 3
|
||||||
|
val y = deriv { x * x }
|
||||||
|
assertEquals(9.0, y.value) // y = x * x = 9
|
||||||
|
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqr() {
|
||||||
|
val x = Variable(3)
|
||||||
|
val y = deriv { sqr(x) }
|
||||||
|
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||||
|
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrSqr() {
|
||||||
|
val x = Variable(2)
|
||||||
|
val y = deriv { sqr(sqr(x)) }
|
||||||
|
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||||
|
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testX3() {
|
||||||
|
val x = Variable(2) // diff w.r.t this x at 2
|
||||||
|
val y = deriv { x * x * x }
|
||||||
|
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||||
|
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDiv() {
|
||||||
|
val x = Variable(5)
|
||||||
|
val y = Variable(2)
|
||||||
|
val z = deriv { x / y }
|
||||||
|
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||||
|
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
||||||
|
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPow3() {
|
||||||
|
val x = Variable(2) // diff w.r.t this x at 2
|
||||||
|
val y = deriv { pow(x, 3) }
|
||||||
|
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||||
|
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPowFull() {
|
||||||
|
val x = Variable(2)
|
||||||
|
val y = Variable(3)
|
||||||
|
val z = deriv { pow(x, y) }
|
||||||
|
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||||
|
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||||
|
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFromPaper() {
|
||||||
|
val x = Variable(3)
|
||||||
|
val y = deriv { 2 * x + x * x * x }
|
||||||
|
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||||
|
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testInnerVariable() {
|
||||||
|
val x = Variable(1)
|
||||||
|
val y = deriv {
|
||||||
|
Variable(1) * x
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLongChain() {
|
||||||
|
val n = 10_000
|
||||||
|
val x = Variable(1)
|
||||||
|
val y = deriv {
|
||||||
|
var res = Variable(1)
|
||||||
|
for (i in 1..n) res *= x
|
||||||
|
res
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testExample() {
|
||||||
|
val x = Variable(2)
|
||||||
|
val y = deriv { sqr(x) + 5 * x + 3 }
|
||||||
|
assertEquals(17.0, y.value) // the value of result (y)
|
||||||
|
assertEquals(9.0, y.deriv(x)) // dy/dx
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrt() {
|
||||||
|
val x = Variable(16)
|
||||||
|
val y = deriv { sqrt(x) }
|
||||||
|
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||||
|
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSin() {
|
||||||
|
val x = Variable(PI / 6)
|
||||||
|
val y = deriv { sin(x) }
|
||||||
|
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||||
|
assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCos() {
|
||||||
|
val x = Variable(PI / 6)
|
||||||
|
val y = deriv { cos(x) }
|
||||||
|
assertApprox(kotlin.math.sqrt(3.0) / 2, y.value) // y = cos(PI/6) = sqrt(3)/2
|
||||||
|
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivGrad() {
|
||||||
|
val x = Variable(1.0)
|
||||||
|
val y = Variable(2.0)
|
||||||
|
val res = deriv { x * x + y * y }
|
||||||
|
assertEquals(6.0, res.div())
|
||||||
|
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun assertApprox(a: Double, b: Double) {
|
||||||
|
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,46 +1,26 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
`multiplatform-config`
|
||||||
|
id("kotlinx-atomicfu") version Versions.atomicfuVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
val coroutinesVersion: String by rootProject.extra
|
kotlin.sourceSets {
|
||||||
|
commonMain {
|
||||||
kotlin {
|
dependencies {
|
||||||
jvm()
|
api(project(":kmath-core"))
|
||||||
js()
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
|
||||||
|
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
sourceSets {
|
|
||||||
val commonMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(project(":kmath-core"))
|
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:$coroutinesVersion")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
val commonTest by getting {
|
}
|
||||||
dependencies {
|
jvmMain {
|
||||||
implementation(kotlin("test-common"))
|
dependencies {
|
||||||
implementation(kotlin("test-annotations-common"))
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
|
||||||
}
|
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
val jvmMain by getting {
|
}
|
||||||
dependencies {
|
jsMain {
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:$coroutinesVersion")
|
dependencies {
|
||||||
}
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
|
||||||
}
|
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||||
val jvmTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test"))
|
|
||||||
implementation(kotlin("test-junit"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:$coroutinesVersion")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-js"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,102 @@
|
|||||||
|
package scientifik.kmath
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import kotlinx.coroutines.channels.produce
|
||||||
|
import kotlinx.coroutines.flow.*
|
||||||
|
|
||||||
|
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An imitator of [Deferred] which holds a suspended function block and dispatcher
|
||||||
|
*/
|
||||||
|
internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) {
|
||||||
|
private var deferred: Deferred<T>? = null
|
||||||
|
|
||||||
|
internal fun start(scope: CoroutineScope) {
|
||||||
|
if (deferred == null) {
|
||||||
|
deferred = scope.async(dispatcher, block = block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
@FlowPreview
|
||||||
|
class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
|
||||||
|
override suspend fun collect(collector: FlowCollector<T>) {
|
||||||
|
deferredFlow.collect {
|
||||||
|
collector.emit((it.await()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@FlowPreview
|
||||||
|
fun <T, R> Flow<T>.async(
|
||||||
|
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
||||||
|
block: suspend CoroutineScope.(T) -> R
|
||||||
|
): AsyncFlow<R> {
|
||||||
|
val flow = map {
|
||||||
|
LazyDeferred(dispatcher) { block(it) }
|
||||||
|
}
|
||||||
|
return AsyncFlow(flow)
|
||||||
|
}
|
||||||
|
|
||||||
|
@FlowPreview
|
||||||
|
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = AsyncFlow(deferredFlow.map { input ->
|
||||||
|
//TODO add function composition
|
||||||
|
LazyDeferred(input.dispatcher) {
|
||||||
|
input.start(this)
|
||||||
|
action(input.await())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@ExperimentalCoroutinesApi
|
||||||
|
@FlowPreview
|
||||||
|
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
|
||||||
|
require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" }
|
||||||
|
coroutineScope {
|
||||||
|
//Starting up to N deferred coroutines ahead of time
|
||||||
|
val channel = produce(capacity = concurrency - 1) {
|
||||||
|
deferredFlow.collect { value ->
|
||||||
|
value.start(this@coroutineScope)
|
||||||
|
send(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(channel as Job).invokeOnCompletion {
|
||||||
|
if (it is CancellationException && it.cause == null) cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
for (element in channel) {
|
||||||
|
collector.emit(element.await())
|
||||||
|
}
|
||||||
|
|
||||||
|
val producer = channel as Job
|
||||||
|
if (producer.isCancelled) {
|
||||||
|
producer.join()
|
||||||
|
//throw producer.getCancellationException()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ExperimentalCoroutinesApi
|
||||||
|
@FlowPreview
|
||||||
|
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit {
|
||||||
|
collect(concurrency, object : FlowCollector<T> {
|
||||||
|
override suspend fun emit(value: T) = action(value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
@FlowPreview
|
||||||
|
fun <T, R> Flow<T>.map(
|
||||||
|
dispatcher: CoroutineDispatcher,
|
||||||
|
concurrencyLevel: Int = 16,
|
||||||
|
bufferSize: Int = concurrencyLevel,
|
||||||
|
transform: suspend (T) -> R
|
||||||
|
): Flow<R> {
|
||||||
|
return flatMapMerge(concurrencyLevel, bufferSize) { value ->
|
||||||
|
flow { emit(transform(value)) }
|
||||||
|
}.flowOn(dispatcher)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -14,14 +14,10 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package scientifik.kmath.sequential
|
package scientifik.kmath.chains
|
||||||
|
|
||||||
import kotlinx.atomicfu.atomic
|
import kotlinx.atomicfu.atomic
|
||||||
import kotlinx.coroutines.CoroutineScope
|
import kotlinx.coroutines.FlowPreview
|
||||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
|
||||||
import kotlinx.coroutines.channels.ReceiveChannel
|
|
||||||
import kotlinx.coroutines.channels.produce
|
|
||||||
import kotlinx.coroutines.isActive
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -30,7 +26,7 @@ import kotlinx.coroutines.isActive
|
|||||||
*/
|
*/
|
||||||
interface Chain<out R> {
|
interface Chain<out R> {
|
||||||
/**
|
/**
|
||||||
* Last value of the chain. Returns null if [next] was not called
|
* Last cached value of the chain. Returns null if [next] was not called
|
||||||
*/
|
*/
|
||||||
val value: R?
|
val value: R?
|
||||||
|
|
||||||
@ -47,10 +43,11 @@ interface Chain<out R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Chain as a coroutine receive channel
|
* Chain as a coroutine flow. The flow emit affects chain state and vice versa
|
||||||
*/
|
*/
|
||||||
@ExperimentalCoroutinesApi
|
@FlowPreview
|
||||||
fun <R> Chain<R>.asChannel(scope: CoroutineScope): ReceiveChannel<R> = scope.produce { while (isActive) send(next()) }
|
val <R> Chain<R>.flow
|
||||||
|
get() = kotlinx.coroutines.flow.flow { while (true) emit(next()) }
|
||||||
|
|
||||||
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
||||||
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
@ -0,0 +1,80 @@
|
|||||||
|
package scientifik.kmath.streaming
|
||||||
|
|
||||||
|
import kotlinx.coroutines.FlowPreview
|
||||||
|
import kotlinx.coroutines.flow.*
|
||||||
|
import scientifik.kmath.structures.Buffer
|
||||||
|
import scientifik.kmath.structures.BufferFactory
|
||||||
|
import scientifik.kmath.structures.DoubleBuffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a [Flow] from buffer
|
||||||
|
*/
|
||||||
|
@FlowPreview
|
||||||
|
fun <T> Buffer<T>.asFlow() = iterator().asFlow()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Flat map a [Flow] of [Buffer] into continuous [Flow] of elements
|
||||||
|
*/
|
||||||
|
@FlowPreview
|
||||||
|
fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collect incoming flow into fixed size chunks
|
||||||
|
*/
|
||||||
|
@FlowPreview
|
||||||
|
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>) = flow {
|
||||||
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
|
val list = ArrayList<T>(bufferSize)
|
||||||
|
var counter = 0
|
||||||
|
|
||||||
|
this@chunked.collect { element ->
|
||||||
|
list.add(element)
|
||||||
|
counter++
|
||||||
|
if (counter == bufferSize) {
|
||||||
|
val buffer = bufferFactory(bufferSize) { list[it] }
|
||||||
|
emit(buffer)
|
||||||
|
list.clear()
|
||||||
|
counter = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (counter > 0) {
|
||||||
|
emit(bufferFactory(counter) { list[it] })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized flow chunker for real buffer
|
||||||
|
*/
|
||||||
|
@FlowPreview
|
||||||
|
fun Flow<Double>.chunked(bufferSize: Int) = flow {
|
||||||
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
|
val array = DoubleArray(bufferSize)
|
||||||
|
var counter = 0
|
||||||
|
|
||||||
|
this@chunked.collect { element ->
|
||||||
|
array[counter] = element
|
||||||
|
counter++
|
||||||
|
if (counter == bufferSize) {
|
||||||
|
val buffer = DoubleBuffer(array)
|
||||||
|
emit(buffer)
|
||||||
|
counter = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (counter > 0) {
|
||||||
|
emit(DoubleBuffer(counter) { array[it] })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map a flow to a moving window buffer. The window step is one.
|
||||||
|
* In order to get different steps, one could use skip operation.
|
||||||
|
*/
|
||||||
|
@FlowPreview
|
||||||
|
fun <T> Flow<T>.windowed(window: Int): Flow<Buffer<T>> = flow {
|
||||||
|
require(window > 1) { "Window size must be more than one" }
|
||||||
|
val ringBuffer = RingBuffer.boxing<T>(window)
|
||||||
|
this@windowed.collect { element ->
|
||||||
|
ringBuffer.push(element)
|
||||||
|
emit(ringBuffer.snapshot())
|
||||||
|
}
|
||||||
|
}
|
@ -1,16 +1,18 @@
|
|||||||
package scientifik.kmath.sequential
|
package scientifik.kmath.streaming
|
||||||
|
|
||||||
import kotlinx.coroutines.sync.Mutex
|
import kotlinx.coroutines.sync.Mutex
|
||||||
import kotlinx.coroutines.sync.withLock
|
import kotlinx.coroutines.sync.withLock
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.MutableBuffer
|
import scientifik.kmath.structures.MutableBuffer
|
||||||
import scientifik.kmath.structures.VirtualBuffer
|
import scientifik.kmath.structures.VirtualBuffer
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Thread-safe ring buffer
|
* Thread-safe ring buffer
|
||||||
*/
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
internal class RingBuffer<T>(
|
internal class RingBuffer<T>(
|
||||||
private val buffer: MutableBuffer<T>,
|
private val buffer: MutableBuffer<T?>,
|
||||||
private var startIndex: Int = 0,
|
private var startIndex: Int = 0,
|
||||||
size: Int = 0
|
size: Int = 0
|
||||||
) : Buffer<T> {
|
) : Buffer<T> {
|
||||||
@ -23,7 +25,7 @@ internal class RingBuffer<T>(
|
|||||||
override fun get(index: Int): T {
|
override fun get(index: Int): T {
|
||||||
require(index >= 0) { "Index must be positive" }
|
require(index >= 0) { "Index must be positive" }
|
||||||
require(index < size) { "Index $index is out of circular buffer size $size" }
|
require(index < size) { "Index $index is out of circular buffer size $size" }
|
||||||
return buffer[startIndex.forward(index)]
|
return buffer[startIndex.forward(index)] as T
|
||||||
}
|
}
|
||||||
|
|
||||||
fun isFull() = size == buffer.size
|
fun isFull() = size == buffer.size
|
||||||
@ -40,7 +42,7 @@ internal class RingBuffer<T>(
|
|||||||
if (count == 0) {
|
if (count == 0) {
|
||||||
done()
|
done()
|
||||||
} else {
|
} else {
|
||||||
setNext(copy[index])
|
setNext(copy[index] as T)
|
||||||
index = index.forward(1)
|
index = index.forward(1)
|
||||||
count--
|
count--
|
||||||
}
|
}
|
||||||
@ -53,7 +55,9 @@ internal class RingBuffer<T>(
|
|||||||
suspend fun snapshot(): Buffer<T> {
|
suspend fun snapshot(): Buffer<T> {
|
||||||
mutex.withLock {
|
mutex.withLock {
|
||||||
val copy = buffer.copy()
|
val copy = buffer.copy()
|
||||||
return VirtualBuffer(size) { i -> copy[startIndex.forward(i)] }
|
return VirtualBuffer(size) { i ->
|
||||||
|
copy[startIndex.forward(i)] as T
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,14 +78,14 @@ internal class RingBuffer<T>(
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
inline fun <reified T : Any> build(size: Int, empty: T): RingBuffer<T> {
|
inline fun <reified T : Any> build(size: Int, empty: T): RingBuffer<T> {
|
||||||
val buffer = MutableBuffer.auto(size) { empty }
|
val buffer = MutableBuffer.auto(size) { empty } as MutableBuffer<T?>
|
||||||
return RingBuffer(buffer)
|
return RingBuffer(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Slow yet universal buffer
|
* Slow yet universal buffer
|
||||||
*/
|
*/
|
||||||
fun <T> boxing(size: Int): RingBuffer<T?> {
|
fun <T> boxing(size: Int): RingBuffer<T> {
|
||||||
val buffer: MutableBuffer<T?> = MutableBuffer.boxing(size) { null }
|
val buffer: MutableBuffer<T?> = MutableBuffer.boxing(size) { null }
|
||||||
return RingBuffer(buffer)
|
return RingBuffer(buffer)
|
||||||
}
|
}
|
@ -1,9 +0,0 @@
|
|||||||
package scientifik.kmath.structures
|
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineDispatcher
|
|
||||||
import kotlinx.coroutines.CoroutineScope
|
|
||||||
import kotlinx.coroutines.Dispatchers
|
|
||||||
import kotlin.coroutines.CoroutineContext
|
|
||||||
import kotlin.coroutines.EmptyCoroutineContext
|
|
||||||
|
|
||||||
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
|
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.sequential
|
package scientifik.kmath.chains
|
||||||
|
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import kotlin.sequences.Sequence
|
import kotlin.sequences.Sequence
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
|
import scientifik.kmath.Math
|
||||||
|
|
||||||
class LazyNDStructure<T>(
|
class LazyNDStructure<T>(
|
||||||
val scope: CoroutineScope,
|
val scope: CoroutineScope,
|
||||||
|
@ -0,0 +1,48 @@
|
|||||||
|
package scientifik.kmath.streaming
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import kotlinx.coroutines.flow.asFlow
|
||||||
|
import kotlinx.coroutines.flow.collect
|
||||||
|
import org.junit.Test
|
||||||
|
import scientifik.kmath.async
|
||||||
|
import scientifik.kmath.collect
|
||||||
|
import scientifik.kmath.map
|
||||||
|
import java.util.concurrent.Executors
|
||||||
|
|
||||||
|
|
||||||
|
@ExperimentalCoroutinesApi
|
||||||
|
@InternalCoroutinesApi
|
||||||
|
@FlowPreview
|
||||||
|
class BufferFlowTest {
|
||||||
|
|
||||||
|
val dispatcher = Executors.newFixedThreadPool(4).asCoroutineDispatcher()
|
||||||
|
|
||||||
|
@Test(timeout = 2000)
|
||||||
|
fun map() {
|
||||||
|
runBlocking {
|
||||||
|
(1..20).asFlow().map( dispatcher) {
|
||||||
|
//println("Started $it on ${Thread.currentThread().name}")
|
||||||
|
@Suppress("BlockingMethodInNonBlockingContext")
|
||||||
|
Thread.sleep(200)
|
||||||
|
it
|
||||||
|
}.collect {
|
||||||
|
println("Completed $it on ${Thread.currentThread().name}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout = 2000)
|
||||||
|
fun async() {
|
||||||
|
runBlocking {
|
||||||
|
(1..20).asFlow().async(dispatcher) {
|
||||||
|
//println("Started $it on ${Thread.currentThread().name}")
|
||||||
|
@Suppress("BlockingMethodInNonBlockingContext")
|
||||||
|
Thread.sleep(200)
|
||||||
|
it
|
||||||
|
}.collect(4) {
|
||||||
|
println("Completed $it on ${Thread.currentThread().name}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
package scientifik.kmath.streaming
|
||||||
|
|
||||||
|
import kotlinx.coroutines.flow.*
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import org.junit.Test
|
||||||
|
import scientifik.kmath.structures.asSequence
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class RingBufferTest {
|
||||||
|
@Test
|
||||||
|
fun push() {
|
||||||
|
val buffer = RingBuffer.build(20, Double.NaN)
|
||||||
|
runBlocking {
|
||||||
|
for (i in 1..30) {
|
||||||
|
buffer.push(i.toDouble())
|
||||||
|
}
|
||||||
|
assertEquals(410.0, buffer.asSequence().sum())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun windowed(){
|
||||||
|
val flow = flow{
|
||||||
|
var i = 0
|
||||||
|
while(true){
|
||||||
|
emit(i++)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val windowed = flow.windowed(10)
|
||||||
|
runBlocking {
|
||||||
|
val first = windowed.take(1).single()
|
||||||
|
val res = windowed.take(15).map { it -> it.asSequence().average() }.toList()
|
||||||
|
assertEquals(0.0, res[0])
|
||||||
|
assertEquals(4.5, res[9])
|
||||||
|
assertEquals(9.5, res[14])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,34 +1,10 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
`multiplatform-config`
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin {
|
// Just an example how we can collapse nested DSL for simple declarations
|
||||||
jvm()
|
kotlin.sourceSets.commonMain {
|
||||||
js()
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
sourceSets {
|
|
||||||
|
|
||||||
val commonMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(project(":kmath-core"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val commonTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-common"))
|
|
||||||
implementation(kotlin("test-annotations-common"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jvmTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test"))
|
|
||||||
implementation(kotlin("test-junit"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-js"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,57 +1,31 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
`multiplatform-config`
|
||||||
}
|
}
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
maven("http://dl.bintray.com/kyonifer/maven")
|
maven("http://dl.bintray.com/kyonifer/maven")
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin {
|
kotlin.sourceSets {
|
||||||
jvm {
|
commonMain {
|
||||||
compilations.all {
|
dependencies {
|
||||||
kotlinOptions {
|
api(project(":kmath-core"))
|
||||||
jvmTarget = "1.8"
|
api("com.kyonifer:koma-core-api-common:0.12")
|
||||||
freeCompilerArgs += "-progressive"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
js()
|
jvmMain {
|
||||||
|
dependencies {
|
||||||
sourceSets {
|
api("com.kyonifer:koma-core-api-jvm:0.12")
|
||||||
|
|
||||||
val commonMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(project(":kmath-core"))
|
|
||||||
api("com.kyonifer:koma-core-api-common:0.12")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
val commonTest by getting {
|
}
|
||||||
dependencies {
|
jvmTest {
|
||||||
implementation(kotlin("test-common"))
|
dependencies {
|
||||||
implementation(kotlin("test-annotations-common"))
|
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
val jvmMain by getting {
|
}
|
||||||
dependencies {
|
jsMain {
|
||||||
api("com.kyonifer:koma-core-api-jvm:0.12")
|
dependencies {
|
||||||
}
|
api("com.kyonifer:koma-core-api-js:0.12")
|
||||||
}
|
|
||||||
val jvmTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test"))
|
|
||||||
implementation(kotlin("test-junit"))
|
|
||||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api("com.kyonifer:koma-core-api-js:0.12")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-js"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -2,9 +2,14 @@ package scientifik.kmath.linear
|
|||||||
|
|
||||||
import koma.extensions.fill
|
import koma.extensions.fill
|
||||||
import koma.matrix.MatrixFactory
|
import koma.matrix.MatrixFactory
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
|
||||||
class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T>>) : MatrixContext<T>,
|
class KomaMatrixContext<T : Any>(
|
||||||
LinearSolver<T> {
|
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
||||||
|
private val space: Space<T>
|
||||||
|
) :
|
||||||
|
MatrixContext<T> {
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) =
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) =
|
||||||
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
||||||
@ -31,41 +36,52 @@ class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T
|
|||||||
override fun Matrix<T>.unaryMinus() =
|
override fun Matrix<T>.unaryMinus() =
|
||||||
KomaMatrix(this.toKoma().origin.unaryMinus())
|
KomaMatrix(this.toKoma().origin.unaryMinus())
|
||||||
|
|
||||||
override fun Matrix<T>.plus(b: Matrix<T>) =
|
override fun add(a: Matrix<T>, b: Matrix<T>) =
|
||||||
KomaMatrix(this.toKoma().origin + b.toKoma().origin)
|
KomaMatrix(a.toKoma().origin + b.toKoma().origin)
|
||||||
|
|
||||||
override fun Matrix<T>.minus(b: Matrix<T>) =
|
override fun Matrix<T>.minus(b: Matrix<T>) =
|
||||||
KomaMatrix(this.toKoma().origin - b.toKoma().origin)
|
KomaMatrix(this.toKoma().origin - b.toKoma().origin)
|
||||||
|
|
||||||
|
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||||
|
produce(a.rowNum, a.colNum) { i, j -> space.run { a[i, j] * k } }
|
||||||
|
|
||||||
override fun Matrix<T>.times(value: T) =
|
override fun Matrix<T>.times(value: T) =
|
||||||
KomaMatrix(this.toKoma().origin * value)
|
KomaMatrix(this.toKoma().origin * value)
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
|
||||||
override fun solve(a: Matrix<T>, b: Matrix<T>) =
|
}
|
||||||
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
|
||||||
|
|
||||||
override fun inverse(a: Matrix<T>) =
|
|
||||||
KomaMatrix(a.toKoma().origin.inv())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) :
|
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
|
||||||
Matrix<T> {
|
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
||||||
|
|
||||||
|
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Point<T>) =
|
||||||
|
KomaVector(a.toKoma().origin.solve(b.toKoma().origin))
|
||||||
|
|
||||||
|
fun <T : Any> KomaMatrixContext<T>.inverse(a: Matrix<T>) =
|
||||||
|
KomaMatrix(a.toKoma().origin.inv())
|
||||||
|
|
||||||
|
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : FeaturedMatrix<T> {
|
||||||
override val rowNum: Int get() = origin.numRows()
|
override val rowNum: Int get() = origin.numRows()
|
||||||
override val colNum: Int get() = origin.numCols()
|
override val colNum: Int get() = origin.numCols()
|
||||||
|
|
||||||
|
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||||
|
|
||||||
override val features: Set<MatrixFeature> = features ?: setOf(
|
override val features: Set<MatrixFeature> = features ?: setOf(
|
||||||
object : DeterminantFeature<T> {
|
object : DeterminantFeature<T> {
|
||||||
override val determinant: T get() = origin.det()
|
override val determinant: T get() = origin.det()
|
||||||
},
|
},
|
||||||
object : LUPDecompositionFeature<T> {
|
object : LUPDecompositionFeature<T> {
|
||||||
private val lup by lazy { origin.LU() }
|
private val lup by lazy { origin.LU() }
|
||||||
override val l: Matrix<T> get() = KomaMatrix(lup.second)
|
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
|
||||||
override val u: Matrix<T> get() = KomaMatrix(lup.third)
|
override val u: FeaturedMatrix<T> get() = KomaMatrix(lup.third)
|
||||||
override val p: Matrix<T> get() = KomaMatrix(lup.first)
|
override val p: FeaturedMatrix<T> get() = KomaMatrix(lup.first)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature): Matrix<T> =
|
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
|
||||||
KomaMatrix(this.origin, this.features + features)
|
KomaMatrix(this.origin, this.features + features)
|
||||||
|
|
||||||
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
||||||
|
@ -1,50 +1,3 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
`multiplatform-config`
|
||||||
}
|
|
||||||
|
|
||||||
val ioVersion: String by rootProject.extra
|
|
||||||
|
|
||||||
|
|
||||||
kotlin {
|
|
||||||
jvm()
|
|
||||||
js()
|
|
||||||
|
|
||||||
sourceSets {
|
|
||||||
val commonMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(kotlin("stdlib"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val commonTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-common"))
|
|
||||||
implementation(kotlin("test-annotations-common"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jvmMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(kotlin("stdlib-jdk8"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jvmTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test"))
|
|
||||||
implementation(kotlin("test-junit"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(kotlin("stdlib-js"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jsTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-js"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// mingwMain {
|
|
||||||
// }
|
|
||||||
// mingwTest {
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
}
|
}
|
29
kmath-prob/build.gradle.kts
Normal file
29
kmath-prob/build.gradle.kts
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
plugins {
|
||||||
|
`multiplatform-config`
|
||||||
|
id("kotlinx-atomicfu") version Versions.atomicfuVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
kotlin.sourceSets {
|
||||||
|
commonMain {
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
api(project(":kmath-coroutines"))
|
||||||
|
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jvmMain {
|
||||||
|
dependencies {
|
||||||
|
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsMain {
|
||||||
|
dependencies {
|
||||||
|
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
atomicfu {
|
||||||
|
variant = "VH"
|
||||||
|
}
|
@ -1,53 +0,0 @@
|
|||||||
plugins {
|
|
||||||
kotlin("multiplatform")
|
|
||||||
id("kotlinx-atomicfu") version "0.12.1"
|
|
||||||
}
|
|
||||||
|
|
||||||
val atomicfuVersion: String by rootProject.extra
|
|
||||||
|
|
||||||
kotlin {
|
|
||||||
jvm ()
|
|
||||||
//js()
|
|
||||||
|
|
||||||
sourceSets {
|
|
||||||
val commonMain by getting {
|
|
||||||
dependencies {
|
|
||||||
api(project(":kmath-core"))
|
|
||||||
api(project(":kmath-coroutines"))
|
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:$atomicfuVersion")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val commonTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test-common"))
|
|
||||||
implementation(kotlin("test-annotations-common"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jvmMain by getting {
|
|
||||||
dependencies {
|
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu:$atomicfuVersion")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val jvmTest by getting {
|
|
||||||
dependencies {
|
|
||||||
implementation(kotlin("test"))
|
|
||||||
implementation(kotlin("test-junit"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// val jsMain by getting {
|
|
||||||
// dependencies {
|
|
||||||
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:$atomicfuVersion")
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// val jsTest by getting {
|
|
||||||
// dependencies {
|
|
||||||
// implementation(kotlin("test-js"))
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
atomicfu {
|
|
||||||
variant = "VH"
|
|
||||||
}
|
|
@ -1,82 +0,0 @@
|
|||||||
package scientifik.kmath.sequential
|
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineScope
|
|
||||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
|
||||||
import kotlinx.coroutines.channels.Channel
|
|
||||||
import kotlinx.coroutines.channels.produce
|
|
||||||
import kotlinx.coroutines.isActive
|
|
||||||
import kotlinx.coroutines.sync.Mutex
|
|
||||||
import kotlinx.coroutines.sync.withLock
|
|
||||||
import scientifik.kmath.structures.Buffer
|
|
||||||
import scientifik.kmath.structures.BufferFactory
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A processor that collects incoming elements into fixed size buffers
|
|
||||||
*/
|
|
||||||
@ExperimentalCoroutinesApi
|
|
||||||
class JoinProcessor<T>(
|
|
||||||
scope: CoroutineScope,
|
|
||||||
bufferSize: Int,
|
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
|
||||||
) : AbstractProcessor<T, Buffer<T>>(scope) {
|
|
||||||
|
|
||||||
private val input = Channel<T>(bufferSize)
|
|
||||||
|
|
||||||
private val output = produce(coroutineContext) {
|
|
||||||
val list = ArrayList<T>(bufferSize)
|
|
||||||
while (isActive) {
|
|
||||||
list.clear()
|
|
||||||
repeat(bufferSize) {
|
|
||||||
list.add(input.receive())
|
|
||||||
}
|
|
||||||
val buffer = bufferFactory(bufferSize) { list[it] }
|
|
||||||
send(buffer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun receive(): Buffer<T> = output.receive()
|
|
||||||
|
|
||||||
override suspend fun send(value: T) {
|
|
||||||
input.send(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A processor that splits incoming buffers into individual elements
|
|
||||||
*/
|
|
||||||
class SplitProcessor<T>(scope: CoroutineScope) : AbstractProcessor<Buffer<T>, T>(scope) {
|
|
||||||
|
|
||||||
private val input = Channel<Buffer<T>>()
|
|
||||||
|
|
||||||
private val mutex = Mutex()
|
|
||||||
|
|
||||||
private var currentBuffer: Buffer<T>? = null
|
|
||||||
|
|
||||||
private var pos = 0
|
|
||||||
|
|
||||||
|
|
||||||
override suspend fun receive(): T {
|
|
||||||
mutex.withLock {
|
|
||||||
while (currentBuffer == null || pos == currentBuffer!!.size) {
|
|
||||||
currentBuffer = input.receive()
|
|
||||||
pos = 0
|
|
||||||
}
|
|
||||||
return currentBuffer!![pos].also { pos++ }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun send(value: Buffer<T>) {
|
|
||||||
input.send(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
|
||||||
fun <T> Producer<T>.chunked(chunkSize: Int, bufferFactory: BufferFactory<T>) =
|
|
||||||
JoinProcessor<T>(this, chunkSize, bufferFactory).also { connect(it) }
|
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
|
||||||
inline fun <reified T : Any> Producer<T>.chunked(chunkSize: Int) =
|
|
||||||
JoinProcessor<T>(this, chunkSize, Buffer.Companion::auto).also { connect(it) }
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,275 +0,0 @@
|
|||||||
package scientifik.kmath.sequential
|
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineScope
|
|
||||||
import kotlinx.coroutines.GlobalScope
|
|
||||||
import kotlinx.coroutines.channels.*
|
|
||||||
import kotlinx.coroutines.isActive
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
import kotlinx.coroutines.sync.Mutex
|
|
||||||
import kotlinx.coroutines.sync.withLock
|
|
||||||
import scientifik.kmath.structures.Buffer
|
|
||||||
import kotlin.coroutines.CoroutineContext
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initial chain block. Could produce an element sequence and be connected to single [Consumer]
|
|
||||||
*
|
|
||||||
* The general rule is that channel is created on first call. Also each element is responsible for its connection so
|
|
||||||
* while the connections are symmetric, the scope, used for making the connection is responsible for cancelation.
|
|
||||||
*
|
|
||||||
* Also connections are not reversible. Once connected block stays faithful until it finishes processing.
|
|
||||||
* Manually putting elements to connected block could lead to undetermined behavior and must be avoided.
|
|
||||||
*/
|
|
||||||
interface Producer<T> : CoroutineScope {
|
|
||||||
fun connect(consumer: Consumer<T>)
|
|
||||||
|
|
||||||
suspend fun receive(): T
|
|
||||||
|
|
||||||
val consumer: Consumer<T>?
|
|
||||||
|
|
||||||
val outputIsConnected: Boolean get() = consumer != null
|
|
||||||
|
|
||||||
//fun close()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Terminal chain block. Could consume an element sequence and be connected to signle [Producer]
|
|
||||||
*/
|
|
||||||
interface Consumer<T> : CoroutineScope {
|
|
||||||
fun connect(producer: Producer<T>)
|
|
||||||
|
|
||||||
suspend fun send(value: T)
|
|
||||||
|
|
||||||
val producer: Producer<T>?
|
|
||||||
|
|
||||||
val inputIsConnected: Boolean get() = producer != null
|
|
||||||
|
|
||||||
//fun close()
|
|
||||||
}
|
|
||||||
|
|
||||||
interface Processor<T, R> : Consumer<T>, Producer<R>
|
|
||||||
|
|
||||||
abstract class AbstractProducer<T>(scope: CoroutineScope) : Producer<T> {
|
|
||||||
override val coroutineContext: CoroutineContext = scope.coroutineContext
|
|
||||||
|
|
||||||
override var consumer: Consumer<T>? = null
|
|
||||||
protected set
|
|
||||||
|
|
||||||
override fun connect(consumer: Consumer<T>) {
|
|
||||||
//Ignore if already connected to specific consumer
|
|
||||||
if (consumer != this.consumer) {
|
|
||||||
if (outputIsConnected) error("The output slot of producer is occupied")
|
|
||||||
if (consumer.inputIsConnected) error("The input slot of consumer is occupied")
|
|
||||||
this.consumer = consumer
|
|
||||||
if (consumer.producer != null) {
|
|
||||||
//No need to save the job, it will be canceled on scope cancel
|
|
||||||
connectOutput(consumer)
|
|
||||||
// connect back, consumer is already set so no circular reference
|
|
||||||
consumer.connect(this)
|
|
||||||
} else error("Unreachable statement")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected open fun connectOutput(consumer: Consumer<T>) {
|
|
||||||
launch {
|
|
||||||
while (this.isActive) {
|
|
||||||
consumer.send(receive())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
abstract class AbstractConsumer<T>(scope: CoroutineScope) : Consumer<T> {
|
|
||||||
override val coroutineContext: CoroutineContext = scope.coroutineContext
|
|
||||||
|
|
||||||
override var producer: Producer<T>? = null
|
|
||||||
protected set
|
|
||||||
|
|
||||||
override fun connect(producer: Producer<T>) {
|
|
||||||
//Ignore if already connected to specific consumer
|
|
||||||
if (producer != this.producer) {
|
|
||||||
if (inputIsConnected) error("The input slot of consumer is occupied")
|
|
||||||
if (producer.outputIsConnected) error("The input slot of producer is occupied")
|
|
||||||
this.producer = producer
|
|
||||||
//No need to save the job, it will be canceled on scope cancel
|
|
||||||
if (producer.consumer != null) {
|
|
||||||
connectInput(producer)
|
|
||||||
// connect back
|
|
||||||
producer.connect(this)
|
|
||||||
} else error("Unreachable statement")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected open fun connectInput(producer: Producer<T>) {
|
|
||||||
launch {
|
|
||||||
while (isActive) {
|
|
||||||
send(producer.receive())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
abstract class AbstractProcessor<T, R>(scope: CoroutineScope) : Processor<T, R>, AbstractProducer<R>(scope) {
|
|
||||||
|
|
||||||
override var producer: Producer<T>? = null
|
|
||||||
protected set
|
|
||||||
|
|
||||||
override fun connect(producer: Producer<T>) {
|
|
||||||
//Ignore if already connected to specific consumer
|
|
||||||
if (producer != this.producer) {
|
|
||||||
if (inputIsConnected) error("The input slot of consumer is occupied")
|
|
||||||
if (producer.outputIsConnected) error("The input slot of producer is occupied")
|
|
||||||
this.producer = producer
|
|
||||||
//No need to save the job, it will be canceled on scope cancel
|
|
||||||
if (producer.consumer != null) {
|
|
||||||
connectInput(producer)
|
|
||||||
// connect back
|
|
||||||
producer.connect(this)
|
|
||||||
} else error("Unreachable statement")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected open fun connectInput(producer: Producer<T>) {
|
|
||||||
launch {
|
|
||||||
while (isActive) {
|
|
||||||
send(producer.receive())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A simple [produce]-based producer
|
|
||||||
*/
|
|
||||||
class GenericProducer<T>(
|
|
||||||
scope: CoroutineScope,
|
|
||||||
capacity: Int = Channel.UNLIMITED,
|
|
||||||
block: suspend ProducerScope<T>.() -> Unit
|
|
||||||
) : AbstractProducer<T>(scope) {
|
|
||||||
|
|
||||||
private val channel: ReceiveChannel<T> by lazy { produce(capacity = capacity, block = block) }
|
|
||||||
|
|
||||||
override suspend fun receive(): T = channel.receive()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A simple pipeline [Processor] block
|
|
||||||
*/
|
|
||||||
class PipeProcessor<T, R>(
|
|
||||||
scope: CoroutineScope,
|
|
||||||
capacity: Int = Channel.RENDEZVOUS,
|
|
||||||
process: suspend (T) -> R
|
|
||||||
) : AbstractProcessor<T, R>(scope) {
|
|
||||||
|
|
||||||
private val input = Channel<T>(capacity)
|
|
||||||
private val output: ReceiveChannel<R> = input.map(coroutineContext, process)
|
|
||||||
|
|
||||||
override suspend fun receive(): R = output.receive()
|
|
||||||
|
|
||||||
override suspend fun send(value: T) {
|
|
||||||
input.send(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A moving window [Processor] with circular buffer
|
|
||||||
*/
|
|
||||||
class WindowedProcessor<T, R>(
|
|
||||||
scope: CoroutineScope,
|
|
||||||
window: Int,
|
|
||||||
val process: suspend (Buffer<T?>) -> R
|
|
||||||
) : AbstractProcessor<T, R>(scope) {
|
|
||||||
|
|
||||||
private val ringBuffer = RingBuffer.boxing<T>(window)
|
|
||||||
|
|
||||||
private val channel = Channel<R>(Channel.RENDEZVOUS)
|
|
||||||
|
|
||||||
override suspend fun receive(): R {
|
|
||||||
return channel.receive()
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun send(value: T) {
|
|
||||||
ringBuffer.push(value)
|
|
||||||
channel.send(process(ringBuffer.snapshot()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Thread-safe aggregator of values from input. The aggregator does not store all incoming values, it uses fold procedure
|
|
||||||
* to incorporate them into state on-arrival.
|
|
||||||
* The current aggregated state could be accessed by [state]. The input channel is inactive unless requested
|
|
||||||
* @param T - the type of the input element
|
|
||||||
* @param S - the type of the aggregator
|
|
||||||
*/
|
|
||||||
class Reducer<T, S>(
|
|
||||||
scope: CoroutineScope,
|
|
||||||
initialState: S,
|
|
||||||
val fold: suspend (S, T) -> S
|
|
||||||
) : AbstractConsumer<T>(scope) {
|
|
||||||
|
|
||||||
var state: S = initialState
|
|
||||||
private set
|
|
||||||
|
|
||||||
private val mutex = Mutex()
|
|
||||||
|
|
||||||
override suspend fun send(value: T) = mutex.withLock {
|
|
||||||
state = fold(state, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Collector that accumulates all values in a list. List could be accessed from non-suspending environment via [list] value.
|
|
||||||
*/
|
|
||||||
class Collector<T>(scope: CoroutineScope) : AbstractConsumer<T>(scope) {
|
|
||||||
|
|
||||||
private val _list = ArrayList<T>()
|
|
||||||
private val mutex = Mutex()
|
|
||||||
val list: List<T> get() = _list
|
|
||||||
|
|
||||||
override suspend fun send(value: T) {
|
|
||||||
mutex.withLock {
|
|
||||||
_list.add(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a sequence to [Producer]
|
|
||||||
*/
|
|
||||||
fun <T> Sequence<T>.produce(scope: CoroutineScope = GlobalScope) =
|
|
||||||
GenericProducer<T>(scope) { forEach { send(it) } }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a [ReceiveChannel] to [Producer]
|
|
||||||
*/
|
|
||||||
fun <T> ReceiveChannel<T>.produce(scope: CoroutineScope = GlobalScope) =
|
|
||||||
GenericProducer<T>(scope) { for (e in this@produce) send(e) }
|
|
||||||
|
|
||||||
|
|
||||||
fun <T, C : Consumer<T>> Producer<T>.consumer(consumerFactory: () -> C): C =
|
|
||||||
consumerFactory().also { connect(it) }
|
|
||||||
|
|
||||||
fun <T, R> Producer<T>.map(capacity: Int = Channel.RENDEZVOUS, process: suspend (T) -> R) =
|
|
||||||
PipeProcessor(this, capacity, process).also { connect(it) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a reducer and connect this producer to reducer
|
|
||||||
*/
|
|
||||||
fun <T, S> Producer<T>.reduce(initialState: S, fold: suspend (S, T) -> S) =
|
|
||||||
Reducer(this, initialState, fold).also { connect(it) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a [Collector] and attach it to this [Producer]
|
|
||||||
*/
|
|
||||||
fun <T> Producer<T>.collect() =
|
|
||||||
Collector<T>(this).also { connect(it) }
|
|
||||||
|
|
||||||
fun <T, R, P : Processor<T, R>> Producer<T>.process(processorBuilder: () -> P): P =
|
|
||||||
processorBuilder().also { connect(it) }
|
|
||||||
|
|
||||||
fun <T, R> Producer<T>.process(capacity: Int = Channel.RENDEZVOUS, process: suspend (T) -> R) =
|
|
||||||
PipeProcessor<T, R>(this, capacity, process).also { connect(it) }
|
|
||||||
|
|
||||||
|
|
||||||
fun <T, R> Producer<T>.windowed(window: Int, process: suspend (Buffer<T?>) -> R) =
|
|
||||||
WindowedProcessor(this, window, process).also { connect(it) }
|
|
@ -1,19 +0,0 @@
|
|||||||
package scientifik.kmath.sequential
|
|
||||||
|
|
||||||
import kotlinx.coroutines.runBlocking
|
|
||||||
import scientifik.kmath.structures.asSequence
|
|
||||||
import kotlin.test.Test
|
|
||||||
import kotlin.test.assertEquals
|
|
||||||
|
|
||||||
class RingBufferTest {
|
|
||||||
@Test
|
|
||||||
fun testPush() {
|
|
||||||
val buffer = RingBuffer.build(20, Double.NaN)
|
|
||||||
runBlocking {
|
|
||||||
for (i in 1..30) {
|
|
||||||
buffer.push(i.toDouble())
|
|
||||||
}
|
|
||||||
assertEquals(410.0, buffer.asSequence().sum())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -26,6 +26,6 @@ include(
|
|||||||
":kmath-histograms",
|
":kmath-histograms",
|
||||||
":kmath-commons",
|
":kmath-commons",
|
||||||
":kmath-koma",
|
":kmath-koma",
|
||||||
":kmath-sequential",
|
":kmath-prob",
|
||||||
":benchmarks"
|
":benchmarks"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user