forked from kscience/kmath
commit
82e4ea8897
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
.gradle
|
||||
**/build/
|
||||
/.idea/
|
||||
build/
|
||||
out/
|
||||
.idea/
|
||||
|
||||
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
||||
!gradle-wrapper.jar
|
||||
|
@ -21,7 +21,7 @@ dependencies {
|
||||
//jmh project(':kmath-core')
|
||||
}
|
||||
|
||||
jmh{
|
||||
jmh {
|
||||
warmupIterations = 1
|
||||
}
|
||||
|
||||
|
@ -4,16 +4,14 @@ import org.openjdk.jmh.annotations.Benchmark
|
||||
import org.openjdk.jmh.annotations.Scope
|
||||
import org.openjdk.jmh.annotations.State
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.complex
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
open class BufferBenchmark {
|
||||
|
||||
@Benchmark
|
||||
fun genericDoubleBufferReadWrite() {
|
||||
val buffer = Double.createBuffer(size)
|
||||
(0 until size).forEach {
|
||||
buffer[it] = it.toDouble()
|
||||
}
|
||||
val buffer = DoubleBuffer(size){it.toDouble()}
|
||||
|
||||
(0 until size).forEach {
|
||||
buffer[it]
|
||||
@ -22,10 +20,7 @@ open class BufferBenchmark {
|
||||
|
||||
@Benchmark
|
||||
fun complexBufferReadWrite() {
|
||||
val buffer = MutableBuffer.complex(size / 2)
|
||||
(0 until size / 2).forEach {
|
||||
buffer[it] = Complex(it.toDouble(), -it.toDouble())
|
||||
}
|
||||
val buffer = MutableBuffer.complex(size / 2){Complex(it.toDouble(), -it.toDouble())}
|
||||
|
||||
(0 until size / 2).forEach {
|
||||
buffer[it]
|
||||
|
@ -34,19 +34,6 @@ open class NDFieldBenchmark {
|
||||
}
|
||||
|
||||
|
||||
@Benchmark
|
||||
fun lazyFieldAdd() {
|
||||
lazyNDField.run {
|
||||
var res = one
|
||||
repeat(n) {
|
||||
res += one
|
||||
}
|
||||
|
||||
res.elements().sumByDouble { it.second }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Benchmark
|
||||
fun boxingFieldAdd() {
|
||||
genericField.run {
|
||||
@ -61,9 +48,8 @@ open class NDFieldBenchmark {
|
||||
val dim = 1000
|
||||
val n = 100
|
||||
|
||||
val bufferedField = NDField.auto(RealField, intArrayOf(dim, dim))
|
||||
val specializedField = NDField.real(intArrayOf(dim, dim))
|
||||
val bufferedField = NDField.auto(RealField, dim, dim)
|
||||
val specializedField = NDField.real(dim, dim)
|
||||
val genericField = NDField.buffered(intArrayOf(dim, dim), RealField)
|
||||
val lazyNDField = NDField.lazy(intArrayOf(dim, dim), RealField)
|
||||
}
|
||||
}
|
@ -1,39 +1,42 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import koma.matrix.ejml.EJMLMatrixFactory
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import kotlin.contracts.ExperimentalContracts
|
||||
import kotlin.random.Random
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
@ExperimentalContracts
|
||||
fun main() {
|
||||
val random = Random(12224)
|
||||
val random = Random(1224)
|
||||
val dim = 100
|
||||
//creating invertible matrix
|
||||
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||
val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
|
||||
val matrix = l dot u
|
||||
|
||||
val n = 500 // iterations
|
||||
val n = 5000 // iterations
|
||||
|
||||
val solver = LUSolver.real
|
||||
MatrixContext.real.run {
|
||||
|
||||
repeat(50) {
|
||||
val res = solver.inverse(matrix)
|
||||
}
|
||||
|
||||
val inverseTime = measureTimeMillis {
|
||||
repeat(n) {
|
||||
val res = solver.inverse(matrix)
|
||||
repeat(50) {
|
||||
val res = inverse(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||
val inverseTime = measureTimeMillis {
|
||||
repeat(n) {
|
||||
val res = inverse(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||
}
|
||||
|
||||
//commons-math
|
||||
|
||||
val cmContext = CMMatrixContext
|
||||
|
||||
val commonsTime = measureTimeMillis {
|
||||
cmContext.run {
|
||||
CMMatrixContext.run {
|
||||
val cm = matrix.toCM() //avoid overhead on conversion
|
||||
repeat(n) {
|
||||
val res = inverse(cm)
|
||||
@ -46,10 +49,8 @@ fun main() {
|
||||
|
||||
//koma-ejml
|
||||
|
||||
val komaContext = KomaMatrixContext(EJMLMatrixFactory())
|
||||
|
||||
val komaTime = measureTimeMillis {
|
||||
komaContext.run {
|
||||
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
|
||||
val km = matrix.toKoma() //avoid overhead on conversion
|
||||
repeat(n) {
|
||||
val res = inverse(km)
|
||||
|
@ -1,6 +1,8 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import koma.matrix.ejml.EJMLMatrixFactory
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import kotlin.random.Random
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
@ -26,7 +28,7 @@ fun main() {
|
||||
}
|
||||
|
||||
|
||||
KomaMatrixContext(EJMLMatrixFactory()).run {
|
||||
KomaMatrixContext(EJMLMatrixFactory(), RealField).run {
|
||||
val komaMatrix1 = matrix1.toKoma()
|
||||
val komaMatrix2 = matrix2.toKoma()
|
||||
|
||||
|
@ -23,7 +23,7 @@ fun main() {
|
||||
|
||||
val complexTime = measureTimeMillis {
|
||||
complexField.run {
|
||||
var res: ComplexNDElement = one
|
||||
var res = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
}
|
||||
|
121
build.gradle.kts
121
build.gradle.kts
@ -1,110 +1,29 @@
|
||||
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
|
||||
|
||||
buildscript {
|
||||
val kotlinVersion: String by rootProject.extra("1.3.21")
|
||||
val ioVersion: String by rootProject.extra("0.1.5")
|
||||
val coroutinesVersion: String by rootProject.extra("1.1.1")
|
||||
val atomicfuVersion: String by rootProject.extra("0.12.1")
|
||||
val dokkaVersion: String by rootProject.extra("0.9.17")
|
||||
|
||||
repositories {
|
||||
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
jcenter()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
classpath("org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlinVersion")
|
||||
classpath("org.jfrog.buildinfo:build-info-extractor-gradle:4+")
|
||||
classpath("com.jfrog.bintray.gradle:gradle-bintray-plugin:1.8.4")
|
||||
classpath("org.jetbrains.dokka:dokka-gradle-plugin:$dokkaVersion")
|
||||
}
|
||||
}
|
||||
|
||||
plugins {
|
||||
id("com.jfrog.artifactory") version "4.9.1" apply false
|
||||
}
|
||||
|
||||
val kmathVersion by extra("0.1.0")
|
||||
val kmathVersion by extra("0.1.2-dev-4")
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
maven("https://kotlin.bintray.com/kotlinx")
|
||||
}
|
||||
|
||||
group = "scientifik"
|
||||
version = kmathVersion
|
||||
|
||||
repositories {
|
||||
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
jcenter()
|
||||
}
|
||||
|
||||
apply(plugin = "maven")
|
||||
apply(plugin = "maven-publish")
|
||||
|
||||
// apply bintray configuration
|
||||
apply(from = "${rootProject.rootDir}/gradle/bintray.gradle")
|
||||
|
||||
//apply artifactory configuration
|
||||
apply(from = "${rootProject.rootDir}/gradle/artifactory.gradle")
|
||||
}
|
||||
|
||||
subprojects {
|
||||
if (!name.startsWith("kmath")) return@subprojects
|
||||
|
||||
|
||||
extensions.findByType<KotlinMultiplatformExtension>()?.apply {
|
||||
jvm {
|
||||
compilations.all {
|
||||
kotlinOptions {
|
||||
jvmTarget = "1.8"
|
||||
}
|
||||
}
|
||||
}
|
||||
targets.all {
|
||||
sourceSets.all {
|
||||
languageSettings.progressiveMode = true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
extensions.findByType<PublishingExtension>()?.apply {
|
||||
publications.filterIsInstance<MavenPublication>().forEach { publication ->
|
||||
if (publication.name == "kotlinMultiplatform") {
|
||||
// for our root metadata publication, set artifactId with a package and project name
|
||||
publication.artifactId = project.name
|
||||
} else {
|
||||
// for targets, set artifactId with a package, project name and target name (e.g. iosX64)
|
||||
publication.artifactId = "${project.name}-${publication.name}"
|
||||
}
|
||||
}
|
||||
|
||||
// Create empty jar for sources classifier to satisfy maven requirements
|
||||
val stubSources by tasks.registering(Jar::class) {
|
||||
archiveClassifier.set("sources")
|
||||
//from(sourceSets.main.get().allSource)
|
||||
}
|
||||
|
||||
// Create empty jar for javadoc classifier to satisfy maven requirements
|
||||
val stubJavadoc by tasks.registering(Jar::class) {
|
||||
archiveClassifier.set("javadoc")
|
||||
}
|
||||
|
||||
extensions.findByType<KotlinMultiplatformExtension>()?.apply {
|
||||
|
||||
targets.forEach { target ->
|
||||
val publication = publications.findByName(target.name) as MavenPublication
|
||||
|
||||
// Patch publications with fake javadoc
|
||||
publication.artifact(stubJavadoc)
|
||||
|
||||
// Remove gradle metadata publishing from all targets which are not native
|
||||
// if (target.platformType.name != "native") {
|
||||
// publication.gradleModuleMetadataFile = null
|
||||
// tasks.matching { it.name == "generateMetadataFileFor${name.capitalize()}Publication" }.all {
|
||||
// onlyIf { false }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
// Actually, probably we should apply it to plugins explicitly
|
||||
// We also can merge them to single kmath-publish plugin
|
||||
if (name.startsWith("kmath")) {
|
||||
apply(plugin = "bintray-config")
|
||||
apply(plugin = "artifactory-config")
|
||||
}
|
||||
// dokka {
|
||||
// outputFormat = "html"
|
||||
// outputDirectory = javadoc.destinationDir
|
||||
// }
|
||||
//
|
||||
// task dokkaJar (type: Jar, dependsOn: dokka) {
|
||||
// from javadoc . destinationDir
|
||||
// classifier = "javadoc"
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
|
18
buildSrc/build.gradle.kts
Normal file
18
buildSrc/build.gradle.kts
Normal file
@ -0,0 +1,18 @@
|
||||
plugins {
|
||||
`kotlin-dsl`
|
||||
}
|
||||
|
||||
repositories {
|
||||
gradlePluginPortal()
|
||||
jcenter()
|
||||
}
|
||||
|
||||
val kotlinVersion = "1.3.31"
|
||||
|
||||
// Add plugins used in buildSrc as dependencies, also we should specify version only here
|
||||
dependencies {
|
||||
implementation("org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlinVersion")
|
||||
implementation("org.jfrog.buildinfo:build-info-extractor-gradle:4.9.5")
|
||||
implementation("com.jfrog.bintray.gradle:gradle-bintray-plugin:1.8.4")
|
||||
implementation("com.moowork.gradle:gradle-node-plugin:1.3.1")
|
||||
}
|
0
buildSrc/settings.gradle.kts
Normal file
0
buildSrc/settings.gradle.kts
Normal file
10
buildSrc/src/main/kotlin/Versions.kt
Normal file
10
buildSrc/src/main/kotlin/Versions.kt
Normal file
@ -0,0 +1,10 @@
|
||||
// Instead of defining runtime properties and use them dynamically
|
||||
// define version in buildSrc and have autocompletion and compile-time check
|
||||
// Also dependencies itself can be moved here
|
||||
object Versions {
|
||||
val ioVersion = "0.1.8"
|
||||
val coroutinesVersion = "1.2.1"
|
||||
val atomicfuVersion = "0.12.6"
|
||||
// This version is not used and IDEA shows this property as unused
|
||||
val dokkaVersion = "0.9.18"
|
||||
}
|
38
buildSrc/src/main/kotlin/artifactory-config.gradle.kts
Normal file
38
buildSrc/src/main/kotlin/artifactory-config.gradle.kts
Normal file
@ -0,0 +1,38 @@
|
||||
import groovy.lang.GroovyObject
|
||||
import org.jfrog.gradle.plugin.artifactory.dsl.PublisherConfig
|
||||
import org.jfrog.gradle.plugin.artifactory.dsl.ResolverConfig
|
||||
|
||||
plugins {
|
||||
id("com.jfrog.artifactory")
|
||||
}
|
||||
|
||||
artifactory {
|
||||
val artifactoryUser: String? by project
|
||||
val artifactoryPassword: String? by project
|
||||
val artifactoryContextUrl = "http://npm.mipt.ru:8081/artifactory"
|
||||
|
||||
setContextUrl(artifactoryContextUrl)//The base Artifactory URL if not overridden by the publisher/resolver
|
||||
publish(delegateClosureOf<PublisherConfig> {
|
||||
repository(delegateClosureOf<GroovyObject> {
|
||||
setProperty("repoKey", "gradle-dev-local")
|
||||
setProperty("username", artifactoryUser)
|
||||
setProperty("password", artifactoryPassword)
|
||||
})
|
||||
|
||||
defaults(delegateClosureOf<GroovyObject>{
|
||||
invokeMethod("publications", arrayOf("jvm", "js", "kotlinMultiplatform", "metadata"))
|
||||
//TODO: This property is not available for ArtifactoryTask
|
||||
//setProperty("publishBuildInfo", false)
|
||||
setProperty("publishArtifacts", true)
|
||||
setProperty("publishPom", true)
|
||||
setProperty("publishIvy", false)
|
||||
})
|
||||
})
|
||||
resolve(delegateClosureOf<ResolverConfig> {
|
||||
repository(delegateClosureOf<GroovyObject> {
|
||||
setProperty("repoKey", "gradle-dev")
|
||||
setProperty("username", artifactoryUser)
|
||||
setProperty("password", artifactoryPassword)
|
||||
})
|
||||
})
|
||||
}
|
97
buildSrc/src/main/kotlin/bintray-config.gradle.kts
Normal file
97
buildSrc/src/main/kotlin/bintray-config.gradle.kts
Normal file
@ -0,0 +1,97 @@
|
||||
@file:Suppress("UnstableApiUsage")
|
||||
|
||||
import com.jfrog.bintray.gradle.BintrayExtension.PackageConfig
|
||||
import com.jfrog.bintray.gradle.BintrayExtension.VersionConfig
|
||||
|
||||
// Old bintray.gradle script converted to real Gradle plugin (precompiled script plugin)
|
||||
// It now has own dependencies and support type safe accessors
|
||||
// Syntax is pretty close to what we had in Groovy
|
||||
// (excluding Property.set and bintray dynamic configs)
|
||||
|
||||
plugins {
|
||||
id("com.jfrog.bintray")
|
||||
`maven-publish`
|
||||
}
|
||||
|
||||
val vcs = "https://github.com/mipt-npm/kmath"
|
||||
|
||||
// Configure publishing
|
||||
publishing {
|
||||
repositories {
|
||||
maven("https://bintray.com/mipt-npm/scientifik")
|
||||
}
|
||||
|
||||
// Process each publication we have in this project
|
||||
publications.filterIsInstance<MavenPublication>().forEach { publication ->
|
||||
|
||||
// use type safe pom config GSL insterad of old dynamic
|
||||
publication.pom {
|
||||
name.set(project.name)
|
||||
description.set(project.description)
|
||||
url.set(vcs)
|
||||
|
||||
licenses {
|
||||
license {
|
||||
name.set("The Apache Software License, Version 2.0")
|
||||
url.set("http://www.apache.org/licenses/LICENSE-2.0.txt")
|
||||
distribution.set("repo")
|
||||
}
|
||||
}
|
||||
developers {
|
||||
developer {
|
||||
id.set("MIPT-NPM")
|
||||
name.set("MIPT nuclear physics methods laboratory")
|
||||
organization.set("MIPT")
|
||||
organizationUrl.set("http://npm.mipt.ru")
|
||||
}
|
||||
|
||||
}
|
||||
scm {
|
||||
url.set(vcs)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
bintray {
|
||||
// delegates for runtime properties
|
||||
val bintrayUser: String? by project
|
||||
val bintrayApiKey: String? by project
|
||||
user = bintrayUser ?: System.getenv("BINTRAY_USER")
|
||||
key = bintrayApiKey ?: System.getenv("BINTRAY_API_KEY")
|
||||
publish = true
|
||||
override = true // for multi-platform Kotlin/Native publishing
|
||||
|
||||
// We have to use delegateClosureOf because bintray supports only dynamic groovy syntax
|
||||
// this is a problem of this plugin
|
||||
pkg(delegateClosureOf<PackageConfig> {
|
||||
userOrg = "mipt-npm"
|
||||
repo = "scientifik"
|
||||
name = "scientifik.kmath"
|
||||
issueTrackerUrl = "https://github.com/mipt-npm/kmath/issues"
|
||||
setLicenses("Apache-2.0")
|
||||
vcsUrl = vcs
|
||||
version(delegateClosureOf<VersionConfig> {
|
||||
name = project.version.toString()
|
||||
vcsTag = project.version.toString()
|
||||
released = java.util.Date().toString()
|
||||
})
|
||||
})
|
||||
|
||||
tasks {
|
||||
bintrayUpload {
|
||||
dependsOn(publishToMavenLocal)
|
||||
doFirst {
|
||||
setPublications(project.publishing.publications
|
||||
.filterIsInstance<MavenPublication>()
|
||||
.filter { !it.name.contains("-test") && it.name != "kotlinMultiplatform" }
|
||||
.map {
|
||||
println("""Uploading artifact "${it.groupId}:${it.artifactId}:${it.version}" from publication "${it.name}""")
|
||||
it.name //https://github.com/bintray/gradle-bintray-plugin/issues/256
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
44
buildSrc/src/main/kotlin/js-test.gradle.kts
Normal file
44
buildSrc/src/main/kotlin/js-test.gradle.kts
Normal file
@ -0,0 +1,44 @@
|
||||
import com.moowork.gradle.node.npm.NpmTask
|
||||
import com.moowork.gradle.node.task.NodeTask
|
||||
import org.gradle.kotlin.dsl.*
|
||||
import org.jetbrains.kotlin.gradle.tasks.Kotlin2JsCompile
|
||||
|
||||
plugins {
|
||||
id("com.moowork.node")
|
||||
kotlin("multiplatform")
|
||||
}
|
||||
|
||||
node {
|
||||
nodeModulesDir = file("$buildDir/node_modules")
|
||||
}
|
||||
|
||||
val compileKotlinJs by tasks.getting(Kotlin2JsCompile::class)
|
||||
val compileTestKotlinJs by tasks.getting(Kotlin2JsCompile::class)
|
||||
|
||||
val populateNodeModules by tasks.registering(Copy::class) {
|
||||
dependsOn(compileKotlinJs)
|
||||
from(compileKotlinJs.destinationDir)
|
||||
|
||||
kotlin.js().compilations["test"].runtimeDependencyFiles.forEach {
|
||||
if (it.exists() && !it.isDirectory) {
|
||||
from(zipTree(it.absolutePath).matching { include("*.js") })
|
||||
}
|
||||
}
|
||||
|
||||
into("$buildDir/node_modules")
|
||||
}
|
||||
|
||||
val installMocha by tasks.registering(NpmTask::class) {
|
||||
setWorkingDir(buildDir)
|
||||
setArgs(listOf("install", "mocha"))
|
||||
}
|
||||
|
||||
val runMocha by tasks.registering(NodeTask::class) {
|
||||
dependsOn(compileTestKotlinJs, populateNodeModules, installMocha)
|
||||
setScript(file("$buildDir/node_modules/mocha/bin/mocha"))
|
||||
setArgs(listOf(compileTestKotlinJs.outputFile))
|
||||
}
|
||||
|
||||
tasks["jsTest"].dependsOn(runMocha)
|
||||
|
||||
|
118
buildSrc/src/main/kotlin/multiplatform-config.gradle.kts
Normal file
118
buildSrc/src/main/kotlin/multiplatform-config.gradle.kts
Normal file
@ -0,0 +1,118 @@
|
||||
import org.gradle.kotlin.dsl.*
|
||||
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
`maven-publish`
|
||||
}
|
||||
|
||||
|
||||
kotlin {
|
||||
jvm {
|
||||
compilations.all {
|
||||
kotlinOptions {
|
||||
jvmTarget = "1.8"
|
||||
// This was used in kmath-koma, but probably if we need it better to apply it for all modules
|
||||
freeCompilerArgs = listOf("-progressive")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
js {
|
||||
compilations.all {
|
||||
kotlinOptions {
|
||||
metaInfo = true
|
||||
sourceMap = true
|
||||
sourceMapEmbedSources = "always"
|
||||
moduleKind = "commonjs"
|
||||
}
|
||||
}
|
||||
|
||||
compilations.named("main") {
|
||||
kotlinOptions {
|
||||
main = "call"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sourceSets.invoke {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(kotlin("stdlib"))
|
||||
}
|
||||
}
|
||||
commonTest {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
"jvmMain" {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-jdk8"))
|
||||
}
|
||||
}
|
||||
"jvmTest" {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
"jsMain" {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-js"))
|
||||
}
|
||||
}
|
||||
"jsTest" {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
targets.all {
|
||||
sourceSets.all {
|
||||
languageSettings.progressiveMode = true
|
||||
languageSettings.enableLanguageFeature("InlineClasses")
|
||||
}
|
||||
}
|
||||
|
||||
// Create empty jar for sources classifier to satisfy maven requirements
|
||||
val stubSources by tasks.registering(Jar::class){
|
||||
archiveClassifier.set("sources")
|
||||
//from(sourceSets.main.get().allSource)
|
||||
}
|
||||
|
||||
// Create empty jar for javadoc classifier to satisfy maven requirements
|
||||
val stubJavadoc by tasks.registering(Jar::class) {
|
||||
archiveClassifier.set("javadoc")
|
||||
}
|
||||
|
||||
|
||||
publishing {
|
||||
|
||||
publications.filterIsInstance<MavenPublication>().forEach { publication ->
|
||||
if (publication.name == "kotlinMultiplatform") {
|
||||
// for our root metadata publication, set artifactId with a package and project name
|
||||
publication.artifactId = project.name
|
||||
} else {
|
||||
// for targets, set artifactId with a package, project name and target name (e.g. iosX64)
|
||||
publication.artifactId = "${project.name}-${publication.name}"
|
||||
}
|
||||
}
|
||||
|
||||
targets.all {
|
||||
val publication = publications.findByName(name) as MavenPublication
|
||||
|
||||
// Patch publications with fake javadoc
|
||||
publication.artifact(stubJavadoc.get())
|
||||
}
|
||||
}
|
||||
|
||||
// Apply JS test configuration
|
||||
val runJsTests by ext(false)
|
||||
|
||||
if (runJsTests) {
|
||||
apply(plugin = "js-test")
|
||||
}
|
||||
|
||||
}
|
@ -1 +1,19 @@
|
||||
**TODO**
|
||||
## Basic linear algebra layout
|
||||
|
||||
Kmath support for linear algebra organized in a context-oriented way. Meaning that operations are in most cases declared
|
||||
in context classes, and are not the members of classes that store data. This allows more flexible approach to maintain multiple
|
||||
back-ends. The new operations added as extensions to contexts instead of being member functions of data structures.
|
||||
|
||||
Two major contexts used for linear algebra and hyper-geometry:
|
||||
|
||||
* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its typealias `Point` used for geometry).
|
||||
|
||||
* `MatrixContext` forms a space-like context for 2d-structures. It does not store matrix size and therefore does not implement
|
||||
`Space` interface (it is not possible to create zero element without knowing the matrix size).
|
||||
|
||||
## Vector spaces
|
||||
|
||||
|
||||
## Matrix operations
|
||||
|
||||
## Back-end overview
|
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
Binary file not shown.
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-5.0-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-5.4-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
18
gradlew
vendored
18
gradlew
vendored
@ -1,5 +1,21 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
#
|
||||
# Copyright 2015 the original author or authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
@ -28,7 +44,7 @@ APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m"'
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
18
gradlew.bat
vendored
18
gradlew.bat
vendored
@ -1,3 +1,19 @@
|
||||
@rem
|
||||
@rem Copyright 2015 the original author or authors.
|
||||
@rem
|
||||
@rem Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@rem you may not use this file except in compliance with the License.
|
||||
@rem You may obtain a copy of the License at
|
||||
@rem
|
||||
@rem http://www.apache.org/licenses/LICENSE-2.0
|
||||
@rem
|
||||
@rem Unless required by applicable law or agreed to in writing, software
|
||||
@rem distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@rem See the License for the specific language governing permissions and
|
||||
@rem limitations under the License.
|
||||
@rem
|
||||
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@ -14,7 +30,7 @@ set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS="-Xmx64m"
|
||||
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
@ -7,7 +7,7 @@ description = "Commons math binding for kmath"
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api(project(":kmath-sequential"))
|
||||
api(project(":kmath-coroutines"))
|
||||
api("org.apache.commons:commons-math3:3.6.1")
|
||||
testImplementation("org.jetbrains.kotlin:kotlin-test")
|
||||
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
|
||||
|
@ -3,13 +3,14 @@ package scientifik.kmath.linear
|
||||
import org.apache.commons.math3.linear.*
|
||||
import org.apache.commons.math3.linear.RealMatrix
|
||||
import org.apache.commons.math3.linear.RealVector
|
||||
import scientifik.kmath.structures.Matrix
|
||||
|
||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : Matrix<Double> {
|
||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||
override val rowNum: Int get() = origin.rowDimension
|
||||
override val colNum: Int get() = origin.columnDimension
|
||||
|
||||
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
||||
if(origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||
if (origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||
}.toSet()
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||
@ -26,7 +27,7 @@ fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
||||
CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
fun RealMatrix.toMatrix() = CMMatrix(this)
|
||||
fun RealMatrix.asMatrix() = CMMatrix(this)
|
||||
|
||||
class CMVector(val origin: RealVector) : Point<Double> {
|
||||
override val size: Int get() = origin.dimension
|
||||
@ -45,46 +46,31 @@ fun Point<Double>.toCM(): CMVector = if (this is CMVector) {
|
||||
|
||||
fun RealVector.toPoint() = CMVector(this)
|
||||
|
||||
object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
|
||||
|
||||
object CMMatrixContext : MatrixContext<Double> {
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
||||
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
|
||||
return CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
||||
}
|
||||
|
||||
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
||||
}
|
||||
|
||||
override fun inverse(a: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.inverse.toMatrix()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.dot(other: Matrix<Double>) =
|
||||
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
||||
|
||||
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
|
||||
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
|
||||
|
||||
|
||||
override fun Matrix<Double>.unaryMinus(): CMMatrix =
|
||||
produce(rowNum, colNum) { i, j -> -get(i, j) }
|
||||
|
||||
override fun Matrix<Double>.plus(b: Matrix<Double>) =
|
||||
CMMatrix(this.toCM().origin.multiply(b.toCM().origin))
|
||||
override fun add(a: Matrix<Double>, b: Matrix<Double>) =
|
||||
CMMatrix(a.toCM().origin.multiply(b.toCM().origin))
|
||||
|
||||
override fun Matrix<Double>.minus(b: Matrix<Double>) =
|
||||
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
|
||||
|
||||
override fun Matrix<Double>.times(value: Double) =
|
||||
CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble()))
|
||||
override fun multiply(a: Matrix<Double>, k: Number) =
|
||||
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
|
||||
|
||||
override fun Matrix<Double>.times(value: Double): Matrix<Double> = produce(rowNum,colNum){i,j-> get(i,j)*value}
|
||||
}
|
||||
|
||||
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
||||
|
@ -0,0 +1,39 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import org.apache.commons.math3.linear.*
|
||||
import scientifik.kmath.structures.Matrix
|
||||
|
||||
enum class CMDecomposition {
|
||||
LUP,
|
||||
QR,
|
||||
RRQR,
|
||||
EIGEN,
|
||||
CHOLESKY
|
||||
}
|
||||
|
||||
|
||||
fun CMMatrixContext.solver(a: Matrix<Double>, decomposition: CMDecomposition = CMDecomposition.LUP) =
|
||||
when (decomposition) {
|
||||
CMDecomposition.LUP -> LUDecomposition(a.toCM().origin).solver
|
||||
CMDecomposition.RRQR -> RRQRDecomposition(a.toCM().origin).solver
|
||||
CMDecomposition.QR -> QRDecomposition(a.toCM().origin).solver
|
||||
CMDecomposition.EIGEN -> EigenDecomposition(a.toCM().origin).solver
|
||||
CMDecomposition.CHOLESKY -> CholeskyDecomposition(a.toCM().origin).solver
|
||||
}
|
||||
|
||||
fun CMMatrixContext.solve(
|
||||
a: Matrix<Double>,
|
||||
b: Matrix<Double>,
|
||||
decomposition: CMDecomposition = CMDecomposition.LUP
|
||||
) = solver(a, decomposition).solve(b.toCM().origin).asMatrix()
|
||||
|
||||
fun CMMatrixContext.solve(
|
||||
a: Matrix<Double>,
|
||||
b: Point<Double>,
|
||||
decomposition: CMDecomposition = CMDecomposition.LUP
|
||||
) = solver(a, decomposition).solve(b.toCM().origin).toPoint()
|
||||
|
||||
fun CMMatrixContext.inverse(
|
||||
a: Matrix<Double>,
|
||||
decomposition: CMDecomposition = CMDecomposition.LUP
|
||||
) = solver(a, decomposition).inverse.asMatrix()
|
@ -1,10 +1,12 @@
|
||||
package scientifik.kmath.transform
|
||||
|
||||
import kotlinx.coroutines.FlowPreview
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.map
|
||||
import org.apache.commons.math3.transform.*
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.sequential.Processor
|
||||
import scientifik.kmath.sequential.Producer
|
||||
import scientifik.kmath.sequential.map
|
||||
import scientifik.kmath.streaming.chunked
|
||||
import scientifik.kmath.streaming.spread
|
||||
import scientifik.kmath.structures.*
|
||||
|
||||
|
||||
@ -33,54 +35,75 @@ object Transformations {
|
||||
fun fourier(
|
||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): BufferTransform<Complex, Complex> = {
|
||||
): SuspendBufferTransform<Complex, Complex> = {
|
||||
FastFourierTransformer(normalization).transform(it.toArray(), direction).asBuffer()
|
||||
}
|
||||
|
||||
fun realFourier(
|
||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): BufferTransform<Double, Complex> = {
|
||||
): SuspendBufferTransform<Double, Complex> = {
|
||||
FastFourierTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
||||
}
|
||||
|
||||
fun sine(
|
||||
normalization: DstNormalization = DstNormalization.STANDARD_DST_I,
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): BufferTransform<Double, Double> = {
|
||||
): SuspendBufferTransform<Double, Double> = {
|
||||
FastSineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
||||
}
|
||||
|
||||
fun cosine(
|
||||
normalization: DctNormalization = DctNormalization.STANDARD_DCT_I,
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): BufferTransform<Double, Double> = {
|
||||
): SuspendBufferTransform<Double, Double> = {
|
||||
FastCosineTransformer(normalization).transform(it.asArray(), direction).asBuffer()
|
||||
}
|
||||
|
||||
fun hadamard(
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): BufferTransform<Double, Double> = {
|
||||
): SuspendBufferTransform<Double, Double> = {
|
||||
FastHadamardTransformer().transform(it.asArray(), direction).asBuffer()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process given [Producer] with commons-math fft transformation
|
||||
* Process given [Flow] with commons-math fft transformation
|
||||
*/
|
||||
fun Producer<Buffer<Complex>>.FFT(
|
||||
@FlowPreview
|
||||
fun Flow<Buffer<Complex>>.FFT(
|
||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): Processor<Buffer<Complex>, Buffer<Complex>> {
|
||||
): Flow<Buffer<Complex>> {
|
||||
val transform = Transformations.fourier(normalization, direction)
|
||||
return map { transform(it) }
|
||||
}
|
||||
|
||||
@FlowPreview
|
||||
@JvmName("realFFT")
|
||||
fun Producer<Buffer<Double>>.FFT(
|
||||
fun Flow<Buffer<Double>>.FFT(
|
||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): Processor<Buffer<Double>, Buffer<Complex>> {
|
||||
): Flow<Buffer<Complex>> {
|
||||
val transform = Transformations.realFourier(normalization, direction)
|
||||
return map { transform(it) }
|
||||
}
|
||||
return map(transform)
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a continous flow of real numbers in FFT splitting it in chunks of [bufferSize].
|
||||
*/
|
||||
@FlowPreview
|
||||
@JvmName("realFFT")
|
||||
fun Flow<Double>.FFT(
|
||||
bufferSize: Int = Int.MAX_VALUE,
|
||||
normalization: DftNormalization = DftNormalization.STANDARD,
|
||||
direction: TransformType = TransformType.FORWARD
|
||||
): Flow<Complex> {
|
||||
return chunked(bufferSize).FFT(normalization,direction).spread()
|
||||
}
|
||||
|
||||
/**
|
||||
* Map a complex flow into real flow by taking real part of each number
|
||||
*/
|
||||
@FlowPreview
|
||||
fun Flow<Complex>.real(): Flow<Double> = map{it.re}
|
@ -1,51 +1,13 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
`multiplatform-config`
|
||||
}
|
||||
|
||||
val ioVersion: String by rootProject.extra
|
||||
|
||||
|
||||
kotlin {
|
||||
jvm()
|
||||
js()
|
||||
|
||||
sourceSets {
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-memory"))
|
||||
api(kotlin("stdlib"))
|
||||
}
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-memory"))
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-jdk8"))
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
val jsMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-js"))
|
||||
}
|
||||
}
|
||||
val jsTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
}
|
||||
// mingwMain {
|
||||
// }
|
||||
// mingwTest {
|
||||
// }
|
||||
}
|
||||
//mingwMain {}
|
||||
//mingwTest {}
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.*
|
||||
import kotlin.jvm.JvmSynthetic
|
||||
|
||||
/**
|
||||
* Basic implementation of Matrix space based on [NDStructure]
|
||||
@ -18,6 +18,22 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
|
||||
}
|
||||
|
||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||
|
||||
override val elementContext = RealField
|
||||
|
||||
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
|
||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer)
|
||||
}
|
||||
|
||||
class BufferMatrix<T : Any>(
|
||||
@ -25,7 +41,7 @@ class BufferMatrix<T : Any>(
|
||||
override val colNum: Int,
|
||||
val buffer: Buffer<out T>,
|
||||
override val features: Set<MatrixFeature> = emptySet()
|
||||
) : Matrix<T> {
|
||||
) : FeaturedMatrix<T> {
|
||||
|
||||
init {
|
||||
if (buffer.size != rowNum * colNum) {
|
||||
|
@ -0,0 +1,96 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.Structure2D
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* A 2d structure plus optional matrix-specific features
|
||||
*/
|
||||
interface FeaturedMatrix<T : Any> : Matrix<T> {
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
val features: Set<MatrixFeature>
|
||||
|
||||
/**
|
||||
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure.
|
||||
*
|
||||
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
|
||||
* add only those features that are valid.
|
||||
*/
|
||||
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||
MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
fun <T : Any> Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
|
||||
|
||||
class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
|
||||
operator fun invoke(vararg elements: T): FeaturedMatrix<T> {
|
||||
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
}
|
||||
|
||||
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet()
|
||||
|
||||
/**
|
||||
* Check if matrix has the given feature class
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
|
||||
features.find { it is T } != null
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
|
||||
features.filterIsInstance<T>().firstOrNull()
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
|
||||
|
||||
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
/**
|
||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||
*/
|
||||
fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
|
||||
return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||
this.colNum,
|
||||
this.rowNum,
|
||||
setOf(TransposedFeature(this))
|
||||
) { i, j -> get(j, i) }
|
||||
}
|
||||
|
||||
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }
|
@ -1,26 +1,31 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.MutableBuffer
|
||||
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
|
||||
import scientifik.kmath.structures.MutableBufferFactory
|
||||
import scientifik.kmath.structures.NDStructure
|
||||
import scientifik.kmath.structures.get
|
||||
import scientifik.kmath.structures.BufferAccessor2D
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.Structure2D
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
class LUPDecomposition<T : Comparable<T>>(
|
||||
private val elementContext: Ring<T>,
|
||||
internal val lu: NDStructure<T>,
|
||||
/**
|
||||
* Common implementation of [LUPDecompositionFeature]
|
||||
*/
|
||||
class LUPDecomposition<T : Any>(
|
||||
val context: GenericMatrixContext<T, out Field<T>>,
|
||||
val lu: Structure2D<T>,
|
||||
val pivot: IntArray,
|
||||
private val even: Boolean
|
||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||
|
||||
val elementContext get() = context.elementContext
|
||||
|
||||
/**
|
||||
* Returns the matrix L of the decomposition.
|
||||
*
|
||||
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
||||
*/
|
||||
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
||||
override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
||||
when {
|
||||
j < i -> lu[i, j]
|
||||
j == i -> elementContext.one
|
||||
@ -34,7 +39,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
*
|
||||
* U is an upper-triangular matrix including the diagonal
|
||||
*/
|
||||
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
||||
override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
||||
if (j >= i) lu[i, j] else elementContext.zero
|
||||
}
|
||||
|
||||
@ -45,7 +50,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
||||
* each row and each column, all other elements being set to [Ring.zero].
|
||||
*/
|
||||
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j == pivot[i]) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
@ -62,70 +67,86 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
|
||||
}
|
||||
|
||||
|
||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
val context: GenericMatrixContext<T, F>,
|
||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||
val singularityCheck: (T) -> Boolean
|
||||
) : LinearSolver<T> {
|
||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.abs(value: T) =
|
||||
if (value > elementContext.zero) value else with(elementContext) { -value }
|
||||
|
||||
|
||||
private fun abs(value: T) =
|
||||
if (value > context.elementContext.zero) value else with(context.elementContext) { -value }
|
||||
/**
|
||||
* Create a lup decomposition of generic matrix
|
||||
*/
|
||||
fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
type: KClass<T>,
|
||||
matrix: Matrix<T>,
|
||||
checkSingular: (T) -> Boolean
|
||||
): LUPDecomposition<T> {
|
||||
if (matrix.rowNum != matrix.colNum) {
|
||||
error("LU decomposition supports only square matrices")
|
||||
}
|
||||
|
||||
fun buildDecomposition(matrix: Matrix<T>): LUPDecomposition<T> {
|
||||
if (matrix.rowNum != matrix.colNum) {
|
||||
error("LU decomposition supports only square matrices")
|
||||
}
|
||||
val m = matrix.colNum
|
||||
val pivot = IntArray(matrix.rowNum)
|
||||
|
||||
val m = matrix.colNum
|
||||
val pivot = IntArray(matrix.rowNum)
|
||||
//TODO just waits for KEEP-176
|
||||
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||
elementContext.run {
|
||||
|
||||
val lu = Mutable2DStructure.create(matrix.rowNum, matrix.colNum, bufferFactory) { i, j ->
|
||||
matrix[i, j]
|
||||
}
|
||||
val lu = create(matrix)
|
||||
|
||||
|
||||
with(context.elementContext) {
|
||||
// Initialize permutation array and parity
|
||||
for (row in 0 until m) {
|
||||
pivot[row] = row
|
||||
}
|
||||
var even = true
|
||||
|
||||
// Initialize permutation array and parity
|
||||
for (row in 0 until m) {
|
||||
pivot[row] = row
|
||||
}
|
||||
|
||||
// Loop over columns
|
||||
for (col in 0 until m) {
|
||||
|
||||
// upper
|
||||
for (row in 0 until col) {
|
||||
var sum = lu[row, col]
|
||||
val luRow = lu.row(row)
|
||||
var sum = luRow[col]
|
||||
for (i in 0 until row) {
|
||||
sum -= lu[row, i] * lu[i, col]
|
||||
sum -= luRow[i] * lu[i, col]
|
||||
}
|
||||
lu[row, col] = sum
|
||||
luRow[col] = sum
|
||||
}
|
||||
|
||||
// lower
|
||||
val max = (col until m).maxBy { row ->
|
||||
var sum = lu[row, col]
|
||||
var max = col // permutation row
|
||||
var largest = -one
|
||||
for (row in col until m) {
|
||||
val luRow = lu.row(row)
|
||||
var sum = luRow[col]
|
||||
for (i in 0 until col) {
|
||||
sum -= lu[row, i] * lu[i, col]
|
||||
sum -= luRow[i] * lu[i, col]
|
||||
}
|
||||
lu[row, col] = sum
|
||||
luRow[col] = sum
|
||||
|
||||
abs(sum)
|
||||
} ?: col
|
||||
// maintain best permutation choice
|
||||
if (abs(sum) > largest) {
|
||||
largest = abs(sum)
|
||||
max = row
|
||||
}
|
||||
}
|
||||
|
||||
// Singularity check
|
||||
if (singularityCheck(lu[max, col])) {
|
||||
error("Singular matrix")
|
||||
if (checkSingular(abs(lu[max, col]))) {
|
||||
error("The matrix is singular")
|
||||
}
|
||||
|
||||
// Pivot if necessary
|
||||
if (max != col) {
|
||||
val luMax = lu.row(max)
|
||||
val luCol = lu.row(col)
|
||||
for (i in 0 until m) {
|
||||
lu[max, i] = lu[col, i]
|
||||
lu[col, i] = lu[max, i]
|
||||
val tmp = luMax[i]
|
||||
luMax[i] = luCol[i]
|
||||
luCol[i] = tmp
|
||||
}
|
||||
val temp = pivot[max]
|
||||
pivot[max] = pivot[col]
|
||||
@ -136,70 +157,96 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
// Divide the lower elements by the "winning" diagonal elt.
|
||||
val luDiag = lu[col, col]
|
||||
for (row in col + 1 until m) {
|
||||
lu[row, col] = lu[row, col] / luDiag
|
||||
lu[row, col] /= luDiag
|
||||
}
|
||||
}
|
||||
return LUPDecomposition(context.elementContext, lu, pivot, even)
|
||||
|
||||
return LUPDecomposition(this@lup, lu.collect(), pivot, even)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Produce a matrix with added decomposition feature
|
||||
*/
|
||||
fun decompose(matrix: Matrix<T>): Matrix<T> {
|
||||
if (matrix.hasFeature<LUPDecomposition<*>>()) {
|
||||
return matrix
|
||||
} else {
|
||||
val decomposition = buildDecomposition(matrix)
|
||||
return VirtualMatrix.wrap(matrix, decomposition)
|
||||
}
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
matrix: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean
|
||||
) = lup(T::class, matrix, checkSingular)
|
||||
|
||||
fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>) = lup(Double::class, matrix) { it < 1e-11 }
|
||||
|
||||
fun <T : Any> LUPDecomposition<T>.solve(type: KClass<T>, matrix: Matrix<T>): Matrix<T> {
|
||||
|
||||
if (matrix.rowNum != pivot.size) {
|
||||
error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}")
|
||||
}
|
||||
|
||||
BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run {
|
||||
elementContext.run {
|
||||
|
||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||
if (b.rowNum != a.colNum) {
|
||||
error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}")
|
||||
}
|
||||
// Apply permutations to b
|
||||
val bp = create { i, j -> zero }
|
||||
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
val decomposition = a.getFeature() ?: buildDecomposition(a)
|
||||
|
||||
with(decomposition) {
|
||||
with(context.elementContext) {
|
||||
// Apply permutations to b
|
||||
val bp = Mutable2DStructure.create(a.rowNum, a.colNum, bufferFactory) { i, j ->
|
||||
b[pivot[i], j]
|
||||
for (row in 0 until pivot.size) {
|
||||
val bpRow = bp.row(row)
|
||||
val pRow = pivot[row]
|
||||
for (col in 0 until matrix.colNum) {
|
||||
bpRow[col] = matrix[pRow, col]
|
||||
}
|
||||
|
||||
// Solve LY = b
|
||||
for (col in 0 until a.rowNum) {
|
||||
for (i in col + 1 until a.rowNum) {
|
||||
for (j in 0 until b.colNum) {
|
||||
bp[i, j] -= bp[col, j] * lu[i, col]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Solve UX = Y
|
||||
for (col in a.rowNum - 1 downTo 0) {
|
||||
for (j in 0 until b.colNum) {
|
||||
bp[col, j] /= lu[col, col]
|
||||
}
|
||||
for (i in 0 until col) {
|
||||
for (j in 0 until b.colNum) {
|
||||
bp[i, j] -= bp[col, j] * lu[i, col]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return context.produce(a.rowNum, a.colNum) { i, j -> bp[i, j] }
|
||||
}
|
||||
|
||||
// Solve LY = b
|
||||
for (col in 0 until pivot.size) {
|
||||
val bpCol = bp.row(col)
|
||||
for (i in col + 1 until pivot.size) {
|
||||
val bpI = bp.row(i)
|
||||
val luICol = lu[i, col]
|
||||
for (j in 0 until matrix.colNum) {
|
||||
bpI[j] -= bpCol[j] * luICol
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Solve UX = Y
|
||||
for (col in pivot.size - 1 downTo 0) {
|
||||
val bpCol = bp.row(col)
|
||||
val luDiag = lu[col, col]
|
||||
for (j in 0 until matrix.colNum) {
|
||||
bpCol[j] /= luDiag
|
||||
}
|
||||
for (i in 0 until col) {
|
||||
val bpI = bp.row(i)
|
||||
val luICol = lu[i, col]
|
||||
for (j in 0 until matrix.colNum) {
|
||||
bpI[j] -= bpCol[j] * luICol
|
||||
}
|
||||
}
|
||||
}
|
||||
return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun inverse(a: Matrix<T>): Matrix<T> = solve(a, context.one(a.rowNum, a.colNum))
|
||||
inline fun <reified T : Any> LUPDecomposition<T>.solve(matrix: Matrix<T>) = solve(T::class, matrix)
|
||||
|
||||
companion object {
|
||||
val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Solve a linear equation **a*x = b**
|
||||
*/
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.solve(
|
||||
a: Matrix<T>,
|
||||
b: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean
|
||||
): Matrix<T> {
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
|
||||
return decomposition.solve(T::class, b)
|
||||
}
|
||||
|
||||
fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>) =
|
||||
solve(a, b) { it < 1e-11 }
|
||||
|
||||
inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
||||
matrix: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean
|
||||
) = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||
|
||||
fun RealMatrixContext.inverse(matrix: Matrix<Double>) =
|
||||
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
@ -3,6 +3,7 @@ package scientifik.kmath.linear
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Norm
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.VirtualBuffer
|
||||
import scientifik.kmath.structures.asSequence
|
||||
|
||||
@ -10,9 +11,9 @@ import scientifik.kmath.structures.asSequence
|
||||
/**
|
||||
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
||||
*/
|
||||
interface LinearSolver<T : Any> {
|
||||
interface LinearSolver<T : Any> {
|
||||
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.asMatrix()).asPoint()
|
||||
fun inverse(a: Matrix<T>): Matrix<T>
|
||||
}
|
||||
|
||||
@ -32,14 +33,14 @@ object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||
typealias RealVector = Vector<Double, RealField>
|
||||
typealias RealMatrix = Matrix<Double>
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Convert matrix to vector if it is possible
|
||||
*/
|
||||
fun <T: Any> Matrix<T>.toVector(): Point<T> =
|
||||
fun <T : Any> Matrix<T>.asPoint(): Point<T> =
|
||||
if (this.colNum == 1) {
|
||||
VirtualBuffer(rowNum){ get(it, 0) }
|
||||
} else error("Can't convert matrix with more than one column to vector")
|
||||
VirtualBuffer(rowNum) { get(it, 0) }
|
||||
} else {
|
||||
error("Can't convert matrix with more than one column to vector")
|
||||
}
|
||||
|
||||
fun <T: Any> Point<T>.toMatrix(): Matrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
||||
fun <T : Any> Point<T>.asMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
@ -1,213 +0,0 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.sum
|
||||
import scientifik.kmath.structures.*
|
||||
import scientifik.kmath.structures.Buffer.Companion.boxing
|
||||
import kotlin.math.sqrt
|
||||
|
||||
|
||||
interface MatrixContext<T : Any> {
|
||||
/**
|
||||
* Produce a matrix with this context and given dimensions
|
||||
*/
|
||||
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||
|
||||
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
||||
|
||||
infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
||||
|
||||
operator fun Matrix<T>.unaryMinus(): Matrix<T>
|
||||
|
||||
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T>
|
||||
|
||||
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T>
|
||||
|
||||
operator fun Matrix<T>.times(value: T): Matrix<T>
|
||||
|
||||
operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Non-boxing double matrix
|
||||
*/
|
||||
val real = BufferMatrixContext(RealField, Buffer.Companion::auto)
|
||||
|
||||
/**
|
||||
* A structured matrix with custom buffer
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> buffered(
|
||||
ring: R,
|
||||
bufferFactory: BufferFactory<T> = ::boxing
|
||||
): GenericMatrixContext<T, R> =
|
||||
BufferMatrixContext(ring, bufferFactory)
|
||||
|
||||
/**
|
||||
* Automatic buffered matrix, unboxed if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> =
|
||||
buffered(ring, Buffer.Companion::auto)
|
||||
}
|
||||
}
|
||||
|
||||
interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||
/**
|
||||
* The ring context for matrix elements
|
||||
*/
|
||||
val elementContext: R
|
||||
|
||||
/**
|
||||
* Produce a point compatible with matrix space
|
||||
*/
|
||||
fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||
|
||||
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||
//TODO add typed error
|
||||
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
||||
return produce(rowNum, other.colNum) { i, j ->
|
||||
val row = rows[i]
|
||||
val column = other.columns[j]
|
||||
with(elementContext) {
|
||||
sum(row.asSequence().zip(column.asSequence(), ::multiply))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||
//TODO add typed error
|
||||
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
||||
return point(rowNum) { i ->
|
||||
val row = rows[i]
|
||||
with(elementContext) {
|
||||
sum(row.asSequence().zip(vector.asSequence(), ::multiply))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override operator fun Matrix<T>.unaryMinus() =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||
|
||||
override operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]")
|
||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||
}
|
||||
|
||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||
}
|
||||
|
||||
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
|
||||
|
||||
operator fun Number.times(matrix: Matrix<T>): Matrix<T> = matrix * this
|
||||
|
||||
override fun Matrix<T>.times(value: T): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Specialized 2-d structure
|
||||
*/
|
||||
interface Matrix<T : Any> : Structure2D<T> {
|
||||
val rowNum: Int
|
||||
val colNum: Int
|
||||
|
||||
val features: Set<MatrixFeature>
|
||||
|
||||
/**
|
||||
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure.
|
||||
*
|
||||
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
|
||||
* add only those features that are valid.
|
||||
*/
|
||||
fun suggestFeature(vararg features: MatrixFeature): Matrix<T>
|
||||
|
||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
val rows: Point<Point<T>>
|
||||
get() = VirtualBuffer(rowNum) { i ->
|
||||
VirtualBuffer(colNum) { j -> get(i, j) }
|
||||
}
|
||||
|
||||
val columns: Point<Point<T>>
|
||||
get() = VirtualBuffer(colNum) { j ->
|
||||
VirtualBuffer(rowNum) { i -> get(i, j) }
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||
for (i in (0 until rowNum)) {
|
||||
for (j in (0 until colNum)) {
|
||||
yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||
MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
fun <T : Any> square(vararg elements: T): Matrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
fun <T : Any> build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
|
||||
}
|
||||
}
|
||||
|
||||
class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
|
||||
operator fun invoke(vararg elements: T): Matrix<T> {
|
||||
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if matrix has the given feature class
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { it is T } != null
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull()
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
|
||||
|
||||
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
/**
|
||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> Matrix<T>.transpose(): Matrix<T> {
|
||||
return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||
this.colNum,
|
||||
this.rowNum,
|
||||
setOf(TransposedFeature(this))
|
||||
) { i, j -> get(j, i) }
|
||||
}
|
||||
|
||||
infix fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = with(MatrixContext.real) { dot(other) }
|
@ -0,0 +1,105 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.SpaceOperations
|
||||
import scientifik.kmath.operations.sum
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.asSequence
|
||||
|
||||
/**
|
||||
* Basic operations on matrices. Operates on [Matrix]
|
||||
*/
|
||||
interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
||||
/**
|
||||
* Produce a matrix with this context and given dimensions
|
||||
*/
|
||||
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||
|
||||
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
||||
|
||||
infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
||||
|
||||
operator fun Matrix<T>.times(value: T): Matrix<T>
|
||||
|
||||
operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Non-boxing double matrix
|
||||
*/
|
||||
val real = RealMatrixContext
|
||||
|
||||
/**
|
||||
* A structured matrix with custom buffer
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> buffered(
|
||||
ring: R,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||
): GenericMatrixContext<T, R> =
|
||||
BufferMatrixContext(ring, bufferFactory)
|
||||
|
||||
/**
|
||||
* Automatic buffered matrix, unboxed if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> =
|
||||
buffered(ring, Buffer.Companion::auto)
|
||||
}
|
||||
}
|
||||
|
||||
interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||
/**
|
||||
* The ring context for matrix elements
|
||||
*/
|
||||
val elementContext: R
|
||||
|
||||
/**
|
||||
* Produce a point compatible with matrix space
|
||||
*/
|
||||
fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||
|
||||
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||
//TODO add typed error
|
||||
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
||||
return produce(rowNum, other.colNum) { i, j ->
|
||||
val row = rows[i]
|
||||
val column = other.columns[j]
|
||||
with(elementContext) {
|
||||
sum(row.asSequence().zip(column.asSequence(), ::multiply))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||
//TODO add typed error
|
||||
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
||||
return point(rowNum) { i ->
|
||||
val row = rows[i]
|
||||
with(elementContext) {
|
||||
sum(row.asSequence().zip(vector.asSequence(), ::multiply))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override operator fun Matrix<T>.unaryMinus() =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||
|
||||
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||
if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]")
|
||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) + b[i, j] } }
|
||||
}
|
||||
|
||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||
}
|
||||
|
||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||
produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a.get(i, j) * k } }
|
||||
|
||||
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||
|
||||
override fun Matrix<T>.times(value: T): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
||||
}
|
@ -25,7 +25,7 @@ object UnitFeature : MatrixFeature
|
||||
* Inverted matrix feature
|
||||
*/
|
||||
interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
||||
val inverse: Matrix<T>
|
||||
val inverse: FeaturedMatrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
@ -54,9 +54,9 @@ object UFeature: MatrixFeature
|
||||
* TODO add documentation
|
||||
*/
|
||||
interface LUPDecompositionFeature<T : Any> : MatrixFeature {
|
||||
val l: Matrix<T>
|
||||
val u: Matrix<T>
|
||||
val p: Matrix<T>
|
||||
val l: FeaturedMatrix<T>
|
||||
val u: FeaturedMatrix<T>
|
||||
val p: FeaturedMatrix<T>
|
||||
}
|
||||
|
||||
//TODO add sparse matrix feature
|
@ -1,40 +0,0 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.structures.MutableBuffer
|
||||
import scientifik.kmath.structures.MutableBufferFactory
|
||||
import scientifik.kmath.structures.MutableNDStructure
|
||||
|
||||
class Mutable2DStructure<T>(val rowNum: Int, val colNum: Int, val buffer: MutableBuffer<T>) : MutableNDStructure<T> {
|
||||
override val shape: IntArray
|
||||
get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
||||
|
||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||
for (i in 0 until rowNum) {
|
||||
for (j in 0 until colNum) {
|
||||
yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
operator fun set(i: Int, j: Int, value: T) {
|
||||
buffer[i * colNum + j] = value
|
||||
}
|
||||
|
||||
override fun set(index: IntArray, value: T) = set(index[0], index[1], value)
|
||||
|
||||
companion object {
|
||||
fun <T> create(
|
||||
rowNum: Int,
|
||||
colNum: Int,
|
||||
bufferFactory: MutableBufferFactory<T>,
|
||||
init: (i: Int, j: Int) -> T
|
||||
): Mutable2DStructure<T> {
|
||||
val buffer = bufferFactory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||
return Mutable2DStructure(rowNum, colNum, buffer)
|
||||
}
|
||||
}
|
||||
}
|
@ -4,68 +4,23 @@ import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.operations.SpaceElement
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
import scientifik.kmath.structures.asSequence
|
||||
import kotlin.jvm.JvmName
|
||||
|
||||
typealias Point<T> = Buffer<T>
|
||||
|
||||
/**
|
||||
* A linear space for vectors.
|
||||
* Could be used on any point-like structure
|
||||
*/
|
||||
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||
|
||||
val size: Int
|
||||
|
||||
val space: S
|
||||
|
||||
fun produce(initializer: (Int) -> T): Point<T>
|
||||
|
||||
/**
|
||||
* Produce a space-element of this vector space for expressions
|
||||
*/
|
||||
fun produceElement(initializer: (Int) -> T): Vector<T, S>
|
||||
|
||||
override val zero: Point<T> get() = produce { space.zero }
|
||||
|
||||
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
|
||||
|
||||
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
|
||||
|
||||
//TODO add basis
|
||||
|
||||
companion object {
|
||||
|
||||
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
||||
|
||||
/**
|
||||
* Non-boxing double vector space
|
||||
*/
|
||||
fun real(size: Int): BufferVectorSpace<Double, RealField> {
|
||||
return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, RealField, Buffer.Companion::auto) }
|
||||
}
|
||||
|
||||
/**
|
||||
* A structured vector space with custom buffer
|
||||
*/
|
||||
fun <T : Any, S : Space<T>> buffered(
|
||||
size: Int,
|
||||
space: S,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||
): VectorSpace<T, S> = BufferVectorSpace(size, space, bufferFactory)
|
||||
|
||||
/**
|
||||
* Automatic buffered vector, unboxed if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any, S : Space<T>> smart(size: Int, space: S): VectorSpace<T, S> =
|
||||
buffered(size, space, Buffer.Companion::auto)
|
||||
}
|
||||
}
|
||||
fun <T : Any, S : Space<T>> BufferVectorSpace<T, S>.produceElement(initializer: (Int) -> T): Vector<T, S> =
|
||||
BufferVector(this, produce(initializer))
|
||||
|
||||
@JvmName("produceRealElement")
|
||||
fun BufferVectorSpace<Double, RealField>.produceElement(initializer: (Int) -> Double): Vector<Double, RealField> =
|
||||
BufferVector(this, produce(initializer))
|
||||
|
||||
/**
|
||||
* A point coupled to the linear space
|
||||
*/
|
||||
@Deprecated("Use VectorContext instead")
|
||||
interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, Vector<T, S>, VectorSpace<T, S>>, Point<T> {
|
||||
override val size: Int get() = context.size
|
||||
|
||||
@ -90,16 +45,7 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, Vector<T, S>, V
|
||||
}
|
||||
}
|
||||
|
||||
data class BufferVectorSpace<T : Any, S : Space<T>>(
|
||||
override val size: Int,
|
||||
override val space: S,
|
||||
val bufferFactory: BufferFactory<T>
|
||||
) : VectorSpace<T, S> {
|
||||
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
|
||||
override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
||||
}
|
||||
|
||||
|
||||
@Deprecated("Use VectorContext instead")
|
||||
data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace<T, S>, val buffer: Buffer<T>) :
|
||||
Vector<T, S> {
|
||||
|
||||
|
@ -0,0 +1,75 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
|
||||
/**
|
||||
* A linear space for vectors.
|
||||
* Could be used on any point-like structure
|
||||
*/
|
||||
interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||
|
||||
val size: Int
|
||||
|
||||
val space: S
|
||||
|
||||
fun produce(initializer: (Int) -> T): Point<T>
|
||||
|
||||
/**
|
||||
* Produce a space-element of this vector space for expressions
|
||||
*/
|
||||
//fun produceElement(initializer: (Int) -> T): Vector<T, S>
|
||||
|
||||
override val zero: Point<T> get() = produce { space.zero }
|
||||
|
||||
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
|
||||
|
||||
override fun multiply(a: Point<T>, k: Number): Point<T> = produce { with(space) { a[it] * k } }
|
||||
|
||||
//TODO add basis
|
||||
|
||||
companion object {
|
||||
|
||||
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
||||
|
||||
/**
|
||||
* Non-boxing double vector space
|
||||
*/
|
||||
fun real(size: Int): BufferVectorSpace<Double, RealField> {
|
||||
return realSpaceCache.getOrPut(size) {
|
||||
BufferVectorSpace(
|
||||
size,
|
||||
RealField,
|
||||
Buffer.Companion::auto
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A structured vector space with custom buffer
|
||||
*/
|
||||
fun <T : Any, S : Space<T>> buffered(
|
||||
size: Int,
|
||||
space: S,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||
) = BufferVectorSpace(size, space, bufferFactory)
|
||||
|
||||
/**
|
||||
* Automatic buffered vector, unboxed if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any, S : Space<T>> auto(size: Int, space: S): VectorSpace<T, S> =
|
||||
buffered(size, space, Buffer.Companion::auto)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BufferVectorSpace<T : Any, S : Space<T>>(
|
||||
override val size: Int,
|
||||
override val space: S,
|
||||
val bufferFactory: BufferFactory<T>
|
||||
) : VectorSpace<T, S> {
|
||||
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
|
||||
//override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
||||
}
|
@ -1,11 +1,23 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.structures.Matrix
|
||||
|
||||
class VirtualMatrix<T : Any>(
|
||||
override val rowNum: Int,
|
||||
override val colNum: Int,
|
||||
override val features: Set<MatrixFeature> = emptySet(),
|
||||
val generator: (i: Int, j: Int) -> T
|
||||
) : Matrix<T> {
|
||||
) : FeaturedMatrix<T> {
|
||||
|
||||
constructor(rowNum: Int, colNum: Int, vararg features: MatrixFeature, generator: (i: Int, j: Int) -> T) : this(
|
||||
rowNum,
|
||||
colNum,
|
||||
setOf(*features),
|
||||
generator
|
||||
)
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||
@ -13,7 +25,7 @@ class VirtualMatrix<T : Any>(
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is Matrix<*>) return false
|
||||
if (other !is FeaturedMatrix<*>) return false
|
||||
|
||||
if (rowNum != other.rowNum) return false
|
||||
if (colNum != other.colNum) return false
|
||||
@ -34,7 +46,7 @@ class VirtualMatrix<T : Any>(
|
||||
/**
|
||||
* Wrap a matrix adding additional features to it
|
||||
*/
|
||||
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): Matrix<T> {
|
||||
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
|
||||
return if (matrix is VirtualMatrix) {
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
||||
} else {
|
||||
|
@ -1,4 +1,4 @@
|
||||
package scientifik.kmath.sequential
|
||||
package scientifik.kmath.misc
|
||||
|
||||
import scientifik.kmath.operations.Space
|
||||
import kotlin.jvm.JvmName
|
@ -0,0 +1,218 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/*
|
||||
* Implementation of backward-mode automatic differentiation.
|
||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||
*/
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||
* with respect to this variable.
|
||||
*/
|
||||
open class Variable(val value: Double) {
|
||||
constructor(x: Number) : this(x.toDouble())
|
||||
}
|
||||
|
||||
class DerivationResult(value: Double, val deriv: Map<Variable, Double>) : Variable(value) {
|
||||
fun deriv(variable: Variable) = deriv[variable] ?: 0.0
|
||||
|
||||
/**
|
||||
* compute divergence
|
||||
*/
|
||||
fun div() = deriv.values.sum()
|
||||
|
||||
/**
|
||||
* Compute a gradient for variables in given order
|
||||
*/
|
||||
fun grad(vararg variables: Variable): Point<Double> = if (variables.isEmpty()) {
|
||||
error("Variable order is not provided for gradient construction")
|
||||
} else {
|
||||
variables.map(::deriv).toDoubleArray().asBuffer()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||
*
|
||||
* The partial derivatives are placed in argument `d` variable
|
||||
*
|
||||
* Example:
|
||||
* ```
|
||||
* val x = Variable(2) // define variable(s) and their values
|
||||
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||
* assertEquals(17.0, y.x) // the value of result (y)
|
||||
* assertEquals(9.0, x.d) // dy/dx
|
||||
* ```
|
||||
*/
|
||||
fun deriv(body: AutoDiffField.() -> Variable): DerivationResult =
|
||||
AutoDiffContext().run {
|
||||
val result = body()
|
||||
result.d = 1.0 // computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
DerivationResult(result.value, derivatives)
|
||||
}
|
||||
|
||||
|
||||
abstract class AutoDiffField : Field<Variable> {
|
||||
/**
|
||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||
*
|
||||
* For example, implementation of `sin` function is:
|
||||
*
|
||||
* ```
|
||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
abstract fun <R> derive(value: R, block: (R) -> Unit): R
|
||||
|
||||
/**
|
||||
* A variable accessing inner state of derivatives.
|
||||
* Use this function in inner builders to avoid creating additional derivative bindings
|
||||
*/
|
||||
abstract var Variable.d: Double
|
||||
|
||||
abstract fun variable(value: Double): Variable
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
operator fun Number.plus(that: Variable): Variable = derive(variable(this.toDouble() + that.value)) { z ->
|
||||
that.d += z.d
|
||||
}
|
||||
|
||||
operator fun Variable.plus(b: Number): Variable = b.plus(this)
|
||||
|
||||
operator fun Number.minus(that: Variable): Variable = derive(variable(this.toDouble() - that.value)) { z ->
|
||||
that.d -= z.d
|
||||
}
|
||||
|
||||
operator fun Variable.minus(that: Number): Variable = derive(variable(this.value - that.toDouble())) { z ->
|
||||
this.d += z.d
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatic Differentiation context class.
|
||||
*/
|
||||
private class AutoDiffContext : AutoDiffField() {
|
||||
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack = arrayOfNulls<Any?>(8)
|
||||
private var sp = 0
|
||||
|
||||
internal val derivatives = HashMap<Variable, Double>()
|
||||
|
||||
|
||||
/**
|
||||
* A variable coupled with its derivative. For internal use only
|
||||
*/
|
||||
class VariableWithDeriv(x: Double, var d: Double = 0.0) : Variable(x)
|
||||
|
||||
|
||||
override fun variable(value: Double): Variable = VariableWithDeriv(value)
|
||||
|
||||
override var Variable.d: Double
|
||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: 0.0
|
||||
set(value) {
|
||||
if (this is VariableWithDeriv) {
|
||||
d = value
|
||||
} else {
|
||||
derivatives[this] = value
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <R> derive(value: R, block: (R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
stack[sp++] = value
|
||||
return value
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as (Any?) -> Unit
|
||||
block(value)
|
||||
}
|
||||
}
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
|
||||
override fun add(a: Variable, b: Variable): Variable =
|
||||
derive(variable(a.value + b.value)) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable, b: Variable): Variable =
|
||||
derive(variable(a.value * b.value)) { z ->
|
||||
a.d += z.d * b.value
|
||||
b.d += z.d * a.value
|
||||
}
|
||||
|
||||
override fun divide(a: Variable, b: Variable): Variable =
|
||||
derive(Variable(a.value / b.value)) { z ->
|
||||
a.d += z.d / b.value
|
||||
b.d -= z.d * a.value / (b.value * b.value)
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable, k: Number): Variable =
|
||||
derive(variable(k.toDouble() * a.value)) { z ->
|
||||
a.d += z.d * k.toDouble()
|
||||
}
|
||||
|
||||
override val zero: Variable get() = Variable(0.0)
|
||||
override val one: Variable get() = Variable(1.0)
|
||||
}
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
fun AutoDiffField.sqr(x: Variable): Variable = derive(variable(x.value * x.value)) { z ->
|
||||
x.d += z.d * 2 * x.value
|
||||
}
|
||||
|
||||
// x ^ 1/2
|
||||
fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable(sqrt(x.value))) { z ->
|
||||
x.d += z.d * 0.5 / z.value
|
||||
}
|
||||
|
||||
// x ^ y (const)
|
||||
fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable(x.value.pow(y))) { z ->
|
||||
x.d += z.d * y * x.value.pow(y - 1)
|
||||
}
|
||||
|
||||
fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
fun AutoDiffField.exp(x: Variable): Variable = derive(variable(kotlin.math.exp(x.value))) { z ->
|
||||
x.d += z.d * z.value
|
||||
}
|
||||
|
||||
// ln(x)
|
||||
fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.value))) { z ->
|
||||
x.d += z.d / x.value
|
||||
}
|
||||
|
||||
// x ^ y (any)
|
||||
fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
fun AutoDiffField.sin(x: Variable): Variable = derive(variable(kotlin.math.sin(x.value))) { z ->
|
||||
x.d += z.d * kotlin.math.cos(x.value)
|
||||
}
|
||||
|
||||
// cos(x)
|
||||
fun AutoDiffField.cos(x: Variable): Variable = derive(variable(kotlin.math.cos(x.value))) { z ->
|
||||
x.d -= z.d * kotlin.math.sin(x.value)
|
||||
}
|
@ -55,23 +55,14 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
||||
/**
|
||||
* Complex number class
|
||||
*/
|
||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField> {
|
||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
||||
override fun unwrap(): Complex = this
|
||||
|
||||
override fun Complex.wrap(): Complex = this
|
||||
|
||||
override val context: ComplexField get() = ComplexField
|
||||
|
||||
/**
|
||||
* A complex conjugate
|
||||
*/
|
||||
val conjugate: Complex get() = Complex(re, -im)
|
||||
|
||||
val square: Double get() = re * re + im * im
|
||||
|
||||
val abs: Double get() = sqrt(square)
|
||||
|
||||
val theta: Double get() = atan(im / re)
|
||||
override fun compareTo(other: Complex): Int = abs.compareTo(other.abs)
|
||||
|
||||
companion object : MemorySpec<Complex> {
|
||||
override val objectSize: Int = 16
|
||||
@ -86,6 +77,26 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A complex conjugate
|
||||
*/
|
||||
val Complex.conjugate: Complex get() = Complex(re, -im)
|
||||
|
||||
/**
|
||||
* Square of the complex number
|
||||
*/
|
||||
val Complex.square: Double get() = re * re + im * im
|
||||
|
||||
/**
|
||||
* Absolute value of complex number
|
||||
*/
|
||||
val Complex.abs: Double get() = sqrt(square)
|
||||
|
||||
/**
|
||||
* An angle between vector represented by complex number and X axis
|
||||
*/
|
||||
val Complex.theta: Double get() = atan(im / re)
|
||||
|
||||
fun Double.toComplex() = Complex(this, 0.0)
|
||||
|
||||
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||
|
@ -1,5 +1,6 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow
|
||||
|
||||
/**
|
||||
@ -31,79 +32,152 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
/**
|
||||
* A field for double without boxing. Does not produce appropriate field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
|
||||
override val zero: Double = 0.0
|
||||
override fun add(a: Double, b: Double): Double = a + b
|
||||
override fun multiply(a: Double, b: Double): Double = a * b
|
||||
override fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||
override inline fun add(a: Double, b: Double) = a + b
|
||||
override inline fun multiply(a: Double, b: Double) = a * b
|
||||
override inline fun multiply(a: Double, k: Number) = a * k.toDouble()
|
||||
|
||||
override val one: Double = 1.0
|
||||
override fun divide(a: Double, b: Double): Double = a / b
|
||||
override inline fun divide(a: Double, b: Double) = a / b
|
||||
|
||||
override fun sin(arg: Double): Double = kotlin.math.sin(arg)
|
||||
override fun cos(arg: Double): Double = kotlin.math.cos(arg)
|
||||
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
|
||||
|
||||
override fun power(arg: Double, pow: Number): Double = arg.pow(pow.toDouble())
|
||||
override inline fun power(arg: Double, pow: Number) = arg.pow(pow.toDouble())
|
||||
|
||||
override fun exp(arg: Double): Double = kotlin.math.exp(arg)
|
||||
override fun ln(arg: Double): Double = kotlin.math.ln(arg)
|
||||
override inline fun exp(arg: Double) = kotlin.math.exp(arg)
|
||||
override inline fun ln(arg: Double) = kotlin.math.ln(arg)
|
||||
|
||||
override fun norm(arg: Double): Double = kotlin.math.abs(arg)
|
||||
override inline fun norm(arg: Double) = abs(arg)
|
||||
|
||||
override fun Double.unaryMinus(): Double = -this
|
||||
override inline fun Double.unaryMinus() = -this
|
||||
|
||||
override fun Double.minus(b: Double): Double = this - b
|
||||
override inline fun Double.plus(b: Double) = this + b
|
||||
|
||||
override inline fun Double.minus(b: Double) = this - b
|
||||
|
||||
override inline fun Double.times(b: Double) = this * b
|
||||
|
||||
override inline fun Double.div(b: Double) = this / b
|
||||
}
|
||||
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object FloatField : Field<Float>, ExtendedFieldOperations<Float>, Norm<Float, Float> {
|
||||
override val zero: Float = 0f
|
||||
override inline fun add(a: Float, b: Float) = a + b
|
||||
override inline fun multiply(a: Float, b: Float) = a * b
|
||||
override inline fun multiply(a: Float, k: Number) = a * k.toFloat()
|
||||
|
||||
override val one: Float = 1f
|
||||
override inline fun divide(a: Float, b: Float) = a / b
|
||||
|
||||
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
|
||||
|
||||
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
|
||||
|
||||
override inline fun exp(arg: Float) = kotlin.math.exp(arg)
|
||||
override inline fun ln(arg: Float) = kotlin.math.ln(arg)
|
||||
|
||||
override inline fun norm(arg: Float) = abs(arg)
|
||||
|
||||
override inline fun Float.unaryMinus() = -this
|
||||
|
||||
override inline fun Float.plus(b: Float) = this + b
|
||||
|
||||
override inline fun Float.minus(b: Float) = this - b
|
||||
|
||||
override inline fun Float.times(b: Float) = this * b
|
||||
|
||||
override inline fun Float.div(b: Float) = this / b
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Int] without boxing. Does not produce corresponding field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
override val zero: Int = 0
|
||||
override fun add(a: Int, b: Int): Int = a + b
|
||||
override fun multiply(a: Int, b: Int): Int = a * b
|
||||
override fun multiply(a: Int, k: Number): Int = (k * a)
|
||||
override inline fun add(a: Int, b: Int) = a + b
|
||||
override inline fun multiply(a: Int, b: Int) = a * b
|
||||
override inline fun multiply(a: Int, k: Number) = (k * a)
|
||||
override val one: Int = 1
|
||||
|
||||
override fun norm(arg: Int): Int = arg
|
||||
override inline fun norm(arg: Int) = abs(arg)
|
||||
|
||||
override inline fun Int.unaryMinus() = -this
|
||||
|
||||
override inline fun Int.plus(b: Int): Int = this + b
|
||||
|
||||
override inline fun Int.minus(b: Int): Int = this - b
|
||||
|
||||
override inline fun Int.times(b: Int): Int = this * b
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Short] without boxing. Does not produce appropriate field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
override val zero: Short = 0
|
||||
override fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
override fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
||||
override fun multiply(a: Short, k: Number): Short = (a * k)
|
||||
override inline fun add(a: Short, b: Short) = (a + b).toShort()
|
||||
override inline fun multiply(a: Short, b: Short) = (a * b).toShort()
|
||||
override inline fun multiply(a: Short, k: Number) = (a * k)
|
||||
override val one: Short = 1
|
||||
|
||||
override fun norm(arg: Short): Short = arg
|
||||
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||
|
||||
override inline fun Short.unaryMinus() = (-this).toShort()
|
||||
|
||||
override inline fun Short.plus(b: Short) = (this + b).toShort()
|
||||
|
||||
override inline fun Short.minus(b: Short) = (this - b).toShort()
|
||||
|
||||
override inline fun Short.times(b: Short) = (this * b).toShort()
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Byte] values
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
override val zero: Byte = 0
|
||||
override fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
override fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
||||
override fun multiply(a: Byte, k: Number): Byte = (a * k)
|
||||
override inline fun add(a: Byte, b: Byte) = (a + b).toByte()
|
||||
override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte()
|
||||
override inline fun multiply(a: Byte, k: Number) = (a * k)
|
||||
override val one: Byte = 1
|
||||
|
||||
override fun norm(arg: Byte): Byte = arg
|
||||
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||
|
||||
override inline fun Byte.unaryMinus() = (-this).toByte()
|
||||
|
||||
override inline fun Byte.plus(b: Byte) = (this + b).toByte()
|
||||
|
||||
override inline fun Byte.minus(b: Byte) = (this - b).toByte()
|
||||
|
||||
override inline fun Byte.times(b: Byte) = (this * b).toByte()
|
||||
}
|
||||
|
||||
/**
|
||||
* A field for [Long] values
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||
override val zero: Long = 0
|
||||
override fun add(a: Long, b: Long): Long = (a + b)
|
||||
override fun multiply(a: Long, b: Long): Long = (a * b)
|
||||
override fun multiply(a: Long, k: Number): Long = (a * k)
|
||||
override inline fun add(a: Long, b: Long) = (a + b)
|
||||
override inline fun multiply(a: Long, b: Long) = (a * b)
|
||||
override inline fun multiply(a: Long, k: Number) = (a * k)
|
||||
override val one: Long = 1
|
||||
|
||||
override fun norm(arg: Long): Long = arg
|
||||
override fun norm(arg: Long): Long = abs(arg)
|
||||
|
||||
override inline fun Long.unaryMinus() = (-this)
|
||||
|
||||
override inline fun Long.plus(b: Long) = (this + b)
|
||||
|
||||
override inline fun Long.minus(b: Long) = (this - b)
|
||||
|
||||
override inline fun Long.times(b: Long) = (this * b)
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* A context that allows to operate on a [MutableBuffer] as on 2d array
|
||||
*/
|
||||
class BufferAccessor2D<T : Any>(val type: KClass<T>, val rowNum: Int, val colNum: Int) {
|
||||
|
||||
operator fun Buffer<T>.get(i: Int, j: Int) = get(i + colNum * j)
|
||||
|
||||
operator fun MutableBuffer<T>.set(i: Int, j: Int, value: T) {
|
||||
set(i + colNum * j, value)
|
||||
}
|
||||
|
||||
inline fun create(init: (i: Int, j: Int) -> T) =
|
||||
MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||
|
||||
fun create(mat: Structure2D<T>) = create { i, j -> mat[i, j] }
|
||||
|
||||
//TODO optimize wrapper
|
||||
fun MutableBuffer<T>.collect(): Structure2D<T> =
|
||||
NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D()
|
||||
|
||||
|
||||
inner class Row(val buffer: MutableBuffer<T>, val rowIndex: Int) : MutableBuffer<T> {
|
||||
override val size: Int get() = colNum
|
||||
|
||||
override fun get(index: Int): T = buffer[rowIndex, index]
|
||||
|
||||
override fun set(index: Int, value: T) {
|
||||
buffer[rowIndex, index] = value
|
||||
}
|
||||
|
||||
override fun copy(): MutableBuffer<T> = MutableBuffer.auto(type, colNum) { get(it) }
|
||||
|
||||
override fun iterator(): Iterator<T> = (0 until colNum).map(::get).iterator()
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get row
|
||||
*/
|
||||
fun MutableBuffer<T>.row(i: Int) = Row(this, i)
|
||||
}
|
@ -2,6 +2,7 @@ package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.complex
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
|
||||
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
||||
@ -36,18 +37,20 @@ interface Buffer<T> {
|
||||
|
||||
companion object {
|
||||
|
||||
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer {
|
||||
val array = DoubleArray(size) { initializer(it) }
|
||||
return DoubleBuffer(array)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a boxing buffer of given type
|
||||
*/
|
||||
inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer))
|
||||
|
||||
/**
|
||||
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||
//TODO add resolution based on Annotation or companion resolution
|
||||
return when (T::class) {
|
||||
return when (type) {
|
||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
||||
@ -56,6 +59,13 @@ interface Buffer<T> {
|
||||
else -> boxing(size, initializer)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified T : Any> auto(size: Int, crossinline initializer: (Int) -> T): Buffer<T> =
|
||||
auto(T::class, size, initializer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,12 +88,9 @@ interface MutableBuffer<T> : Buffer<T> {
|
||||
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||
MutableListBuffer(MutableList(size, initializer))
|
||||
|
||||
/**
|
||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||
return when (T::class) {
|
||||
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||
return when (type) {
|
||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||
@ -91,6 +98,17 @@ interface MutableBuffer<T> : Buffer<T> {
|
||||
else -> boxing(size, initializer)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||
auto(T::class, size, initializer)
|
||||
|
||||
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,7 +254,10 @@ inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
||||
* Useful when one needs single element from the buffer.
|
||||
*/
|
||||
class VirtualBuffer<T>(override val size: Int, private val generator: (Int) -> T) : Buffer<T> {
|
||||
override fun get(index: Int): T = generator(index)
|
||||
override fun get(index: Int): T {
|
||||
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
|
||||
return generator(index)
|
||||
}
|
||||
|
||||
override fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
|
||||
|
||||
@ -261,4 +282,6 @@ fun <T> Buffer<T>.asReadOnly(): Buffer<T> = if (this is MutableBuffer) {
|
||||
/**
|
||||
* Typealias for buffer transformations
|
||||
*/
|
||||
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
|
||||
typealias BufferTransform<T, R> = (Buffer<T>) -> Buffer<R>
|
||||
|
||||
typealias SuspendBufferTransform<T, R> = suspend (Buffer<T>) -> Buffer<R>
|
@ -73,6 +73,7 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
|
||||
*/
|
||||
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
||||
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
||||
|
||||
@ -90,6 +91,7 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
|
||||
*/
|
||||
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) }
|
||||
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) }
|
||||
}
|
||||
@ -109,6 +111,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
||||
*/
|
||||
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) }
|
||||
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
||||
|
||||
|
@ -1,12 +1,14 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import kotlin.jvm.JvmName
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
|
||||
interface NDStructure<T> {
|
||||
|
||||
val shape: IntArray
|
||||
|
||||
val dimension
|
||||
get() = shape.size
|
||||
val dimension get() = shape.size
|
||||
|
||||
operator fun get(index: IntArray): T
|
||||
|
||||
@ -41,6 +43,9 @@ interface NDStructure<T> {
|
||||
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
inline fun <T : Any> auto(type: KClass<T>, strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
fun <T> build(
|
||||
shape: IntArray,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
@ -49,6 +54,13 @@ interface NDStructure<T> {
|
||||
|
||||
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
auto(DefaultStrides(shape), initializer)
|
||||
|
||||
@JvmName("autoVarArg")
|
||||
inline fun <reified T : Any> auto(vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||
auto(DefaultStrides(shape), initializer)
|
||||
|
||||
inline fun <T : Any> auto(type: KClass<T>, vararg shape: Int, crossinline initializer: (IntArray) -> T) =
|
||||
auto(type, DefaultStrides(shape), initializer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,7 +70,7 @@ interface MutableNDStructure<T> : NDStructure<T> {
|
||||
operator fun set(index: IntArray, value: T)
|
||||
}
|
||||
|
||||
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
||||
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
||||
elements().forEach { (index, oldValue) ->
|
||||
this[index] = action(index, oldValue)
|
||||
}
|
||||
|
@ -27,20 +27,6 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
Structure1DWrapper(this)
|
||||
} else {
|
||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
fun <T> NDBuffer<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
Buffer1DWrapper(this.buffer)
|
||||
} else {
|
||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
/**
|
||||
* A structure wrapper for buffer
|
||||
@ -56,39 +42,21 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
||||
override fun get(index: Int): T = buffer.get(index)
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent this buffer as 1D structure
|
||||
*/
|
||||
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||
|
||||
/**
|
||||
* A structure that is guaranteed to be two-dimensional
|
||||
*/
|
||||
interface Structure2D<T> : NDStructure<T> {
|
||||
operator fun get(i: Int, j: Int): T
|
||||
|
||||
override fun get(index: IntArray): T {
|
||||
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
||||
return get(index[0], index[1])
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A 2D wrapper for nd-structure
|
||||
*/
|
||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
||||
Structure2DWrapper(this)
|
||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
if( this is NDBuffer){
|
||||
Buffer1DWrapper(this.buffer)
|
||||
} else {
|
||||
Structure1DWrapper(this)
|
||||
}
|
||||
} else {
|
||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represent this buffer as 1D structure
|
||||
*/
|
||||
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
@ -0,0 +1,79 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
/**
|
||||
* A structure that is guaranteed to be two-dimensional
|
||||
*/
|
||||
interface Structure2D<T> : NDStructure<T> {
|
||||
val rowNum: Int get() = shape[0]
|
||||
val colNum: Int get() = shape[1]
|
||||
|
||||
operator fun get(i: Int, j: Int): T
|
||||
|
||||
override fun get(index: IntArray): T {
|
||||
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
||||
return get(index[0], index[1])
|
||||
}
|
||||
|
||||
|
||||
val rows: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(rowNum) { i ->
|
||||
VirtualBuffer(colNum) { j -> get(i, j) }
|
||||
}
|
||||
|
||||
val columns: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(colNum) { j ->
|
||||
VirtualBuffer(rowNum) { i -> get(i, j) }
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||
for (i in (0 until rowNum)) {
|
||||
for (j in (0 until colNum)) {
|
||||
yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A 2D wrapper for nd-structure
|
||||
*/
|
||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
||||
Structure2DWrapper(this)
|
||||
} else {
|
||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise.
|
||||
*/
|
||||
fun <T> Structure2D<T>.as1D() = if (colNum == 1) {
|
||||
object : Structure1D<T> {
|
||||
override fun get(index: Int): T = get(index, 0)
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum)
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = elements()
|
||||
|
||||
override val size: Int get() = rowNum
|
||||
}
|
||||
} else {
|
||||
error("Can't convert matrix with more than one column to vector")
|
||||
}
|
||||
|
||||
|
||||
typealias Matrix<T> = Structure2D<T>
|
@ -1,5 +1,6 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@ -16,7 +17,7 @@ class MatrixTest {
|
||||
@Test
|
||||
fun testVectorToMatrix() {
|
||||
val vector = Vector.real(5) { it.toDouble() }
|
||||
val matrix = vector.toMatrix()
|
||||
val matrix = vector.asMatrix()
|
||||
assertEquals(4.0, matrix[4, 0])
|
||||
}
|
||||
|
||||
@ -33,8 +34,8 @@ class MatrixTest {
|
||||
val vector1 = Vector.real(5) { it.toDouble() }
|
||||
val vector2 = Vector.real(5) { 5 - it.toDouble() }
|
||||
|
||||
val matrix1 = vector1.toMatrix()
|
||||
val matrix2 = vector2.toMatrix().transpose()
|
||||
val matrix1 = vector1.asMatrix()
|
||||
val matrix2 = vector2.asMatrix().transpose()
|
||||
val product = MatrixContext.real.run { matrix1 dot matrix2 }
|
||||
|
||||
|
||||
@ -51,4 +52,26 @@ class MatrixTest {
|
||||
|
||||
assertEquals(2.0, matrix[1, 2])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMatrixExtension() {
|
||||
val transitionMatrix: Matrix<Double> = VirtualMatrix(6, 6) { row, col ->
|
||||
when {
|
||||
col == 0 -> .50
|
||||
row + 1 == col -> .50
|
||||
row == 5 && col == 5 -> 1.0
|
||||
else -> 0.0
|
||||
}
|
||||
}
|
||||
|
||||
infix fun Matrix<Double>.pow(power: Int): Matrix<Double> {
|
||||
var res = this
|
||||
repeat(power - 1) {
|
||||
res = res dot this
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
val toTenthPower = transitionMatrix pow 10
|
||||
}
|
||||
}
|
@ -1,16 +1,37 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import kotlin.contracts.ExperimentalContracts
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@ExperimentalContracts
|
||||
class RealLUSolverTest {
|
||||
|
||||
@Test
|
||||
fun testInvertOne() {
|
||||
val matrix = MatrixContext.real.one(2, 2)
|
||||
val inverted = LUSolver.real.inverse(matrix)
|
||||
val inverted = MatrixContext.real.inverse(matrix)
|
||||
assertEquals(matrix, inverted)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDecomposition() {
|
||||
val matrix = Matrix.square(
|
||||
3.0, 1.0,
|
||||
1.0, 3.0
|
||||
)
|
||||
|
||||
MatrixContext.real.run {
|
||||
val lup = lup(matrix)
|
||||
|
||||
//Check determinant
|
||||
assertEquals(8.0, lup.determinant)
|
||||
|
||||
assertEquals(lup.p dot matrix, lup.l dot lup.u)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testInvert() {
|
||||
val matrix = Matrix.square(
|
||||
@ -18,18 +39,7 @@ class RealLUSolverTest {
|
||||
1.0, 3.0
|
||||
)
|
||||
|
||||
val decomposed = LUSolver.real.decompose(matrix)
|
||||
val decomposition = decomposed.getFeature<LUPDecomposition<Double>>()!!
|
||||
|
||||
//Check determinant
|
||||
assertEquals(8.0, decomposition.determinant)
|
||||
|
||||
//Check decomposition
|
||||
with(MatrixContext.real) {
|
||||
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
||||
}
|
||||
|
||||
val inverted = LUSolver.real.inverse(decomposed)
|
||||
val inverted = MatrixContext.real.inverse(matrix)
|
||||
|
||||
val expected = Matrix.square(
|
||||
0.375, -0.125,
|
||||
|
@ -1,6 +1,5 @@
|
||||
package scientifik.kmath.misc
|
||||
|
||||
import scientifik.kmath.sequential.cumulativeSum
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
@ -0,0 +1,176 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import kotlin.math.PI
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class AutoDiffTest {
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
val x = Variable(3) // diff w.r.t this x at 3
|
||||
val y = deriv { x + x }
|
||||
assertEquals(6.0, y.value) // y = x + x = 6
|
||||
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPlus() {
|
||||
// two variables
|
||||
val x = Variable(2)
|
||||
val y = Variable(3)
|
||||
val z = deriv { x + y }
|
||||
assertEquals(5.0, z.value) // z = x + y = 5
|
||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
// two variables
|
||||
val x = Variable(7)
|
||||
val y = Variable(3)
|
||||
val z = deriv { x - y }
|
||||
assertEquals(4.0, z.value) // z = x - y = 4
|
||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMulX2() {
|
||||
val x = Variable(3) // diff w.r.t this x at 3
|
||||
val y = deriv { x * x }
|
||||
assertEquals(9.0, y.value) // y = x * x = 9
|
||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqr() {
|
||||
val x = Variable(3)
|
||||
val y = deriv { sqr(x) }
|
||||
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqrSqr() {
|
||||
val x = Variable(2)
|
||||
val y = deriv { sqr(sqr(x)) }
|
||||
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testX3() {
|
||||
val x = Variable(2) // diff w.r.t this x at 2
|
||||
val y = deriv { x * x * x }
|
||||
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDiv() {
|
||||
val x = Variable(5)
|
||||
val y = Variable(2)
|
||||
val z = deriv { x / y }
|
||||
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
||||
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPow3() {
|
||||
val x = Variable(2) // diff w.r.t this x at 2
|
||||
val y = deriv { pow(x, 3) }
|
||||
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPowFull() {
|
||||
val x = Variable(2)
|
||||
val y = Variable(3)
|
||||
val z = deriv { pow(x, y) }
|
||||
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFromPaper() {
|
||||
val x = Variable(3)
|
||||
val y = deriv { 2 * x + x * x * x }
|
||||
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testInnerVariable() {
|
||||
val x = Variable(1)
|
||||
val y = deriv {
|
||||
Variable(1) * x
|
||||
}
|
||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLongChain() {
|
||||
val n = 10_000
|
||||
val x = Variable(1)
|
||||
val y = deriv {
|
||||
var res = Variable(1)
|
||||
for (i in 1..n) res *= x
|
||||
res
|
||||
}
|
||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testExample() {
|
||||
val x = Variable(2)
|
||||
val y = deriv { sqr(x) + 5 * x + 3 }
|
||||
assertEquals(17.0, y.value) // the value of result (y)
|
||||
assertEquals(9.0, y.deriv(x)) // dy/dx
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqrt() {
|
||||
val x = Variable(16)
|
||||
val y = deriv { sqrt(x) }
|
||||
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSin() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { sin(x) }
|
||||
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||
assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCos() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { cos(x) }
|
||||
assertApprox(kotlin.math.sqrt(3.0) / 2, y.value) // y = cos(PI/6) = sqrt(3)/2
|
||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivGrad() {
|
||||
val x = Variable(1.0)
|
||||
val y = Variable(2.0)
|
||||
val res = deriv { x * x + y * y }
|
||||
assertEquals(6.0, res.div())
|
||||
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||
}
|
||||
|
||||
private fun assertApprox(a: Double, b: Double) {
|
||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||
}
|
||||
|
||||
}
|
@ -1,46 +1,26 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
`multiplatform-config`
|
||||
id("kotlinx-atomicfu") version Versions.atomicfuVersion
|
||||
}
|
||||
|
||||
val coroutinesVersion: String by rootProject.extra
|
||||
|
||||
kotlin {
|
||||
jvm()
|
||||
js()
|
||||
|
||||
sourceSets {
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:$coroutinesVersion")
|
||||
}
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
jvmMain {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:$coroutinesVersion")
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
val jsMain by getting {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:$coroutinesVersion")
|
||||
}
|
||||
}
|
||||
val jsTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
}
|
||||
jsMain {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,102 @@
|
||||
package scientifik.kmath
|
||||
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.channels.produce
|
||||
import kotlinx.coroutines.flow.*
|
||||
|
||||
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
|
||||
|
||||
/**
|
||||
* An imitator of [Deferred] which holds a suspended function block and dispatcher
|
||||
*/
|
||||
internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) {
|
||||
private var deferred: Deferred<T>? = null
|
||||
|
||||
internal fun start(scope: CoroutineScope) {
|
||||
if (deferred == null) {
|
||||
deferred = scope.async(dispatcher, block = block)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
|
||||
}
|
||||
|
||||
@FlowPreview
|
||||
class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
|
||||
override suspend fun collect(collector: FlowCollector<T>) {
|
||||
deferredFlow.collect {
|
||||
collector.emit((it.await()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@FlowPreview
|
||||
fun <T, R> Flow<T>.async(
|
||||
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
||||
block: suspend CoroutineScope.(T) -> R
|
||||
): AsyncFlow<R> {
|
||||
val flow = map {
|
||||
LazyDeferred(dispatcher) { block(it) }
|
||||
}
|
||||
return AsyncFlow(flow)
|
||||
}
|
||||
|
||||
@FlowPreview
|
||||
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = AsyncFlow(deferredFlow.map { input ->
|
||||
//TODO add function composition
|
||||
LazyDeferred(input.dispatcher) {
|
||||
input.start(this)
|
||||
action(input.await())
|
||||
}
|
||||
})
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
@FlowPreview
|
||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
|
||||
require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" }
|
||||
coroutineScope {
|
||||
//Starting up to N deferred coroutines ahead of time
|
||||
val channel = produce(capacity = concurrency - 1) {
|
||||
deferredFlow.collect { value ->
|
||||
value.start(this@coroutineScope)
|
||||
send(value)
|
||||
}
|
||||
}
|
||||
|
||||
(channel as Job).invokeOnCompletion {
|
||||
if (it is CancellationException && it.cause == null) cancel()
|
||||
}
|
||||
|
||||
for (element in channel) {
|
||||
collector.emit(element.await())
|
||||
}
|
||||
|
||||
val producer = channel as Job
|
||||
if (producer.isCancelled) {
|
||||
producer.join()
|
||||
//throw producer.getCancellationException()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
@FlowPreview
|
||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit {
|
||||
collect(concurrency, object : FlowCollector<T> {
|
||||
override suspend fun emit(value: T) = action(value)
|
||||
})
|
||||
}
|
||||
|
||||
@FlowPreview
|
||||
fun <T, R> Flow<T>.map(
|
||||
dispatcher: CoroutineDispatcher,
|
||||
concurrencyLevel: Int = 16,
|
||||
bufferSize: Int = concurrencyLevel,
|
||||
transform: suspend (T) -> R
|
||||
): Flow<R> {
|
||||
return flatMapMerge(concurrencyLevel, bufferSize) { value ->
|
||||
flow { emit(transform(value)) }
|
||||
}.flowOn(dispatcher)
|
||||
}
|
||||
|
||||
|
@ -14,14 +14,10 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package scientifik.kmath.sequential
|
||||
package scientifik.kmath.chains
|
||||
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.channels.ReceiveChannel
|
||||
import kotlinx.coroutines.channels.produce
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.FlowPreview
|
||||
|
||||
|
||||
/**
|
||||
@ -30,7 +26,7 @@ import kotlinx.coroutines.isActive
|
||||
*/
|
||||
interface Chain<out R> {
|
||||
/**
|
||||
* Last value of the chain. Returns null if [next] was not called
|
||||
* Last cached value of the chain. Returns null if [next] was not called
|
||||
*/
|
||||
val value: R?
|
||||
|
||||
@ -47,10 +43,11 @@ interface Chain<out R> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Chain as a coroutine receive channel
|
||||
* Chain as a coroutine flow. The flow emit affects chain state and vice versa
|
||||
*/
|
||||
@ExperimentalCoroutinesApi
|
||||
fun <R> Chain<R>.asChannel(scope: CoroutineScope): ReceiveChannel<R> = scope.produce { while (isActive) send(next()) }
|
||||
@FlowPreview
|
||||
val <R> Chain<R>.flow
|
||||
get() = kotlinx.coroutines.flow.flow { while (true) emit(next()) }
|
||||
|
||||
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
||||
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
@ -0,0 +1,80 @@
|
||||
package scientifik.kmath.streaming
|
||||
|
||||
import kotlinx.coroutines.FlowPreview
|
||||
import kotlinx.coroutines.flow.*
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
|
||||
/**
|
||||
* Create a [Flow] from buffer
|
||||
*/
|
||||
@FlowPreview
|
||||
fun <T> Buffer<T>.asFlow() = iterator().asFlow()
|
||||
|
||||
/**
|
||||
* Flat map a [Flow] of [Buffer] into continuous [Flow] of elements
|
||||
*/
|
||||
@FlowPreview
|
||||
fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() }
|
||||
|
||||
/**
|
||||
* Collect incoming flow into fixed size chunks
|
||||
*/
|
||||
@FlowPreview
|
||||
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>) = flow {
|
||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||
val list = ArrayList<T>(bufferSize)
|
||||
var counter = 0
|
||||
|
||||
this@chunked.collect { element ->
|
||||
list.add(element)
|
||||
counter++
|
||||
if (counter == bufferSize) {
|
||||
val buffer = bufferFactory(bufferSize) { list[it] }
|
||||
emit(buffer)
|
||||
list.clear()
|
||||
counter = 0
|
||||
}
|
||||
}
|
||||
if (counter > 0) {
|
||||
emit(bufferFactory(counter) { list[it] })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Specialized flow chunker for real buffer
|
||||
*/
|
||||
@FlowPreview
|
||||
fun Flow<Double>.chunked(bufferSize: Int) = flow {
|
||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||
val array = DoubleArray(bufferSize)
|
||||
var counter = 0
|
||||
|
||||
this@chunked.collect { element ->
|
||||
array[counter] = element
|
||||
counter++
|
||||
if (counter == bufferSize) {
|
||||
val buffer = DoubleBuffer(array)
|
||||
emit(buffer)
|
||||
counter = 0
|
||||
}
|
||||
}
|
||||
if (counter > 0) {
|
||||
emit(DoubleBuffer(counter) { array[it] })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map a flow to a moving window buffer. The window step is one.
|
||||
* In order to get different steps, one could use skip operation.
|
||||
*/
|
||||
@FlowPreview
|
||||
fun <T> Flow<T>.windowed(window: Int): Flow<Buffer<T>> = flow {
|
||||
require(window > 1) { "Window size must be more than one" }
|
||||
val ringBuffer = RingBuffer.boxing<T>(window)
|
||||
this@windowed.collect { element ->
|
||||
ringBuffer.push(element)
|
||||
emit(ringBuffer.snapshot())
|
||||
}
|
||||
}
|
@ -1,16 +1,18 @@
|
||||
package scientifik.kmath.sequential
|
||||
package scientifik.kmath.streaming
|
||||
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.MutableBuffer
|
||||
import scientifik.kmath.structures.VirtualBuffer
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* Thread-safe ring buffer
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
internal class RingBuffer<T>(
|
||||
private val buffer: MutableBuffer<T>,
|
||||
private val buffer: MutableBuffer<T?>,
|
||||
private var startIndex: Int = 0,
|
||||
size: Int = 0
|
||||
) : Buffer<T> {
|
||||
@ -23,7 +25,7 @@ internal class RingBuffer<T>(
|
||||
override fun get(index: Int): T {
|
||||
require(index >= 0) { "Index must be positive" }
|
||||
require(index < size) { "Index $index is out of circular buffer size $size" }
|
||||
return buffer[startIndex.forward(index)]
|
||||
return buffer[startIndex.forward(index)] as T
|
||||
}
|
||||
|
||||
fun isFull() = size == buffer.size
|
||||
@ -40,7 +42,7 @@ internal class RingBuffer<T>(
|
||||
if (count == 0) {
|
||||
done()
|
||||
} else {
|
||||
setNext(copy[index])
|
||||
setNext(copy[index] as T)
|
||||
index = index.forward(1)
|
||||
count--
|
||||
}
|
||||
@ -53,7 +55,9 @@ internal class RingBuffer<T>(
|
||||
suspend fun snapshot(): Buffer<T> {
|
||||
mutex.withLock {
|
||||
val copy = buffer.copy()
|
||||
return VirtualBuffer(size) { i -> copy[startIndex.forward(i)] }
|
||||
return VirtualBuffer(size) { i ->
|
||||
copy[startIndex.forward(i)] as T
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,14 +78,14 @@ internal class RingBuffer<T>(
|
||||
|
||||
companion object {
|
||||
inline fun <reified T : Any> build(size: Int, empty: T): RingBuffer<T> {
|
||||
val buffer = MutableBuffer.auto(size) { empty }
|
||||
val buffer = MutableBuffer.auto(size) { empty } as MutableBuffer<T?>
|
||||
return RingBuffer(buffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Slow yet universal buffer
|
||||
*/
|
||||
fun <T> boxing(size: Int): RingBuffer<T?> {
|
||||
fun <T> boxing(size: Int): RingBuffer<T> {
|
||||
val buffer: MutableBuffer<T?> = MutableBuffer.boxing(size) { null }
|
||||
return RingBuffer(buffer)
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import kotlinx.coroutines.CoroutineDispatcher
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.coroutines.EmptyCoroutineContext
|
||||
|
||||
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
|
@ -1,4 +1,4 @@
|
||||
package scientifik.kmath.sequential
|
||||
package scientifik.kmath.chains
|
||||
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlin.sequences.Sequence
|
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import kotlinx.coroutines.*
|
||||
import scientifik.kmath.Math
|
||||
|
||||
class LazyNDStructure<T>(
|
||||
val scope: CoroutineScope,
|
||||
|
@ -0,0 +1,48 @@
|
||||
package scientifik.kmath.streaming
|
||||
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.flow.asFlow
|
||||
import kotlinx.coroutines.flow.collect
|
||||
import org.junit.Test
|
||||
import scientifik.kmath.async
|
||||
import scientifik.kmath.collect
|
||||
import scientifik.kmath.map
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
@InternalCoroutinesApi
|
||||
@FlowPreview
|
||||
class BufferFlowTest {
|
||||
|
||||
val dispatcher = Executors.newFixedThreadPool(4).asCoroutineDispatcher()
|
||||
|
||||
@Test(timeout = 2000)
|
||||
fun map() {
|
||||
runBlocking {
|
||||
(1..20).asFlow().map( dispatcher) {
|
||||
//println("Started $it on ${Thread.currentThread().name}")
|
||||
@Suppress("BlockingMethodInNonBlockingContext")
|
||||
Thread.sleep(200)
|
||||
it
|
||||
}.collect {
|
||||
println("Completed $it on ${Thread.currentThread().name}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 2000)
|
||||
fun async() {
|
||||
runBlocking {
|
||||
(1..20).asFlow().async(dispatcher) {
|
||||
//println("Started $it on ${Thread.currentThread().name}")
|
||||
@Suppress("BlockingMethodInNonBlockingContext")
|
||||
Thread.sleep(200)
|
||||
it
|
||||
}.collect(4) {
|
||||
println("Completed $it on ${Thread.currentThread().name}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
package scientifik.kmath.streaming
|
||||
|
||||
import kotlinx.coroutines.flow.*
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.Test
|
||||
import scientifik.kmath.structures.asSequence
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class RingBufferTest {
|
||||
@Test
|
||||
fun push() {
|
||||
val buffer = RingBuffer.build(20, Double.NaN)
|
||||
runBlocking {
|
||||
for (i in 1..30) {
|
||||
buffer.push(i.toDouble())
|
||||
}
|
||||
assertEquals(410.0, buffer.asSequence().sum())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun windowed(){
|
||||
val flow = flow{
|
||||
var i = 0
|
||||
while(true){
|
||||
emit(i++)
|
||||
}
|
||||
}
|
||||
val windowed = flow.windowed(10)
|
||||
runBlocking {
|
||||
val first = windowed.take(1).single()
|
||||
val res = windowed.take(15).map { it -> it.asSequence().average() }.toList()
|
||||
assertEquals(0.0, res[0])
|
||||
assertEquals(4.5, res[9])
|
||||
assertEquals(9.5, res[14])
|
||||
}
|
||||
}
|
||||
}
|
@ -1,34 +1,10 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
`multiplatform-config`
|
||||
}
|
||||
|
||||
kotlin {
|
||||
jvm()
|
||||
js()
|
||||
|
||||
sourceSets {
|
||||
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
val jsTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
}
|
||||
// Just an example how we can collapse nested DSL for simple declarations
|
||||
kotlin.sourceSets.commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
}
|
@ -1,57 +1,31 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
`multiplatform-config`
|
||||
}
|
||||
|
||||
repositories {
|
||||
maven("http://dl.bintray.com/kyonifer/maven")
|
||||
}
|
||||
|
||||
kotlin {
|
||||
jvm {
|
||||
compilations.all {
|
||||
kotlinOptions {
|
||||
jvmTarget = "1.8"
|
||||
freeCompilerArgs += "-progressive"
|
||||
}
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api("com.kyonifer:koma-core-api-common:0.12")
|
||||
}
|
||||
}
|
||||
js()
|
||||
|
||||
sourceSets {
|
||||
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api("com.kyonifer:koma-core-api-common:0.12")
|
||||
}
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api("com.kyonifer:koma-core-api-jvm:0.12")
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||
}
|
||||
}
|
||||
val jsMain by getting {
|
||||
dependencies {
|
||||
api("com.kyonifer:koma-core-api-js:0.12")
|
||||
}
|
||||
}
|
||||
val jsTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
jvmMain {
|
||||
dependencies {
|
||||
api("com.kyonifer:koma-core-api-jvm:0.12")
|
||||
}
|
||||
}
|
||||
}
|
||||
jvmTest {
|
||||
dependencies {
|
||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||
}
|
||||
}
|
||||
jsMain {
|
||||
dependencies {
|
||||
api("com.kyonifer:koma-core-api-js:0.12")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,9 +2,14 @@ package scientifik.kmath.linear
|
||||
|
||||
import koma.extensions.fill
|
||||
import koma.matrix.MatrixFactory
|
||||
import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.structures.Matrix
|
||||
|
||||
class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T>>) : MatrixContext<T>,
|
||||
LinearSolver<T> {
|
||||
class KomaMatrixContext<T : Any>(
|
||||
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
||||
private val space: Space<T>
|
||||
) :
|
||||
MatrixContext<T> {
|
||||
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) =
|
||||
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
||||
@ -31,41 +36,52 @@ class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T
|
||||
override fun Matrix<T>.unaryMinus() =
|
||||
KomaMatrix(this.toKoma().origin.unaryMinus())
|
||||
|
||||
override fun Matrix<T>.plus(b: Matrix<T>) =
|
||||
KomaMatrix(this.toKoma().origin + b.toKoma().origin)
|
||||
override fun add(a: Matrix<T>, b: Matrix<T>) =
|
||||
KomaMatrix(a.toKoma().origin + b.toKoma().origin)
|
||||
|
||||
override fun Matrix<T>.minus(b: Matrix<T>) =
|
||||
KomaMatrix(this.toKoma().origin - b.toKoma().origin)
|
||||
|
||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||
produce(a.rowNum, a.colNum) { i, j -> space.run { a[i, j] * k } }
|
||||
|
||||
override fun Matrix<T>.times(value: T) =
|
||||
KomaMatrix(this.toKoma().origin * value)
|
||||
|
||||
companion object {
|
||||
|
||||
override fun solve(a: Matrix<T>, b: Matrix<T>) =
|
||||
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
||||
}
|
||||
|
||||
override fun inverse(a: Matrix<T>) =
|
||||
KomaMatrix(a.toKoma().origin.inv())
|
||||
}
|
||||
|
||||
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) :
|
||||
Matrix<T> {
|
||||
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
|
||||
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
||||
|
||||
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Point<T>) =
|
||||
KomaVector(a.toKoma().origin.solve(b.toKoma().origin))
|
||||
|
||||
fun <T : Any> KomaMatrixContext<T>.inverse(a: Matrix<T>) =
|
||||
KomaMatrix(a.toKoma().origin.inv())
|
||||
|
||||
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : FeaturedMatrix<T> {
|
||||
override val rowNum: Int get() = origin.numRows()
|
||||
override val colNum: Int get() = origin.numCols()
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||
|
||||
override val features: Set<MatrixFeature> = features ?: setOf(
|
||||
object : DeterminantFeature<T> {
|
||||
override val determinant: T get() = origin.det()
|
||||
},
|
||||
object : LUPDecompositionFeature<T> {
|
||||
private val lup by lazy { origin.LU() }
|
||||
override val l: Matrix<T> get() = KomaMatrix(lup.second)
|
||||
override val u: Matrix<T> get() = KomaMatrix(lup.third)
|
||||
override val p: Matrix<T> get() = KomaMatrix(lup.first)
|
||||
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
|
||||
override val u: FeaturedMatrix<T> get() = KomaMatrix(lup.third)
|
||||
override val p: FeaturedMatrix<T> get() = KomaMatrix(lup.first)
|
||||
}
|
||||
)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): Matrix<T> =
|
||||
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
|
||||
KomaMatrix(this.origin, this.features + features)
|
||||
|
||||
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
||||
|
@ -1,50 +1,3 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
`multiplatform-config`
|
||||
}
|
||||
|
||||
val ioVersion: String by rootProject.extra
|
||||
|
||||
|
||||
kotlin {
|
||||
jvm()
|
||||
js()
|
||||
|
||||
sourceSets {
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib"))
|
||||
}
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-jdk8"))
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
val jsMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-js"))
|
||||
}
|
||||
}
|
||||
val jsTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
}
|
||||
// mingwMain {
|
||||
// }
|
||||
// mingwTest {
|
||||
// }
|
||||
}
|
||||
}
|
29
kmath-prob/build.gradle.kts
Normal file
29
kmath-prob/build.gradle.kts
Normal file
@ -0,0 +1,29 @@
|
||||
plugins {
|
||||
`multiplatform-config`
|
||||
id("kotlinx-atomicfu") version Versions.atomicfuVersion
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api(project(":kmath-coroutines"))
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
jvmMain {
|
||||
dependencies {
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
jsMain {
|
||||
dependencies {
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
atomicfu {
|
||||
variant = "VH"
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
plugins {
|
||||
kotlin("multiplatform")
|
||||
id("kotlinx-atomicfu") version "0.12.1"
|
||||
}
|
||||
|
||||
val atomicfuVersion: String by rootProject.extra
|
||||
|
||||
kotlin {
|
||||
jvm ()
|
||||
//js()
|
||||
|
||||
sourceSets {
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api(project(":kmath-coroutines"))
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:$atomicfuVersion")
|
||||
}
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu:$atomicfuVersion")
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
// val jsMain by getting {
|
||||
// dependencies {
|
||||
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:$atomicfuVersion")
|
||||
// }
|
||||
// }
|
||||
// val jsTest by getting {
|
||||
// dependencies {
|
||||
// implementation(kotlin("test-js"))
|
||||
// }
|
||||
// }
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
atomicfu {
|
||||
variant = "VH"
|
||||
}
|
@ -1,82 +0,0 @@
|
||||
package scientifik.kmath.sequential
|
||||
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.channels.produce
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
|
||||
/**
|
||||
* A processor that collects incoming elements into fixed size buffers
|
||||
*/
|
||||
@ExperimentalCoroutinesApi
|
||||
class JoinProcessor<T>(
|
||||
scope: CoroutineScope,
|
||||
bufferSize: Int,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||
) : AbstractProcessor<T, Buffer<T>>(scope) {
|
||||
|
||||
private val input = Channel<T>(bufferSize)
|
||||
|
||||
private val output = produce(coroutineContext) {
|
||||
val list = ArrayList<T>(bufferSize)
|
||||
while (isActive) {
|
||||
list.clear()
|
||||
repeat(bufferSize) {
|
||||
list.add(input.receive())
|
||||
}
|
||||
val buffer = bufferFactory(bufferSize) { list[it] }
|
||||
send(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun receive(): Buffer<T> = output.receive()
|
||||
|
||||
override suspend fun send(value: T) {
|
||||
input.send(value)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A processor that splits incoming buffers into individual elements
|
||||
*/
|
||||
class SplitProcessor<T>(scope: CoroutineScope) : AbstractProcessor<Buffer<T>, T>(scope) {
|
||||
|
||||
private val input = Channel<Buffer<T>>()
|
||||
|
||||
private val mutex = Mutex()
|
||||
|
||||
private var currentBuffer: Buffer<T>? = null
|
||||
|
||||
private var pos = 0
|
||||
|
||||
|
||||
override suspend fun receive(): T {
|
||||
mutex.withLock {
|
||||
while (currentBuffer == null || pos == currentBuffer!!.size) {
|
||||
currentBuffer = input.receive()
|
||||
pos = 0
|
||||
}
|
||||
return currentBuffer!![pos].also { pos++ }
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun send(value: Buffer<T>) {
|
||||
input.send(value)
|
||||
}
|
||||
}
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
fun <T> Producer<T>.chunked(chunkSize: Int, bufferFactory: BufferFactory<T>) =
|
||||
JoinProcessor<T>(this, chunkSize, bufferFactory).also { connect(it) }
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
inline fun <reified T : Any> Producer<T>.chunked(chunkSize: Int) =
|
||||
JoinProcessor<T>(this, chunkSize, Buffer.Companion::auto).also { connect(it) }
|
||||
|
||||
|
||||
|
@ -1,275 +0,0 @@
|
||||
package scientifik.kmath.sequential
|
||||
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import kotlinx.coroutines.channels.*
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
/**
|
||||
* Initial chain block. Could produce an element sequence and be connected to single [Consumer]
|
||||
*
|
||||
* The general rule is that channel is created on first call. Also each element is responsible for its connection so
|
||||
* while the connections are symmetric, the scope, used for making the connection is responsible for cancelation.
|
||||
*
|
||||
* Also connections are not reversible. Once connected block stays faithful until it finishes processing.
|
||||
* Manually putting elements to connected block could lead to undetermined behavior and must be avoided.
|
||||
*/
|
||||
interface Producer<T> : CoroutineScope {
|
||||
fun connect(consumer: Consumer<T>)
|
||||
|
||||
suspend fun receive(): T
|
||||
|
||||
val consumer: Consumer<T>?
|
||||
|
||||
val outputIsConnected: Boolean get() = consumer != null
|
||||
|
||||
//fun close()
|
||||
}
|
||||
|
||||
/**
|
||||
* Terminal chain block. Could consume an element sequence and be connected to signle [Producer]
|
||||
*/
|
||||
interface Consumer<T> : CoroutineScope {
|
||||
fun connect(producer: Producer<T>)
|
||||
|
||||
suspend fun send(value: T)
|
||||
|
||||
val producer: Producer<T>?
|
||||
|
||||
val inputIsConnected: Boolean get() = producer != null
|
||||
|
||||
//fun close()
|
||||
}
|
||||
|
||||
interface Processor<T, R> : Consumer<T>, Producer<R>
|
||||
|
||||
abstract class AbstractProducer<T>(scope: CoroutineScope) : Producer<T> {
|
||||
override val coroutineContext: CoroutineContext = scope.coroutineContext
|
||||
|
||||
override var consumer: Consumer<T>? = null
|
||||
protected set
|
||||
|
||||
override fun connect(consumer: Consumer<T>) {
|
||||
//Ignore if already connected to specific consumer
|
||||
if (consumer != this.consumer) {
|
||||
if (outputIsConnected) error("The output slot of producer is occupied")
|
||||
if (consumer.inputIsConnected) error("The input slot of consumer is occupied")
|
||||
this.consumer = consumer
|
||||
if (consumer.producer != null) {
|
||||
//No need to save the job, it will be canceled on scope cancel
|
||||
connectOutput(consumer)
|
||||
// connect back, consumer is already set so no circular reference
|
||||
consumer.connect(this)
|
||||
} else error("Unreachable statement")
|
||||
}
|
||||
}
|
||||
|
||||
protected open fun connectOutput(consumer: Consumer<T>) {
|
||||
launch {
|
||||
while (this.isActive) {
|
||||
consumer.send(receive())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract class AbstractConsumer<T>(scope: CoroutineScope) : Consumer<T> {
|
||||
override val coroutineContext: CoroutineContext = scope.coroutineContext
|
||||
|
||||
override var producer: Producer<T>? = null
|
||||
protected set
|
||||
|
||||
override fun connect(producer: Producer<T>) {
|
||||
//Ignore if already connected to specific consumer
|
||||
if (producer != this.producer) {
|
||||
if (inputIsConnected) error("The input slot of consumer is occupied")
|
||||
if (producer.outputIsConnected) error("The input slot of producer is occupied")
|
||||
this.producer = producer
|
||||
//No need to save the job, it will be canceled on scope cancel
|
||||
if (producer.consumer != null) {
|
||||
connectInput(producer)
|
||||
// connect back
|
||||
producer.connect(this)
|
||||
} else error("Unreachable statement")
|
||||
}
|
||||
}
|
||||
|
||||
protected open fun connectInput(producer: Producer<T>) {
|
||||
launch {
|
||||
while (isActive) {
|
||||
send(producer.receive())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract class AbstractProcessor<T, R>(scope: CoroutineScope) : Processor<T, R>, AbstractProducer<R>(scope) {
|
||||
|
||||
override var producer: Producer<T>? = null
|
||||
protected set
|
||||
|
||||
override fun connect(producer: Producer<T>) {
|
||||
//Ignore if already connected to specific consumer
|
||||
if (producer != this.producer) {
|
||||
if (inputIsConnected) error("The input slot of consumer is occupied")
|
||||
if (producer.outputIsConnected) error("The input slot of producer is occupied")
|
||||
this.producer = producer
|
||||
//No need to save the job, it will be canceled on scope cancel
|
||||
if (producer.consumer != null) {
|
||||
connectInput(producer)
|
||||
// connect back
|
||||
producer.connect(this)
|
||||
} else error("Unreachable statement")
|
||||
}
|
||||
}
|
||||
|
||||
protected open fun connectInput(producer: Producer<T>) {
|
||||
launch {
|
||||
while (isActive) {
|
||||
send(producer.receive())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A simple [produce]-based producer
|
||||
*/
|
||||
class GenericProducer<T>(
|
||||
scope: CoroutineScope,
|
||||
capacity: Int = Channel.UNLIMITED,
|
||||
block: suspend ProducerScope<T>.() -> Unit
|
||||
) : AbstractProducer<T>(scope) {
|
||||
|
||||
private val channel: ReceiveChannel<T> by lazy { produce(capacity = capacity, block = block) }
|
||||
|
||||
override suspend fun receive(): T = channel.receive()
|
||||
}
|
||||
|
||||
/**
|
||||
* A simple pipeline [Processor] block
|
||||
*/
|
||||
class PipeProcessor<T, R>(
|
||||
scope: CoroutineScope,
|
||||
capacity: Int = Channel.RENDEZVOUS,
|
||||
process: suspend (T) -> R
|
||||
) : AbstractProcessor<T, R>(scope) {
|
||||
|
||||
private val input = Channel<T>(capacity)
|
||||
private val output: ReceiveChannel<R> = input.map(coroutineContext, process)
|
||||
|
||||
override suspend fun receive(): R = output.receive()
|
||||
|
||||
override suspend fun send(value: T) {
|
||||
input.send(value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A moving window [Processor] with circular buffer
|
||||
*/
|
||||
class WindowedProcessor<T, R>(
|
||||
scope: CoroutineScope,
|
||||
window: Int,
|
||||
val process: suspend (Buffer<T?>) -> R
|
||||
) : AbstractProcessor<T, R>(scope) {
|
||||
|
||||
private val ringBuffer = RingBuffer.boxing<T>(window)
|
||||
|
||||
private val channel = Channel<R>(Channel.RENDEZVOUS)
|
||||
|
||||
override suspend fun receive(): R {
|
||||
return channel.receive()
|
||||
}
|
||||
|
||||
override suspend fun send(value: T) {
|
||||
ringBuffer.push(value)
|
||||
channel.send(process(ringBuffer.snapshot()))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Thread-safe aggregator of values from input. The aggregator does not store all incoming values, it uses fold procedure
|
||||
* to incorporate them into state on-arrival.
|
||||
* The current aggregated state could be accessed by [state]. The input channel is inactive unless requested
|
||||
* @param T - the type of the input element
|
||||
* @param S - the type of the aggregator
|
||||
*/
|
||||
class Reducer<T, S>(
|
||||
scope: CoroutineScope,
|
||||
initialState: S,
|
||||
val fold: suspend (S, T) -> S
|
||||
) : AbstractConsumer<T>(scope) {
|
||||
|
||||
var state: S = initialState
|
||||
private set
|
||||
|
||||
private val mutex = Mutex()
|
||||
|
||||
override suspend fun send(value: T) = mutex.withLock {
|
||||
state = fold(state, value)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collector that accumulates all values in a list. List could be accessed from non-suspending environment via [list] value.
|
||||
*/
|
||||
class Collector<T>(scope: CoroutineScope) : AbstractConsumer<T>(scope) {
|
||||
|
||||
private val _list = ArrayList<T>()
|
||||
private val mutex = Mutex()
|
||||
val list: List<T> get() = _list
|
||||
|
||||
override suspend fun send(value: T) {
|
||||
mutex.withLock {
|
||||
_list.add(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a sequence to [Producer]
|
||||
*/
|
||||
fun <T> Sequence<T>.produce(scope: CoroutineScope = GlobalScope) =
|
||||
GenericProducer<T>(scope) { forEach { send(it) } }
|
||||
|
||||
/**
|
||||
* Convert a [ReceiveChannel] to [Producer]
|
||||
*/
|
||||
fun <T> ReceiveChannel<T>.produce(scope: CoroutineScope = GlobalScope) =
|
||||
GenericProducer<T>(scope) { for (e in this@produce) send(e) }
|
||||
|
||||
|
||||
fun <T, C : Consumer<T>> Producer<T>.consumer(consumerFactory: () -> C): C =
|
||||
consumerFactory().also { connect(it) }
|
||||
|
||||
fun <T, R> Producer<T>.map(capacity: Int = Channel.RENDEZVOUS, process: suspend (T) -> R) =
|
||||
PipeProcessor(this, capacity, process).also { connect(it) }
|
||||
|
||||
/**
|
||||
* Create a reducer and connect this producer to reducer
|
||||
*/
|
||||
fun <T, S> Producer<T>.reduce(initialState: S, fold: suspend (S, T) -> S) =
|
||||
Reducer(this, initialState, fold).also { connect(it) }
|
||||
|
||||
/**
|
||||
* Create a [Collector] and attach it to this [Producer]
|
||||
*/
|
||||
fun <T> Producer<T>.collect() =
|
||||
Collector<T>(this).also { connect(it) }
|
||||
|
||||
fun <T, R, P : Processor<T, R>> Producer<T>.process(processorBuilder: () -> P): P =
|
||||
processorBuilder().also { connect(it) }
|
||||
|
||||
fun <T, R> Producer<T>.process(capacity: Int = Channel.RENDEZVOUS, process: suspend (T) -> R) =
|
||||
PipeProcessor<T, R>(this, capacity, process).also { connect(it) }
|
||||
|
||||
|
||||
fun <T, R> Producer<T>.windowed(window: Int, process: suspend (Buffer<T?>) -> R) =
|
||||
WindowedProcessor(this, window, process).also { connect(it) }
|
@ -1,19 +0,0 @@
|
||||
package scientifik.kmath.sequential
|
||||
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import scientifik.kmath.structures.asSequence
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class RingBufferTest {
|
||||
@Test
|
||||
fun testPush() {
|
||||
val buffer = RingBuffer.build(20, Double.NaN)
|
||||
runBlocking {
|
||||
for (i in 1..30) {
|
||||
buffer.push(i.toDouble())
|
||||
}
|
||||
assertEquals(410.0, buffer.asSequence().sum())
|
||||
}
|
||||
}
|
||||
}
|
@ -26,6 +26,6 @@ include(
|
||||
":kmath-histograms",
|
||||
":kmath-commons",
|
||||
":kmath-koma",
|
||||
":kmath-sequential",
|
||||
":kmath-prob",
|
||||
":benchmarks"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user