Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
commit
0622be2494
6
.github/workflows/build.yml
vendored
6
.github/workflows/build.yml
vendored
@ -13,9 +13,11 @@ jobs:
|
||||
- name: Checkout the repo
|
||||
uses: actions/checkout@v2
|
||||
- name: Set up JDK 11
|
||||
uses: actions/setup-java@v1
|
||||
uses: DeLaGuardo/setup-graalvm@4.0
|
||||
with:
|
||||
java-version: 11
|
||||
graalvm: 21.1.0
|
||||
java: java11
|
||||
arch: amd64
|
||||
- name: Add msys to path
|
||||
if: matrix.os == 'windows-latest'
|
||||
run: SETX PATH "%PATH%;C:\msys64\mingw64\bin"
|
||||
|
10
.github/workflows/pages.yml
vendored
10
.github/workflows/pages.yml
vendored
@ -12,9 +12,11 @@ jobs:
|
||||
- name: Checkout the repo
|
||||
uses: actions/checkout@v2
|
||||
- name: Set up JDK 11
|
||||
uses: actions/setup-java@v1
|
||||
uses: DeLaGuardo/setup-graalvm@4.0
|
||||
with:
|
||||
java-version: 11
|
||||
graalvm: 21.1.0
|
||||
java: java11
|
||||
arch: amd64
|
||||
- name: Cache gradle
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
@ -30,9 +32,7 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gradle-
|
||||
- name: Build
|
||||
run: |
|
||||
./gradlew dokkaHtmlMultiModule --no-daemon --no-parallel --stacktrace
|
||||
mv build/dokka/htmlMultiModule/-modules.html build/dokka/htmlMultiModule/index.html
|
||||
run: ./gradlew dokkaHtmlMultiModule --no-daemon --no-parallel --stacktrace
|
||||
- name: Deploy to GitHub Pages
|
||||
uses: JamesIves/github-pages-deploy-action@4.1.0
|
||||
with:
|
||||
|
6
.github/workflows/publish.yml
vendored
6
.github/workflows/publish.yml
vendored
@ -18,9 +18,11 @@ jobs:
|
||||
- name: Checkout the repo
|
||||
uses: actions/checkout@v2
|
||||
- name: Set up JDK 11
|
||||
uses: actions/setup-java@v1
|
||||
uses: DeLaGuardo/setup-graalvm@4.0
|
||||
with:
|
||||
java-version: 11
|
||||
graalvm: 21.1.0
|
||||
java: java11
|
||||
arch: amd64
|
||||
- name: Add msys to path
|
||||
if: matrix.os == 'windows-latest'
|
||||
run: SETX PATH "%PATH%;C:\msys64\mingw64\bin"
|
||||
|
@ -10,7 +10,8 @@
|
||||
- Blocking chains and Statistics
|
||||
- Multiplatform integration
|
||||
- Integration for any Field element
|
||||
- Extendend operations for ND4J fields
|
||||
- Extended operations for ND4J fields
|
||||
- Jupyter Notebook integration module (kmath-jupyter)
|
||||
|
||||
### Changed
|
||||
- Exponential operations merged with hyperbolic functions
|
||||
@ -24,6 +25,7 @@
|
||||
- Redesign MST. Remove MSTExpression.
|
||||
- Move MST to core
|
||||
- Separated benchmarks and examples
|
||||
- Rewritten EJML module without ejml-simple
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
26
README.md
26
README.md
@ -91,7 +91,7 @@ KMath is a modular library. Different modules provide different features with di
|
||||
* ### [kmath-ast](kmath-ast)
|
||||
>
|
||||
>
|
||||
> **Maturity**: PROTOTYPE
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
>
|
||||
> **Features:**
|
||||
> - [expression-language](kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
||||
@ -154,9 +154,9 @@ performance calculations to code generation.
|
||||
> **Maturity**: PROTOTYPE
|
||||
>
|
||||
> **Features:**
|
||||
> - [ejml-vector](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : The Point implementation using SimpleMatrix.
|
||||
> - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : The Matrix implementation using SimpleMatrix.
|
||||
> - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : The LinearSpace implementation using SimpleMatrix.
|
||||
> - [ejml-vector](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : Point implementations.
|
||||
> - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation.
|
||||
> - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations.
|
||||
|
||||
<hr/>
|
||||
|
||||
@ -200,6 +200,12 @@ One can still use generic algebras though.
|
||||
> **Maturity**: PROTOTYPE
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-jupyter](kmath-jupyter)
|
||||
>
|
||||
>
|
||||
> **Maturity**: PROTOTYPE
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-kotlingrad](kmath-kotlingrad)
|
||||
>
|
||||
>
|
||||
@ -230,6 +236,18 @@ One can still use generic algebras though.
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-tensors](kmath-tensors)
|
||||
>
|
||||
>
|
||||
> **Maturity**: PROTOTYPE
|
||||
>
|
||||
> **Features:**
|
||||
> - [tensor algebra](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.)
|
||||
> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
|
||||
> - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
|
||||
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-viktor](kmath-viktor)
|
||||
>
|
||||
>
|
||||
|
@ -9,14 +9,10 @@ sourceSets.register("benchmarks")
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
jcenter()
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://jitpack.io")
|
||||
maven {
|
||||
setUrl("http://logicrunch.research.it.uu.se/maven/")
|
||||
maven("http://logicrunch.research.it.uu.se/maven") {
|
||||
isAllowInsecureProtocol = true
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import kotlinx.benchmark.Blackhole
|
||||
import kotlinx.benchmark.Scope
|
||||
import kotlinx.benchmark.State
|
||||
import space.kscience.kmath.commons.linear.CMLinearSpace
|
||||
import space.kscience.kmath.ejml.EjmlLinearSpace
|
||||
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||
import space.kscience.kmath.linear.LinearSpace
|
||||
import space.kscience.kmath.linear.invoke
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
@ -29,8 +29,8 @@ internal class DotBenchmark {
|
||||
val cmMatrix1 = CMLinearSpace { matrix1.toCM() }
|
||||
val cmMatrix2 = CMLinearSpace { matrix2.toCM() }
|
||||
|
||||
val ejmlMatrix1 = EjmlLinearSpace { matrix1.toEjml() }
|
||||
val ejmlMatrix2 = EjmlLinearSpace { matrix2.toEjml() }
|
||||
val ejmlMatrix1 = EjmlLinearSpaceDDRM { matrix1.toEjml() }
|
||||
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@ -42,14 +42,14 @@ internal class DotBenchmark {
|
||||
|
||||
@Benchmark
|
||||
fun ejmlDot(blackhole: Blackhole) {
|
||||
EjmlLinearSpace {
|
||||
EjmlLinearSpaceDDRM {
|
||||
blackhole.consume(ejmlMatrix1 dot ejmlMatrix2)
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun ejmlDotWithConversion(blackhole: Blackhole) {
|
||||
EjmlLinearSpace {
|
||||
EjmlLinearSpaceDDRM {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
}
|
||||
}
|
||||
|
@ -11,25 +11,26 @@ import kotlinx.benchmark.Scope
|
||||
import kotlinx.benchmark.State
|
||||
import space.kscience.kmath.commons.linear.CMLinearSpace
|
||||
import space.kscience.kmath.commons.linear.inverse
|
||||
import space.kscience.kmath.ejml.EjmlLinearSpace
|
||||
import space.kscience.kmath.ejml.inverse
|
||||
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||
import space.kscience.kmath.linear.InverseMatrixFeature
|
||||
import space.kscience.kmath.linear.LinearSpace
|
||||
import space.kscience.kmath.linear.inverseWithLup
|
||||
import space.kscience.kmath.linear.invoke
|
||||
import space.kscience.kmath.nd.getFeature
|
||||
import kotlin.random.Random
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
internal class MatrixInverseBenchmark {
|
||||
companion object {
|
||||
val random = Random(1224)
|
||||
const val dim = 100
|
||||
private companion object {
|
||||
private val random = Random(1224)
|
||||
private const val dim = 100
|
||||
|
||||
private val space = LinearSpace.real
|
||||
|
||||
//creating invertible matrix
|
||||
val u = space.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||
val l = space.buildMatrix(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
|
||||
val matrix = space { l dot u }
|
||||
private val u = space.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||
private val l = space.buildMatrix(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
|
||||
private val matrix = space { l dot u }
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@ -46,8 +47,8 @@ internal class MatrixInverseBenchmark {
|
||||
|
||||
@Benchmark
|
||||
fun ejmlInverse(blackhole: Blackhole) {
|
||||
with(EjmlLinearSpace) {
|
||||
blackhole.consume(inverse(matrix))
|
||||
with(EjmlLinearSpaceDDRM) {
|
||||
blackhole.consume(matrix.getFeature<InverseMatrixFeature<Double>>()?.inverse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,17 +1,16 @@
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.project")
|
||||
kotlin("jupyter.api") apply false
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://jitpack.io")
|
||||
maven("http://logicrunch.research.it.uu.se/maven/") {
|
||||
maven("http://logicrunch.research.it.uu.se/maven") {
|
||||
isAllowInsecureProtocol = true
|
||||
}
|
||||
maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven")
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
@ -23,22 +22,16 @@ subprojects {
|
||||
if (name.startsWith("kmath")) apply<MavenPublishPlugin>()
|
||||
|
||||
afterEvaluate {
|
||||
tasks.withType<org.jetbrains.dokka.gradle.DokkaTask> {
|
||||
dokkaSourceSets.all {
|
||||
val readmeFile = File(this@subprojects.projectDir, "./README.md")
|
||||
if (readmeFile.exists())
|
||||
includes.setFrom(includes + readmeFile.absolutePath)
|
||||
tasks.withType<org.jetbrains.dokka.gradle.DokkaTaskPartial> {
|
||||
dependsOn(tasks.getByName("assemble"))
|
||||
|
||||
arrayOf(
|
||||
"http://ejml.org/javadoc/",
|
||||
"https://commons.apache.org/proper/commons-math/javadocs/api-3.6.1/",
|
||||
"https://deeplearning4j.org/api/latest/"
|
||||
).map { java.net.URL("${it}package-list") to java.net.URL(it) }.forEach { (a, b) ->
|
||||
externalDocumentationLink {
|
||||
packageListUrl.set(a)
|
||||
url.set(b)
|
||||
}
|
||||
}
|
||||
dokkaSourceSets.all {
|
||||
val readmeFile = File(this@subprojects.projectDir, "README.md")
|
||||
if (readmeFile.exists()) includes.setFrom(includes + readmeFile.absolutePath)
|
||||
externalDocumentationLink("http://ejml.org/javadoc/")
|
||||
externalDocumentationLink("https://commons.apache.org/proper/commons-math/javadocs/api-3.6.1/")
|
||||
externalDocumentationLink("https://deeplearning4j.org/api/latest/")
|
||||
externalDocumentationLink("https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
6
docs/templates/ARTIFACT-TEMPLATE.md
vendored
6
docs/templates/ARTIFACT-TEMPLATE.md
vendored
@ -6,8 +6,7 @@ The Maven coordinates of this project are `${group}:${name}:${version}`.
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -18,8 +17,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
@ -4,14 +4,11 @@ plugins {
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
jcenter()
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://jitpack.io")
|
||||
maven{
|
||||
setUrl("http://logicrunch.research.it.uu.se/maven/")
|
||||
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
|
||||
maven("http://logicrunch.research.it.uu.se/maven") {
|
||||
isAllowInsecureProtocol = true
|
||||
}
|
||||
}
|
||||
@ -28,6 +25,7 @@ dependencies {
|
||||
implementation(project(":kmath-dimensions"))
|
||||
implementation(project(":kmath-ejml"))
|
||||
implementation(project(":kmath-nd4j"))
|
||||
implementation(project(":kmath-tensors"))
|
||||
|
||||
implementation(project(":kmath-for-real"))
|
||||
|
||||
|
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
|
||||
|
||||
|
||||
// Dataset normalization
|
||||
|
||||
fun main() {
|
||||
|
||||
// work in context with broadcast methods
|
||||
BroadcastDoubleTensorAlgebra {
|
||||
// take dataset of 5-element vectors from normal distribution
|
||||
val dataset = randomNormal(intArrayOf(100, 5)) * 1.5 // all elements from N(0, 1.5)
|
||||
|
||||
dataset += fromArray(
|
||||
intArrayOf(5),
|
||||
doubleArrayOf(0.0, 1.0, 1.5, 3.0, 5.0) // rows means
|
||||
)
|
||||
|
||||
|
||||
// find out mean and standard deviation of each column
|
||||
val mean = dataset.mean(0, false)
|
||||
val std = dataset.std(0, false)
|
||||
|
||||
println("Mean:\n$mean")
|
||||
println("Standard deviation:\n$std")
|
||||
|
||||
// also we can calculate other statistic as minimum and maximum of rows
|
||||
println("Minimum:\n${dataset.min(0, false)}")
|
||||
println("Maximum:\n${dataset.max(0, false)}")
|
||||
|
||||
// now we can scale dataset with mean normalization
|
||||
val datasetScaled = (dataset - mean) / std
|
||||
|
||||
// find out mean and std of scaled dataset
|
||||
|
||||
println("Mean of scaled:\n${datasetScaled.mean(0, false)}")
|
||||
println("Mean of scaled:\n${datasetScaled.std(0, false)}")
|
||||
}
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
|
||||
|
||||
// solving linear system with LUP decomposition
|
||||
|
||||
fun main () {
|
||||
|
||||
// work in context with linear operations
|
||||
BroadcastDoubleTensorAlgebra {
|
||||
|
||||
// set true value of x
|
||||
val trueX = fromArray(
|
||||
intArrayOf(4),
|
||||
doubleArrayOf(-2.0, 1.5, 6.8, -2.4)
|
||||
)
|
||||
|
||||
// and A matrix
|
||||
val a = fromArray(
|
||||
intArrayOf(4, 4),
|
||||
doubleArrayOf(
|
||||
0.5, 10.5, 4.5, 1.0,
|
||||
8.5, 0.9, 12.8, 0.1,
|
||||
5.56, 9.19, 7.62, 5.45,
|
||||
1.0, 2.0, -3.0, -2.5
|
||||
)
|
||||
)
|
||||
|
||||
// calculate y value
|
||||
val b = a dot trueX
|
||||
|
||||
// check out A and b
|
||||
println("A:\n$a")
|
||||
println("b:\n$b")
|
||||
|
||||
// solve `Ax = b` system using LUP decomposition
|
||||
|
||||
// get P, L, U such that PA = LU
|
||||
val (p, l, u) = a.lu()
|
||||
|
||||
// check that P is permutation matrix
|
||||
println("P:\n$p")
|
||||
// L is lower triangular matrix and U is upper triangular matrix
|
||||
println("L:\n$l")
|
||||
println("U:\n$u")
|
||||
// and PA = LU
|
||||
println("PA:\n${p dot a}")
|
||||
println("LU:\n${l dot u}")
|
||||
|
||||
/* Ax = b;
|
||||
PAx = Pb;
|
||||
LUx = Pb;
|
||||
let y = Ux, then
|
||||
Ly = Pb -- this system can be easily solved, since the matrix L is lower triangular;
|
||||
Ux = y can be solved the same way, since the matrix L is upper triangular
|
||||
*/
|
||||
|
||||
|
||||
|
||||
// this function returns solution x of a system lx = b, l should be lower triangular
|
||||
fun solveLT(l: DoubleTensor, b: DoubleTensor): DoubleTensor {
|
||||
val n = l.shape[0]
|
||||
val x = zeros(intArrayOf(n))
|
||||
for (i in 0 until n){
|
||||
x[intArrayOf(i)] = (b[intArrayOf(i)] - l[i].dot(x).value()) / l[intArrayOf(i, i)]
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
val y = solveLT(l, p dot b)
|
||||
|
||||
// solveLT(l, b) function can be easily adapted for upper triangular matrix by the permutation matrix revMat
|
||||
// create it by placing ones on side diagonal
|
||||
val revMat = u.zeroesLike()
|
||||
val n = revMat.shape[0]
|
||||
for (i in 0 until n) {
|
||||
revMat[intArrayOf(i, n - 1 - i)] = 1.0
|
||||
}
|
||||
|
||||
// solution of system ux = b, u should be upper triangular
|
||||
fun solveUT(u: DoubleTensor, b: DoubleTensor): DoubleTensor = revMat dot solveLT(
|
||||
revMat dot u dot revMat, revMat dot b
|
||||
)
|
||||
|
||||
val x = solveUT(u, y)
|
||||
|
||||
println("True x:\n$trueX")
|
||||
println("x founded with LU method:\n$x")
|
||||
}
|
||||
}
|
@ -0,0 +1,241 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.toDoubleArray
|
||||
import kotlin.math.sqrt
|
||||
|
||||
const val seed = 100500L
|
||||
|
||||
// Simple feedforward neural network with backpropagation training
|
||||
|
||||
// interface of network layer
|
||||
interface Layer {
|
||||
fun forward(input: DoubleTensor): DoubleTensor
|
||||
fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor
|
||||
}
|
||||
|
||||
// activation layer
|
||||
open class Activation(
|
||||
val activation: (DoubleTensor) -> DoubleTensor,
|
||||
val activationDer: (DoubleTensor) -> DoubleTensor
|
||||
) : Layer {
|
||||
override fun forward(input: DoubleTensor): DoubleTensor {
|
||||
return activation(input)
|
||||
}
|
||||
|
||||
override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor {
|
||||
return DoubleTensorAlgebra { outputError * activationDer(input) }
|
||||
}
|
||||
}
|
||||
|
||||
fun relu(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
x.map { if (it > 0) it else 0.0 }
|
||||
}
|
||||
|
||||
fun reluDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
x.map { if (it > 0) 1.0 else 0.0 }
|
||||
}
|
||||
|
||||
// activation layer with relu activator
|
||||
class ReLU : Activation(::relu, ::reluDer)
|
||||
|
||||
fun sigmoid(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
fun sigmoidDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
sigmoid(x) * (1.0 - sigmoid(x))
|
||||
}
|
||||
|
||||
// activation layer with sigmoid activator
|
||||
class Sigmoid : Activation(::sigmoid, ::sigmoidDer)
|
||||
|
||||
// dense layer
|
||||
class Dense(
|
||||
private val inputUnits: Int,
|
||||
private val outputUnits: Int,
|
||||
private val learningRate: Double = 0.1
|
||||
) : Layer {
|
||||
|
||||
private val weights: DoubleTensor = DoubleTensorAlgebra {
|
||||
randomNormal(
|
||||
intArrayOf(inputUnits, outputUnits),
|
||||
seed
|
||||
) * sqrt(2.0 / (inputUnits + outputUnits))
|
||||
}
|
||||
|
||||
private val bias: DoubleTensor = DoubleTensorAlgebra { zeros(intArrayOf(outputUnits)) }
|
||||
|
||||
override fun forward(input: DoubleTensor): DoubleTensor {
|
||||
return BroadcastDoubleTensorAlgebra { (input dot weights) + bias }
|
||||
}
|
||||
|
||||
override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
val gradInput = outputError dot weights.transpose()
|
||||
|
||||
val gradW = input.transpose() dot outputError
|
||||
val gradBias = outputError.mean(dim = 0, keepDim = false) * input.shape[0].toDouble()
|
||||
|
||||
weights -= learningRate * gradW
|
||||
bias -= learningRate * gradBias
|
||||
|
||||
gradInput
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// simple accuracy equal to the proportion of correct answers
|
||||
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
|
||||
check(yPred.shape contentEquals yTrue.shape)
|
||||
val n = yPred.shape[0]
|
||||
var correctCnt = 0
|
||||
for (i in 0 until n) {
|
||||
if (yPred[intArrayOf(i, 0)] == yTrue[intArrayOf(i, 0)]) {
|
||||
correctCnt += 1
|
||||
}
|
||||
}
|
||||
return correctCnt.toDouble() / n.toDouble()
|
||||
}
|
||||
|
||||
// neural network class
|
||||
class NeuralNetwork(private val layers: List<Layer>) {
|
||||
private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra {
|
||||
|
||||
val onesForAnswers = yPred.zeroesLike()
|
||||
yTrue.toDoubleArray().forEachIndexed { index, labelDouble ->
|
||||
val label = labelDouble.toInt()
|
||||
onesForAnswers[intArrayOf(index, label)] = 1.0
|
||||
}
|
||||
|
||||
val softmaxValue = yPred.exp() / yPred.exp().sum(dim = 1, keepDim = true)
|
||||
|
||||
(-onesForAnswers + softmaxValue) / (yPred.shape[0].toDouble())
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalStdlibApi::class)
|
||||
private fun forward(x: DoubleTensor): List<DoubleTensor> {
|
||||
var input = x
|
||||
|
||||
return buildList {
|
||||
layers.forEach { layer ->
|
||||
val output = layer.forward(input)
|
||||
add(output)
|
||||
input = output
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalStdlibApi::class)
|
||||
private fun train(xTrain: DoubleTensor, yTrain: DoubleTensor) {
|
||||
val layerInputs = buildList {
|
||||
add(xTrain)
|
||||
addAll(forward(xTrain))
|
||||
}
|
||||
|
||||
var lossGrad = softMaxLoss(layerInputs.last(), yTrain)
|
||||
|
||||
layers.zip(layerInputs).reversed().forEach { (layer, input) ->
|
||||
lossGrad = layer.backward(input, lossGrad)
|
||||
}
|
||||
}
|
||||
|
||||
fun fit(xTrain: DoubleTensor, yTrain: DoubleTensor, batchSize: Int, epochs: Int) = DoubleTensorAlgebra {
|
||||
fun iterBatch(x: DoubleTensor, y: DoubleTensor): Sequence<Pair<DoubleTensor, DoubleTensor>> = sequence {
|
||||
val n = x.shape[0]
|
||||
val shuffledIndices = (0 until n).shuffled()
|
||||
for (i in 0 until n step batchSize) {
|
||||
val excerptIndices = shuffledIndices.drop(i).take(batchSize).toIntArray()
|
||||
val batch = x.rowsByIndices(excerptIndices) to y.rowsByIndices(excerptIndices)
|
||||
yield(batch)
|
||||
}
|
||||
}
|
||||
|
||||
for (epoch in 0 until epochs) {
|
||||
println("Epoch ${epoch + 1}/$epochs")
|
||||
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
|
||||
train(xBatch, yBatch)
|
||||
}
|
||||
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}")
|
||||
}
|
||||
}
|
||||
|
||||
fun predict(x: DoubleTensor): DoubleTensor {
|
||||
return forward(x).last()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@OptIn(ExperimentalStdlibApi::class)
|
||||
fun main() {
|
||||
BroadcastDoubleTensorAlgebra {
|
||||
val features = 5
|
||||
val sampleSize = 250
|
||||
val trainSize = 180
|
||||
val testSize = sampleSize - trainSize
|
||||
|
||||
// take sample of features from normal distribution
|
||||
val x = randomNormal(intArrayOf(sampleSize, features), seed) * 2.5
|
||||
|
||||
x += fromArray(
|
||||
intArrayOf(5),
|
||||
doubleArrayOf(0.0, -1.0, -2.5, -3.0, 5.5) // rows means
|
||||
)
|
||||
|
||||
|
||||
// define class like '1' if the sum of features > 0 and '0' otherwise
|
||||
val y = fromArray(
|
||||
intArrayOf(sampleSize, 1),
|
||||
DoubleArray(sampleSize) { i ->
|
||||
if (x[i].sum() > 0.0) {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// split train ans test
|
||||
val trainIndices = (0 until trainSize).toList().toIntArray()
|
||||
val testIndices = (trainSize until sampleSize).toList().toIntArray()
|
||||
|
||||
val xTrain = x.rowsByIndices(trainIndices)
|
||||
val yTrain = y.rowsByIndices(trainIndices)
|
||||
|
||||
val xTest = x.rowsByIndices(testIndices)
|
||||
val yTest = y.rowsByIndices(testIndices)
|
||||
|
||||
// build model
|
||||
val layers = buildList {
|
||||
add(Dense(features, 64))
|
||||
add(ReLU())
|
||||
add(Dense(64, 16))
|
||||
add(ReLU())
|
||||
add(Dense(16, 2))
|
||||
add(Sigmoid())
|
||||
}
|
||||
val model = NeuralNetwork(layers)
|
||||
|
||||
// fit it with train data
|
||||
model.fit(xTrain, yTrain, batchSize = 20, epochs = 10)
|
||||
|
||||
// make prediction
|
||||
val prediction = model.predict(xTest)
|
||||
|
||||
// process raw prediction via argMax
|
||||
val predictionLabels = prediction.argMax(1, true)
|
||||
|
||||
// find out accuracy
|
||||
val acc = accuracy(yTest, predictionLabels)
|
||||
println("Test accuracy:$acc")
|
||||
|
||||
}
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
|
||||
import kotlin.math.abs
|
||||
|
||||
// OLS estimator using SVD
|
||||
|
||||
fun main() {
|
||||
//seed for random
|
||||
val randSeed = 100500L
|
||||
|
||||
// work in context with linear operations
|
||||
DoubleTensorAlgebra {
|
||||
// take coefficient vector from normal distribution
|
||||
val alpha = randomNormal(
|
||||
intArrayOf(5),
|
||||
randSeed
|
||||
) + fromArray(
|
||||
intArrayOf(5),
|
||||
doubleArrayOf(1.0, 2.5, 3.4, 5.0, 10.1)
|
||||
)
|
||||
|
||||
println("Real alpha:\n$alpha")
|
||||
|
||||
// also take sample of size 20 from normal distribution for x
|
||||
val x = randomNormal(
|
||||
intArrayOf(20, 5),
|
||||
randSeed
|
||||
)
|
||||
|
||||
// calculate y and add gaussian noise (N(0, 0.05))
|
||||
val y = x dot alpha
|
||||
y += y.randomNormalLike(randSeed) * 0.05
|
||||
|
||||
// now restore the coefficient vector with OSL estimator with SVD
|
||||
val (u, singValues, v) = x.svd()
|
||||
|
||||
// we have to make sure the singular values of the matrix are not close to zero
|
||||
println("Singular values:\n$singValues")
|
||||
|
||||
|
||||
// inverse Sigma matrix can be restored from singular values with diagonalEmbedding function
|
||||
val sigma = diagonalEmbedding(singValues.map{ x -> if (abs(x) < 1e-3) 0.0 else 1.0/x })
|
||||
|
||||
val alphaOLS = v dot sigma dot u.transpose() dot y
|
||||
println("Estimated alpha:\n" +
|
||||
"$alphaOLS")
|
||||
|
||||
// figure out MSE of approximation
|
||||
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
|
||||
require(yTrue.shape.size == 1)
|
||||
require(yTrue.shape contentEquals yPred.shape)
|
||||
|
||||
val diff = yTrue - yPred
|
||||
return diff.dot(diff).sqrt().value()
|
||||
}
|
||||
|
||||
println("MSE: ${mse(alpha, alphaOLS)}")
|
||||
}
|
||||
}
|
78
examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt
Normal file
78
examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt
Normal file
@ -0,0 +1,78 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
|
||||
|
||||
|
||||
// simple PCA
|
||||
|
||||
fun main(){
|
||||
val seed = 100500L
|
||||
|
||||
// work in context with broadcast methods
|
||||
BroadcastDoubleTensorAlgebra {
|
||||
|
||||
// assume x is range from 0 until 10
|
||||
val x = fromArray(
|
||||
intArrayOf(10),
|
||||
(0 until 10).toList().map { it.toDouble() }.toDoubleArray()
|
||||
)
|
||||
|
||||
// take y dependent on x with noise
|
||||
val y = 2.0 * x + (3.0 + x.randomNormalLike(seed) * 1.5)
|
||||
|
||||
println("x:\n$x")
|
||||
println("y:\n$y")
|
||||
|
||||
// stack them into single dataset
|
||||
val dataset = stack(listOf(x, y)).transpose()
|
||||
|
||||
// normalize both x and y
|
||||
val xMean = x.mean()
|
||||
val yMean = y.mean()
|
||||
|
||||
val xStd = x.std()
|
||||
val yStd = y.std()
|
||||
|
||||
val xScaled = (x - xMean) / xStd
|
||||
val yScaled = (y - yMean) / yStd
|
||||
|
||||
// save means ans standard deviations for further recovery
|
||||
val mean = fromArray(
|
||||
intArrayOf(2),
|
||||
doubleArrayOf(xMean, yMean)
|
||||
)
|
||||
println("Means:\n$mean")
|
||||
|
||||
val std = fromArray(
|
||||
intArrayOf(2),
|
||||
doubleArrayOf(xStd, yStd)
|
||||
)
|
||||
println("Standard deviations:\n$std")
|
||||
|
||||
// calculate the covariance matrix of scaled x and y
|
||||
val covMatrix = cov(listOf(xScaled, yScaled))
|
||||
println("Covariance matrix:\n$covMatrix")
|
||||
|
||||
// and find out eigenvector of it
|
||||
val (_, evecs) = covMatrix.symEig()
|
||||
val v = evecs[0]
|
||||
println("Eigenvector:\n$v")
|
||||
|
||||
// reduce dimension of dataset
|
||||
val datasetReduced = v dot stack(listOf(xScaled, yScaled))
|
||||
println("Reduced data:\n$datasetReduced")
|
||||
|
||||
// we can restore original data from reduced data.
|
||||
// for example, find 7th element of dataset
|
||||
val n = 7
|
||||
val restored = (datasetReduced[n] dot v.view(intArrayOf(1, 2))) * std + mean
|
||||
println("Original value:\n${dataset[n]}")
|
||||
println("Restored value:\n$restored")
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
# Module kmath-ast
|
||||
|
||||
Abstract syntax tree expression representation and related optimizations.
|
||||
Performance and visualization extensions to MST API.
|
||||
|
||||
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
||||
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
|
||||
@ -16,8 +16,7 @@ The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-7`
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -28,8 +27,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -41,21 +39,26 @@ dependencies {
|
||||
|
||||
### On JVM
|
||||
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
||||
special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
|
||||
For example, the following builder:
|
||||
|
||||
```kotlin
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.asm.*
|
||||
|
||||
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
|
||||
```
|
||||
|
||||
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||
... leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||
|
||||
```java
|
||||
package space.kscience.kmath.asm.generated;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import kotlin.jvm.functions.Function2;
|
||||
import space.kscience.kmath.asm.internal.MapIntrinsics;
|
||||
import space.kscience.kmath.expressions.Expression;
|
||||
@ -65,7 +68,7 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
private final Object[] constants;
|
||||
|
||||
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
||||
return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2);
|
||||
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
|
||||
}
|
||||
|
||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
||||
@ -77,8 +80,8 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
|
||||
#### Known issues
|
||||
|
||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
||||
class loading overhead.
|
||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
|
||||
loading overhead.
|
||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
||||
|
||||
### On JS
|
||||
@ -86,6 +89,10 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
A similar feature is also available on JS.
|
||||
|
||||
```kotlin
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.estree.*
|
||||
|
||||
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
|
||||
```
|
||||
|
||||
@ -93,18 +100,22 @@ The code above returns expression implemented with such a JS function:
|
||||
|
||||
```js
|
||||
var executable = function (constants, arguments) {
|
||||
return constants[1](constants[0](arguments, "x"), 2);
|
||||
return constants[1](constants[0](arguments, "x"), 2);
|
||||
};
|
||||
```
|
||||
|
||||
JS also supports very experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation.
|
||||
Currently, only expressions inside `DoubleField` and `IntRing` are supported.
|
||||
|
||||
```kotlin
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.wasm.*
|
||||
|
||||
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
|
||||
```
|
||||
|
||||
An example of emitted WASM IR in the form of WAT:
|
||||
An example of emitted Wasm IR in the form of WAT:
|
||||
|
||||
```lisp
|
||||
(func $executable (param $0 f64) (result f64)
|
||||
@ -129,7 +140,9 @@ Example usage:
|
||||
```kotlin
|
||||
import space.kscience.kmath.ast.*
|
||||
import space.kscience.kmath.ast.rendering.*
|
||||
import space.kscience.kmath.misc.*
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun main() {
|
||||
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
|
||||
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
|
||||
@ -145,13 +158,68 @@ public fun main() {
|
||||
|
||||
Result LaTeX:
|
||||
|
||||
![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D)
|
||||
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{-12})
|
||||
|
||||
Result MathML (embedding MathML is not allowed by GitHub Markdown):
|
||||
|
||||
<details>
|
||||
|
||||
```html
|
||||
<mrow><msup><mrow><mi>e</mi></mrow><mrow><msqrt><mi>x</mi></msqrt></mrow></msup><mo>-</mo><mfrac><mrow><mfrac><mrow><msup><mrow><mo>sin</mo></mrow><mrow><mo>-</mo><mn>1</mn></mrow></msup><mspace width="0.167em"></mspace><mfenced open="(" close=")" separators=""><mn>2</mn><mspace width="0.167em"></mspace><mi>x</mi></mfenced></mrow><mrow><mn>2</mn><mo>×</mo><msup><mrow><mn>10</mn></mrow><mrow><mn>10</mn></mrow></msup><mo>+</mo><msup><mrow><mi>x</mi></mrow><mrow><mn>3</mn></mrow></msup></mrow></mfrac></mrow><mrow><mo>-</mo><mn>12</mn></mrow></mfrac></mrow>
|
||||
<math xmlns="https://www.w3.org/1998/Math/MathML">
|
||||
<mrow>
|
||||
<mo>exp</mo>
|
||||
<mspace width="0.167em"></mspace>
|
||||
<mfenced open="(" close=")" separators="">
|
||||
<msqrt>
|
||||
<mi>x</mi>
|
||||
</msqrt>
|
||||
</mfenced>
|
||||
<mo>-</mo>
|
||||
<mfrac>
|
||||
<mrow>
|
||||
<mfrac>
|
||||
<mrow>
|
||||
<mo>arcsin</mo>
|
||||
<mspace width="0.167em"></mspace>
|
||||
<mfenced open="(" close=")" separators="">
|
||||
<mn>2</mn>
|
||||
<mspace width="0.167em"></mspace>
|
||||
<mi>x</mi>
|
||||
</mfenced>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>2</mn>
|
||||
<mo>×</mo>
|
||||
<msup>
|
||||
<mrow>
|
||||
<mn>10</mn>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>10</mn>
|
||||
</mrow>
|
||||
</msup>
|
||||
<mo>+</mo>
|
||||
<msup>
|
||||
<mrow>
|
||||
<mi>x</mi>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>3</mn>
|
||||
</mrow>
|
||||
</msup>
|
||||
</mrow>
|
||||
</mfrac>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mo>-</mo>
|
||||
<mn>12</mn>
|
||||
</mrow>
|
||||
</mfrac>
|
||||
</mrow>
|
||||
</math>
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
It is also possible to create custom algorithms of render, and even add support of other markup languages
|
||||
(see API reference).
|
||||
|
@ -18,6 +18,10 @@ kotlin.js {
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
filter { it.name.contains("test", true) }
|
||||
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
|
||||
.forEach { it.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI") }
|
||||
|
||||
commonMain {
|
||||
dependencies {
|
||||
api("com.github.h0tk3y.betterParse:better-parse:0.4.2")
|
||||
@ -54,7 +58,7 @@ tasks.dokkaHtml {
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
|
||||
feature(
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Module kmath-ast
|
||||
|
||||
Abstract syntax tree expression representation and related optimizations.
|
||||
Performance and visualization extensions to MST API.
|
||||
|
||||
${features}
|
||||
|
||||
@ -10,21 +10,26 @@ ${artifact}
|
||||
|
||||
### On JVM
|
||||
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
||||
special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
|
||||
For example, the following builder:
|
||||
|
||||
```kotlin
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.asm.*
|
||||
|
||||
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
|
||||
```
|
||||
|
||||
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||
... leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||
|
||||
```java
|
||||
package space.kscience.kmath.asm.generated;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import kotlin.jvm.functions.Function2;
|
||||
import space.kscience.kmath.asm.internal.MapIntrinsics;
|
||||
import space.kscience.kmath.expressions.Expression;
|
||||
@ -34,7 +39,7 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
private final Object[] constants;
|
||||
|
||||
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
||||
return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2);
|
||||
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
|
||||
}
|
||||
|
||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
||||
@ -46,8 +51,8 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
|
||||
#### Known issues
|
||||
|
||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
||||
class loading overhead.
|
||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
|
||||
loading overhead.
|
||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
||||
|
||||
### On JS
|
||||
@ -55,6 +60,10 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
A similar feature is also available on JS.
|
||||
|
||||
```kotlin
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.estree.*
|
||||
|
||||
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
|
||||
```
|
||||
|
||||
@ -62,18 +71,22 @@ The code above returns expression implemented with such a JS function:
|
||||
|
||||
```js
|
||||
var executable = function (constants, arguments) {
|
||||
return constants[1](constants[0](arguments, "x"), 2);
|
||||
return constants[1](constants[0](arguments, "x"), 2);
|
||||
};
|
||||
```
|
||||
|
||||
JS also supports very experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation.
|
||||
Currently, only expressions inside `DoubleField` and `IntRing` are supported.
|
||||
|
||||
```kotlin
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.wasm.*
|
||||
|
||||
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
|
||||
```
|
||||
|
||||
An example of emitted WASM IR in the form of WAT:
|
||||
An example of emitted Wasm IR in the form of WAT:
|
||||
|
||||
```lisp
|
||||
(func \$executable (param \$0 f64) (result f64)
|
||||
@ -98,9 +111,11 @@ Example usage:
|
||||
```kotlin
|
||||
import space.kscience.kmath.ast.*
|
||||
import space.kscience.kmath.ast.rendering.*
|
||||
import space.kscience.kmath.misc.*
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun main() {
|
||||
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
|
||||
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(12)+x^(2/3)".parseMath()
|
||||
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
|
||||
val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax)
|
||||
println("LaTeX:")
|
||||
@ -114,13 +129,78 @@ public fun main() {
|
||||
|
||||
Result LaTeX:
|
||||
|
||||
![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D)
|
||||
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
|
||||
|
||||
Result MathML (embedding MathML is not allowed by GitHub Markdown):
|
||||
Result MathML (can be used with MathJax or other renderers):
|
||||
|
||||
<details>
|
||||
|
||||
```html
|
||||
<mrow><msup><mrow><mi>e</mi></mrow><mrow><msqrt><mi>x</mi></msqrt></mrow></msup><mo>-</mo><mfrac><mrow><mfrac><mrow><msup><mrow><mo>sin</mo></mrow><mrow><mo>-</mo><mn>1</mn></mrow></msup><mspace width="0.167em"></mspace><mfenced open="(" close=")" separators=""><mn>2</mn><mspace width="0.167em"></mspace><mi>x</mi></mfenced></mrow><mrow><mn>2</mn><mo>×</mo><msup><mrow><mn>10</mn></mrow><mrow><mn>10</mn></mrow></msup><mo>+</mo><msup><mrow><mi>x</mi></mrow><mrow><mn>3</mn></mrow></msup></mrow></mfrac></mrow><mrow><mo>-</mo><mn>12</mn></mrow></mfrac></mrow>
|
||||
<math xmlns="https://www.w3.org/1998/Math/MathML">
|
||||
<mrow>
|
||||
<mo>exp</mo>
|
||||
<mspace width="0.167em"></mspace>
|
||||
<mfenced open="(" close=")" separators="">
|
||||
<msqrt>
|
||||
<mi>x</mi>
|
||||
</msqrt>
|
||||
</mfenced>
|
||||
<mo>-</mo>
|
||||
<mfrac>
|
||||
<mrow>
|
||||
<mfrac>
|
||||
<mrow>
|
||||
<mo>arcsin</mo>
|
||||
<mspace width="0.167em"></mspace>
|
||||
<mfenced open="(" close=")" separators="">
|
||||
<mn>2</mn>
|
||||
<mspace width="0.167em"></mspace>
|
||||
<mi>x</mi>
|
||||
</mfenced>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>2</mn>
|
||||
<mo>×</mo>
|
||||
<msup>
|
||||
<mrow>
|
||||
<mn>10</mn>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>10</mn>
|
||||
</mrow>
|
||||
</msup>
|
||||
<mo>+</mo>
|
||||
<msup>
|
||||
<mrow>
|
||||
<mi>x</mi>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>3</mn>
|
||||
</mrow>
|
||||
</msup>
|
||||
</mrow>
|
||||
</mfrac>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>12</mn>
|
||||
</mrow>
|
||||
</mfrac>
|
||||
<mo>+</mo>
|
||||
<msup>
|
||||
<mrow>
|
||||
<mi>x</mi>
|
||||
</mrow>
|
||||
<mrow>
|
||||
<mn>2</mn>
|
||||
<mo>/</mo>
|
||||
<mn>3</mn>
|
||||
</mrow>
|
||||
</msup>
|
||||
</mrow>
|
||||
</math>
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
It is also possible to create custom algorithms of render, and even add support of other markup languages
|
||||
(see API reference).
|
||||
|
@ -29,7 +29,6 @@ import space.kscience.kmath.operations.RingOperations
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
||||
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?".toRegex())
|
||||
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*".toRegex())
|
||||
private val lpar: Token by literalToken("(")
|
||||
|
@ -5,6 +5,8 @@
|
||||
|
||||
package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* [SyntaxRenderer] implementation for LaTeX.
|
||||
*
|
||||
@ -23,6 +25,7 @@ package space.kscience.kmath.ast.rendering
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public object LatexSyntaxRenderer : SyntaxRenderer {
|
||||
public override fun render(node: MathSyntax, output: Appendable): Unit = output.run {
|
||||
fun render(syntax: MathSyntax) = render(syntax, output)
|
||||
@ -115,7 +118,11 @@ public object LatexSyntaxRenderer : SyntaxRenderer {
|
||||
render(node.right)
|
||||
}
|
||||
|
||||
is FractionSyntax -> {
|
||||
is FractionSyntax -> if (node.infix) {
|
||||
render(node.left)
|
||||
append('/')
|
||||
render(node.right)
|
||||
} else {
|
||||
append("\\frac{")
|
||||
render(node.left)
|
||||
append("}{")
|
||||
|
@ -5,6 +5,8 @@
|
||||
|
||||
package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* [SyntaxRenderer] implementation for MathML.
|
||||
*
|
||||
@ -12,14 +14,18 @@ package space.kscience.kmath.ast.rendering
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public object MathMLSyntaxRenderer : SyntaxRenderer {
|
||||
public override fun render(node: MathSyntax, output: Appendable) {
|
||||
output.append("<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>")
|
||||
render0(node, output)
|
||||
output.append("<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>")
|
||||
renderPart(node, output)
|
||||
output.append("</mrow></math>")
|
||||
}
|
||||
|
||||
private fun render0(node: MathSyntax, output: Appendable): Unit = output.run {
|
||||
/**
|
||||
* Renders a part of syntax returning a correct MathML tag not the whole MathML instance.
|
||||
*/
|
||||
public fun renderPart(node: MathSyntax, output: Appendable): Unit = output.run {
|
||||
fun tag(tagName: String, vararg attr: Pair<String, String>, block: () -> Unit = {}) {
|
||||
append('<')
|
||||
append(tagName)
|
||||
@ -44,7 +50,7 @@ public object MathMLSyntaxRenderer : SyntaxRenderer {
|
||||
append('>')
|
||||
}
|
||||
|
||||
fun render(syntax: MathSyntax) = render0(syntax, output)
|
||||
fun render(syntax: MathSyntax) = renderPart(syntax, output)
|
||||
|
||||
when (node) {
|
||||
is NumberSyntax -> tag("mn") { append(node.string) }
|
||||
@ -127,14 +133,13 @@ public object MathMLSyntaxRenderer : SyntaxRenderer {
|
||||
render(node.right)
|
||||
}
|
||||
|
||||
is FractionSyntax -> tag("mfrac") {
|
||||
tag("mrow") {
|
||||
render(node.left)
|
||||
}
|
||||
|
||||
tag("mrow") {
|
||||
render(node.right)
|
||||
}
|
||||
is FractionSyntax -> if (node.infix) {
|
||||
render(node.left)
|
||||
tag("mo") { append('/') }
|
||||
render(node.right)
|
||||
} else tag("mfrac") {
|
||||
tag("mrow") { render(node.left) }
|
||||
tag("mrow") { render(node.right) }
|
||||
}
|
||||
|
||||
is RadicalWithIndexSyntax -> tag("mroot") {
|
||||
|
@ -6,12 +6,14 @@
|
||||
package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* Renders [MST] to [MathSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun interface MathRenderer {
|
||||
/**
|
||||
* Renders [MST] to [MathSyntax].
|
||||
@ -25,6 +27,7 @@ public fun interface MathRenderer {
|
||||
* @property features The applied features.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public open class FeaturedMathRenderer(public val features: List<RenderFeature>) : MathRenderer {
|
||||
public override fun render(mst: MST): MathSyntax {
|
||||
for (feature in features) feature.render(this, mst)?.let { return it }
|
||||
@ -48,6 +51,7 @@ public open class FeaturedMathRenderer(public val features: List<RenderFeature>)
|
||||
* @property stages The applied stages.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public open class FeaturedMathRendererWithPostProcess(
|
||||
features: List<RenderFeature>,
|
||||
public val stages: List<PostProcessStage>,
|
||||
@ -85,6 +89,7 @@ public open class FeaturedMathRendererWithPostProcess(
|
||||
SquareRoot.Default,
|
||||
Exponent.Default,
|
||||
InverseTrigonometricOperations.Default,
|
||||
InverseHyperbolicOperations.Default,
|
||||
|
||||
// Fallback option for unknown operations - printing them as operator
|
||||
BinaryOperator.Default,
|
||||
@ -101,6 +106,7 @@ public open class FeaturedMathRendererWithPostProcess(
|
||||
),
|
||||
listOf(
|
||||
BetterExponent,
|
||||
BetterFraction,
|
||||
SimplifyParentheses.Default,
|
||||
BetterMultiplication,
|
||||
),
|
||||
|
@ -5,11 +5,14 @@
|
||||
|
||||
package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* Mathematical typography syntax node.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public sealed class MathSyntax {
|
||||
/**
|
||||
* The parent node of this syntax node.
|
||||
@ -22,6 +25,7 @@ public sealed class MathSyntax {
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public sealed class TerminalSyntax : MathSyntax()
|
||||
|
||||
/**
|
||||
@ -29,6 +33,7 @@ public sealed class TerminalSyntax : MathSyntax()
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public sealed class OperationSyntax : MathSyntax() {
|
||||
/**
|
||||
* The operation token.
|
||||
@ -41,6 +46,7 @@ public sealed class OperationSyntax : MathSyntax() {
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public sealed class UnarySyntax : OperationSyntax() {
|
||||
/**
|
||||
* The operand of this node.
|
||||
@ -53,6 +59,7 @@ public sealed class UnarySyntax : OperationSyntax() {
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public sealed class BinarySyntax : OperationSyntax() {
|
||||
/**
|
||||
* The left-hand side operand.
|
||||
@ -71,6 +78,7 @@ public sealed class BinarySyntax : OperationSyntax() {
|
||||
* @property string The digits of number.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class NumberSyntax(public var string: String) : TerminalSyntax()
|
||||
|
||||
/**
|
||||
@ -79,6 +87,7 @@ public data class NumberSyntax(public var string: String) : TerminalSyntax()
|
||||
* @property string The symbol.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class SymbolSyntax(public var string: String) : TerminalSyntax()
|
||||
|
||||
/**
|
||||
@ -89,14 +98,16 @@ public data class SymbolSyntax(public var string: String) : TerminalSyntax()
|
||||
* @see UnaryOperatorSyntax
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class OperatorNameSyntax(public var name: String) : TerminalSyntax()
|
||||
|
||||
/**
|
||||
* Represents a usage of special symbols.
|
||||
* Represents a usage of special symbols (e.g., *∞*).
|
||||
*
|
||||
* @property kind The kind of symbol.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class SpecialSymbolSyntax(public var kind: Kind) : TerminalSyntax() {
|
||||
/**
|
||||
* The kind of symbol.
|
||||
@ -121,6 +132,7 @@ public data class SpecialSymbolSyntax(public var kind: Kind) : TerminalSyntax()
|
||||
* @property parentheses Whether the operand should be wrapped with parentheses.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class OperandSyntax(
|
||||
public val operand: MathSyntax,
|
||||
public var parentheses: Boolean,
|
||||
@ -131,11 +143,12 @@ public data class OperandSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents unary, prefix operator syntax (like f x).
|
||||
* Represents unary, prefix operator syntax (like *f(x)*).
|
||||
*
|
||||
* @property prefix The prefix.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class UnaryOperatorSyntax(
|
||||
public override val operation: String,
|
||||
public var prefix: MathSyntax,
|
||||
@ -147,10 +160,11 @@ public data class UnaryOperatorSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents prefix, unary plus operator.
|
||||
* Represents prefix, unary plus operator (*+x*).
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class UnaryPlusSyntax(
|
||||
public override val operation: String,
|
||||
public override val operand: OperandSyntax,
|
||||
@ -161,10 +175,11 @@ public data class UnaryPlusSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents prefix, unary minus operator.
|
||||
* Represents prefix, unary minus operator (*-x*).
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class UnaryMinusSyntax(
|
||||
public override val operation: String,
|
||||
public override val operand: OperandSyntax,
|
||||
@ -175,11 +190,12 @@ public data class UnaryMinusSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents radical with a node inside it.
|
||||
* Represents radical with a node inside it (*√x*).
|
||||
*
|
||||
* @property operand The radicand.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class RadicalSyntax(
|
||||
public override val operation: String,
|
||||
public override val operand: MathSyntax,
|
||||
@ -197,6 +213,7 @@ public data class RadicalSyntax(
|
||||
* (*e<sup>x</sup>*).
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class ExponentSyntax(
|
||||
public override val operation: String,
|
||||
public override val operand: OperandSyntax,
|
||||
@ -208,12 +225,13 @@ public data class ExponentSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a syntax node with superscript (usually, for exponentiation).
|
||||
* Represents a syntax node with superscript (*x<sup>2</sup>*).
|
||||
*
|
||||
* @property left The node.
|
||||
* @property right The superscript.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class SuperscriptSyntax(
|
||||
public override val operation: String,
|
||||
public override val left: MathSyntax,
|
||||
@ -226,12 +244,13 @@ public data class SuperscriptSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a syntax node with subscript.
|
||||
* Represents a syntax node with subscript (*x<sub>i</sup>*).
|
||||
*
|
||||
* @property left The node.
|
||||
* @property right The subscript.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class SubscriptSyntax(
|
||||
public override val operation: String,
|
||||
public override val left: MathSyntax,
|
||||
@ -244,11 +263,12 @@ public data class SubscriptSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents binary, prefix operator syntax (like f(a, b)).
|
||||
* Represents binary, prefix operator syntax (like *f(a, b)*).
|
||||
*
|
||||
* @property prefix The prefix.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class BinaryOperatorSyntax(
|
||||
public override val operation: String,
|
||||
public var prefix: MathSyntax,
|
||||
@ -262,12 +282,13 @@ public data class BinaryOperatorSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents binary, infix addition.
|
||||
* Represents binary, infix addition (*42 + 42*).
|
||||
*
|
||||
* @param left The augend.
|
||||
* @param right The addend.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class BinaryPlusSyntax(
|
||||
public override val operation: String,
|
||||
public override val left: OperandSyntax,
|
||||
@ -280,12 +301,13 @@ public data class BinaryPlusSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents binary, infix subtraction.
|
||||
* Represents binary, infix subtraction (*42 - 42*).
|
||||
*
|
||||
* @param left The minuend.
|
||||
* @param right The subtrahend.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class BinaryMinusSyntax(
|
||||
public override val operation: String,
|
||||
public override val left: OperandSyntax,
|
||||
@ -302,12 +324,15 @@ public data class BinaryMinusSyntax(
|
||||
*
|
||||
* @property left The numerator.
|
||||
* @property right The denominator.
|
||||
* @property infix Whether infix (*1 / 2*) or normal (*½*) fraction should be made.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class FractionSyntax(
|
||||
public override val operation: String,
|
||||
public override val left: MathSyntax,
|
||||
public override val right: MathSyntax,
|
||||
public override val left: OperandSyntax,
|
||||
public override val right: OperandSyntax,
|
||||
public var infix: Boolean,
|
||||
) : BinarySyntax() {
|
||||
init {
|
||||
left.parent = this
|
||||
@ -316,12 +341,13 @@ public data class FractionSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents radical syntax with index.
|
||||
* Represents radical syntax with index (*<sup>3</sup>√x*).
|
||||
*
|
||||
* @property left The index.
|
||||
* @property right The radicand.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class RadicalWithIndexSyntax(
|
||||
public override val operation: String,
|
||||
public override val left: MathSyntax,
|
||||
@ -334,13 +360,14 @@ public data class RadicalWithIndexSyntax(
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents binary, infix multiplication in the form of coefficient (2 x) or with operator (x×2).
|
||||
* Represents binary, infix multiplication in the form of coefficient (*2 x*) or with operator (*x × 2*).
|
||||
*
|
||||
* @property left The multiplicand.
|
||||
* @property right The multiplier.
|
||||
* @property times whether the times (×) symbol should be used.
|
||||
* @property times Whether the times (×) symbol should be used.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public data class MultiplicationSyntax(
|
||||
public override val operation: String,
|
||||
public override val left: OperandSyntax,
|
||||
|
@ -5,12 +5,15 @@
|
||||
|
||||
package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* Abstraction of writing [MathSyntax] as a string of an actual markup language. Typical implementation should
|
||||
* involve traversal of MathSyntax with handling each its subtype.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun interface SyntaxRenderer {
|
||||
/**
|
||||
* Renders the [MathSyntax] to [output].
|
||||
@ -23,6 +26,7 @@ public fun interface SyntaxRenderer {
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun SyntaxRenderer.renderWithStringBuilder(node: MathSyntax): String {
|
||||
val sb = StringBuilder()
|
||||
render(node, sb)
|
||||
|
@ -7,6 +7,7 @@ package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.ast.rendering.FeaturedMathRenderer.RenderFeature
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.*
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
@ -15,11 +16,12 @@ import kotlin.reflect.KClass
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public object PrintSymbolic : RenderFeature {
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
|
||||
if (node !is MST.Symbolic) return null
|
||||
return SymbolSyntax(string = node.value)
|
||||
}
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): SymbolSyntax? =
|
||||
if (node !is MST.Symbolic) null
|
||||
else
|
||||
SymbolSyntax(string = node.value)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -27,35 +29,38 @@ public object PrintSymbolic : RenderFeature {
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public object PrintNumeric : RenderFeature {
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
|
||||
if (node !is MST.Numeric) return null
|
||||
return NumberSyntax(string = node.value.toString())
|
||||
}
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): NumberSyntax? = if (node !is MST.Numeric)
|
||||
null
|
||||
else
|
||||
NumberSyntax(string = node.value.toString())
|
||||
}
|
||||
|
||||
private fun printSignedNumberString(s: String): MathSyntax {
|
||||
if (s.startsWith('-'))
|
||||
return UnaryMinusSyntax(
|
||||
operation = GroupOperations.MINUS_OPERATION,
|
||||
operand = OperandSyntax(
|
||||
operand = NumberSyntax(string = s.removePrefix("-")),
|
||||
parentheses = true,
|
||||
),
|
||||
)
|
||||
|
||||
return NumberSyntax(string = s)
|
||||
}
|
||||
@UnstableKMathAPI
|
||||
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
|
||||
UnaryMinusSyntax(
|
||||
operation = GroupOperations.MINUS_OPERATION,
|
||||
operand = OperandSyntax(
|
||||
operand = NumberSyntax(string = s.removePrefix("-")),
|
||||
parentheses = true,
|
||||
),
|
||||
)
|
||||
else
|
||||
NumberSyntax(string = s)
|
||||
|
||||
/**
|
||||
* Special printing for numeric types which are printed in form of
|
||||
* *('-'? (DIGIT+ ('.' DIGIT+)? ('E' '-'? DIGIT+)? | 'Infinity')) | 'NaN'*.
|
||||
*
|
||||
* @property types The suitable types.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : RenderFeature {
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
|
||||
if (node !is MST.Numeric || node.value::class !in types) return null
|
||||
|
||||
val toString = when (val v = node.value) {
|
||||
is Float -> v.multiplatformToString()
|
||||
is Double -> v.multiplatformToString()
|
||||
@ -109,12 +114,15 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
|
||||
* Special printing for numeric types which are printed in form of *'-'? DIGIT+*.
|
||||
*
|
||||
* @property types The suitable types.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class PrettyPrintIntegers(public val types: Set<KClass<out Number>>) : RenderFeature {
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
|
||||
if (node !is MST.Numeric || node.value::class !in types) return null
|
||||
return printSignedNumberString(node.value.toString())
|
||||
}
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? =
|
||||
if (node !is MST.Numeric || node.value::class !in types)
|
||||
null
|
||||
else
|
||||
printSignedNumberString(node.value.toString())
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
@ -129,12 +137,15 @@ public class PrettyPrintIntegers(public val types: Set<KClass<out Number>>) : Re
|
||||
* Special printing for symbols meaning Pi.
|
||||
*
|
||||
* @property symbols The allowed symbols.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class PrettyPrintPi(public val symbols: Set<String>) : RenderFeature {
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
|
||||
if (node !is MST.Symbolic || node.value !in symbols) return null
|
||||
return SpecialSymbolSyntax(kind = SpecialSymbolSyntax.Kind.SMALL_PI)
|
||||
}
|
||||
public override fun render(renderer: FeaturedMathRenderer, node: MST): SpecialSymbolSyntax? =
|
||||
if (node !is MST.Symbolic || node.value !in symbols)
|
||||
null
|
||||
else
|
||||
SpecialSymbolSyntax(kind = SpecialSymbolSyntax.Kind.SMALL_PI)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
@ -149,17 +160,20 @@ public class PrettyPrintPi(public val symbols: Set<String>) : RenderFeature {
|
||||
* not [MST.Unary].
|
||||
*
|
||||
* @param operations the allowed operations. If `null`, any operation is accepted.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public abstract class Unary(public val operations: Collection<String>?) : RenderFeature {
|
||||
/**
|
||||
* The actual render function.
|
||||
* The actual render function specialized for [MST.Unary].
|
||||
*/
|
||||
protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax?
|
||||
protected abstract fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax?
|
||||
|
||||
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
|
||||
if (node !is MST.Unary || operations != null && node.operation !in operations) return null
|
||||
return render0(renderer, node)
|
||||
}
|
||||
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? =
|
||||
if (node !is MST.Unary || operations != null && node.operation !in operations)
|
||||
null
|
||||
else
|
||||
renderUnary(renderer, node)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -167,169 +181,301 @@ public abstract class Unary(public val operations: Collection<String>?) : Render
|
||||
* not [MST.Binary].
|
||||
*
|
||||
* @property operations the allowed operations. If `null`, any operation is accepted.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public abstract class Binary(public val operations: Collection<String>?) : RenderFeature {
|
||||
/**
|
||||
* The actual render function.
|
||||
* The actual render function specialized for [MST.Binary].
|
||||
*/
|
||||
protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax?
|
||||
protected abstract fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax?
|
||||
|
||||
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
|
||||
if (node !is MST.Binary || operations != null && node.operation !in operations) return null
|
||||
return render0(renderer, node)
|
||||
return renderBinary(renderer, node)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [BinaryPlusSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryPlusSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(parent.render(node.left), true),
|
||||
right = OperandSyntax(parent.render(node.right), true),
|
||||
)
|
||||
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): BinaryPlusSyntax =
|
||||
BinaryPlusSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(parent.render(node.left), true),
|
||||
right = OperandSyntax(parent.render(node.right), true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.PLUS_OPERATION].
|
||||
*/
|
||||
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [BinaryMinusSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryMinusSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
|
||||
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
|
||||
)
|
||||
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): BinaryMinusSyntax =
|
||||
BinaryMinusSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
|
||||
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.MINUS_OPERATION].
|
||||
*/
|
||||
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles unary nodes by producing [UnaryPlusSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryPlusSyntax(
|
||||
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryPlusSyntax = UnaryPlusSyntax(
|
||||
operation = node.operation,
|
||||
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.PLUS_OPERATION].
|
||||
*/
|
||||
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [UnaryMinusSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryMinusSyntax(
|
||||
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryMinusSyntax = UnaryMinusSyntax(
|
||||
operation = node.operation,
|
||||
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.MINUS_OPERATION].
|
||||
*/
|
||||
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [FractionSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class Fraction(operations: Collection<String>?) : Binary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = FractionSyntax(
|
||||
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): FractionSyntax = FractionSyntax(
|
||||
operation = node.operation,
|
||||
left = parent.render(node.left),
|
||||
right = parent.render(node.right),
|
||||
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
|
||||
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
|
||||
infix = true,
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [FieldOperations.DIV_OPERATION].
|
||||
*/
|
||||
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [BinaryOperatorSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class BinaryOperator(operations: Collection<String>?) : Binary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryOperatorSyntax(
|
||||
operation = node.operation,
|
||||
prefix = OperatorNameSyntax(name = node.operation),
|
||||
left = parent.render(node.left),
|
||||
right = parent.render(node.right),
|
||||
)
|
||||
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): BinaryOperatorSyntax =
|
||||
BinaryOperatorSyntax(
|
||||
operation = node.operation,
|
||||
prefix = OperatorNameSyntax(name = node.operation),
|
||||
left = parent.render(node.left),
|
||||
right = parent.render(node.right),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with `null`.
|
||||
*/
|
||||
public val Default: BinaryOperator = BinaryOperator(null)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles unary nodes by producing [UnaryOperatorSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class UnaryOperator(operations: Collection<String>?) : Unary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax(
|
||||
operation = node.operation,
|
||||
prefix = OperatorNameSyntax(node.operation),
|
||||
operand = OperandSyntax(parent.render(node.value), true),
|
||||
)
|
||||
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryOperatorSyntax =
|
||||
UnaryOperatorSyntax(
|
||||
operation = node.operation,
|
||||
prefix = OperatorNameSyntax(node.operation),
|
||||
operand = OperandSyntax(parent.render(node.value), true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with `null`.
|
||||
*/
|
||||
public val Default: UnaryOperator = UnaryOperator(null)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [SuperscriptSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class Power(operations: Collection<String>?) : Binary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = SuperscriptSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(parent.render(node.left), true),
|
||||
right = OperandSyntax(parent.render(node.right), true),
|
||||
)
|
||||
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): SuperscriptSyntax =
|
||||
SuperscriptSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(parent.render(node.left), true),
|
||||
right = OperandSyntax(parent.render(node.right), true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [PowerOperations.POW_OPERATION].
|
||||
*/
|
||||
public val Default: Power = Power(setOf(PowerOperations.POW_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [RadicalSyntax] with no index.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class SquareRoot(operations: Collection<String>?) : Unary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax =
|
||||
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): RadicalSyntax =
|
||||
RadicalSyntax(operation = node.operation, operand = parent.render(node.value))
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [PowerOperations.SQRT_OPERATION].
|
||||
*/
|
||||
public val Default: SquareRoot = SquareRoot(setOf(PowerOperations.SQRT_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles unary nodes by producing [ExponentSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class Exponent(operations: Collection<String>?) : Unary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = ExponentSyntax(
|
||||
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): ExponentSyntax = ExponentSyntax(
|
||||
operation = node.operation,
|
||||
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
|
||||
useOperatorForm = true,
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [ExponentialOperations.EXP_OPERATION].
|
||||
*/
|
||||
public val Default: Exponent = Exponent(setOf(ExponentialOperations.EXP_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing [MultiplicationSyntax].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class Multiplication(operations: Collection<String>?) : Binary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = MultiplicationSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
|
||||
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
|
||||
times = true,
|
||||
)
|
||||
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MultiplicationSyntax =
|
||||
MultiplicationSyntax(
|
||||
operation = node.operation,
|
||||
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
|
||||
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
|
||||
times = true,
|
||||
)
|
||||
|
||||
public companion object {
|
||||
public val Default: Multiplication = Multiplication(setOf(
|
||||
RingOperations.TIMES_OPERATION,
|
||||
))
|
||||
/**
|
||||
* The default instance configured with [RingOperations.TIMES_OPERATION].
|
||||
*/
|
||||
public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing inverse [UnaryOperatorSyntax] with *arc* prefix instead of *a*.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class InverseTrigonometricOperations(operations: Collection<String>?) : Unary(operations) {
|
||||
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax(
|
||||
operation = node.operation,
|
||||
prefix = SuperscriptSyntax(
|
||||
operation = PowerOperations.POW_OPERATION,
|
||||
left = OperatorNameSyntax(name = node.operation.removePrefix("a")),
|
||||
right = UnaryMinusSyntax(
|
||||
operation = GroupOperations.MINUS_OPERATION,
|
||||
operand = OperandSyntax(operand = NumberSyntax(string = "1"), parentheses = true),
|
||||
),
|
||||
),
|
||||
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
|
||||
)
|
||||
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryOperatorSyntax =
|
||||
UnaryOperatorSyntax(
|
||||
operation = node.operation,
|
||||
prefix = OperatorNameSyntax(name = node.operation.replaceFirst("a", "arc")),
|
||||
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [TrigonometricOperations.ACOS_OPERATION],
|
||||
* [TrigonometricOperations.ASIN_OPERATION], [TrigonometricOperations.ATAN_OPERATION].
|
||||
*/
|
||||
public val Default: InverseTrigonometricOperations = InverseTrigonometricOperations(setOf(
|
||||
TrigonometricOperations.ACOS_OPERATION,
|
||||
TrigonometricOperations.ASIN_OPERATION,
|
||||
TrigonometricOperations.ATAN_OPERATION,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles binary nodes by producing inverse [UnaryOperatorSyntax] with *ar* prefix instead of *a*.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class InverseHyperbolicOperations(operations: Collection<String>?) : Unary(operations) {
|
||||
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryOperatorSyntax =
|
||||
UnaryOperatorSyntax(
|
||||
operation = node.operation,
|
||||
prefix = OperatorNameSyntax(name = node.operation.replaceFirst("a", "ar")),
|
||||
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
|
||||
)
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [ExponentialOperations.ACOSH_OPERATION],
|
||||
* [ExponentialOperations.ASINH_OPERATION], and [ExponentialOperations.ATANH_OPERATION].
|
||||
*/
|
||||
public val Default: InverseHyperbolicOperations = InverseHyperbolicOperations(setOf(
|
||||
ExponentialOperations.ACOSH_OPERATION,
|
||||
ExponentialOperations.ASINH_OPERATION,
|
||||
ExponentialOperations.ATANH_OPERATION,
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.FieldOperations
|
||||
import space.kscience.kmath.operations.GroupOperations
|
||||
import space.kscience.kmath.operations.PowerOperations
|
||||
@ -15,6 +16,7 @@ import space.kscience.kmath.operations.RingOperations
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostProcessStage {
|
||||
public override fun perform(node: MathSyntax): Unit = when (node) {
|
||||
is NumberSyntax -> Unit
|
||||
@ -81,6 +83,75 @@ public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostPro
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Chooses [FractionSyntax.infix] depending on the context.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public object BetterFraction : FeaturedMathRendererWithPostProcess.PostProcessStage {
|
||||
private fun perform0(node: MathSyntax, infix: Boolean = false): Unit = when (node) {
|
||||
is NumberSyntax -> Unit
|
||||
is SymbolSyntax -> Unit
|
||||
is OperatorNameSyntax -> Unit
|
||||
is SpecialSymbolSyntax -> Unit
|
||||
is OperandSyntax -> perform0(node.operand, infix)
|
||||
|
||||
is UnaryOperatorSyntax -> {
|
||||
perform0(node.prefix, infix)
|
||||
perform0(node.operand, infix)
|
||||
}
|
||||
|
||||
is UnaryPlusSyntax -> perform0(node.operand, infix)
|
||||
is UnaryMinusSyntax -> perform0(node.operand, infix)
|
||||
is RadicalSyntax -> perform0(node.operand, infix)
|
||||
is ExponentSyntax -> perform0(node.operand, infix)
|
||||
|
||||
is SuperscriptSyntax -> {
|
||||
perform0(node.left, true)
|
||||
perform0(node.right, true)
|
||||
}
|
||||
|
||||
is SubscriptSyntax -> {
|
||||
perform0(node.left, true)
|
||||
perform0(node.right, true)
|
||||
}
|
||||
|
||||
is BinaryOperatorSyntax -> {
|
||||
perform0(node.prefix, infix)
|
||||
perform0(node.left, infix)
|
||||
perform0(node.right, infix)
|
||||
}
|
||||
|
||||
is BinaryPlusSyntax -> {
|
||||
perform0(node.left, infix)
|
||||
perform0(node.right, infix)
|
||||
}
|
||||
|
||||
is BinaryMinusSyntax -> {
|
||||
perform0(node.left, infix)
|
||||
perform0(node.right, infix)
|
||||
}
|
||||
|
||||
is FractionSyntax -> {
|
||||
node.infix = infix
|
||||
perform0(node.left, infix)
|
||||
perform0(node.right, infix)
|
||||
}
|
||||
|
||||
is RadicalWithIndexSyntax -> {
|
||||
perform0(node.left, true)
|
||||
perform0(node.right, true)
|
||||
}
|
||||
|
||||
is MultiplicationSyntax -> {
|
||||
perform0(node.left, infix)
|
||||
perform0(node.right, infix)
|
||||
}
|
||||
}
|
||||
|
||||
public override fun perform(node: MathSyntax): Unit = perform0(node)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies [ExponentSyntax.useOperatorForm] to [ExponentSyntax] when the operand contains a fraction, a
|
||||
@ -88,6 +159,7 @@ public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostPro
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessStage {
|
||||
private fun perform0(node: MathSyntax): Boolean {
|
||||
return when (node) {
|
||||
@ -99,7 +171,7 @@ public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessSt
|
||||
is UnaryOperatorSyntax -> perform0(node.prefix) || perform0(node.operand)
|
||||
is UnaryPlusSyntax -> perform0(node.operand)
|
||||
is UnaryMinusSyntax -> perform0(node.operand)
|
||||
is RadicalSyntax -> perform0(node.operand)
|
||||
is RadicalSyntax -> true
|
||||
|
||||
is ExponentSyntax -> {
|
||||
val r = perform0(node.operand)
|
||||
@ -113,7 +185,7 @@ public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessSt
|
||||
is BinaryPlusSyntax -> perform0(node.left) || perform0(node.right)
|
||||
is BinaryMinusSyntax -> perform0(node.left) || perform0(node.right)
|
||||
is FractionSyntax -> true
|
||||
is RadicalWithIndexSyntax -> perform0(node.left) || perform0(node.right)
|
||||
is RadicalWithIndexSyntax -> true
|
||||
is MultiplicationSyntax -> perform0(node.left) || perform0(node.right)
|
||||
}
|
||||
}
|
||||
@ -129,6 +201,7 @@ public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessSt
|
||||
* @property precedenceFunction Returns the precedence number for syntax node. Higher number is lower priority.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) -> Int) :
|
||||
FeaturedMathRendererWithPostProcess.PostProcessStage {
|
||||
public override fun perform(node: MathSyntax): Unit = when (node) {
|
||||
@ -159,8 +232,11 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) ->
|
||||
val isInsideExpOperator =
|
||||
node.parent is ExponentSyntax && (node.parent as ExponentSyntax).useOperatorForm
|
||||
|
||||
val isOnOrUnderNormalFraction = node.parent is FractionSyntax && !((node.parent as FractionSyntax).infix)
|
||||
|
||||
node.parentheses = !isRightOfSuperscript
|
||||
&& (needParenthesesByPrecedence || node.parent is UnaryOperatorSyntax || isInsideExpOperator)
|
||||
&& !isOnOrUnderNormalFraction
|
||||
|
||||
perform(node.operand)
|
||||
}
|
||||
|
@ -3,12 +3,12 @@
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.wasm
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.MstField
|
||||
import space.kscience.kmath.expressions.MstRing
|
||||
import space.kscience.kmath.expressions.interpret
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.misc.Symbol.Companion.x
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
@ -16,45 +16,41 @@ import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestWasmConsistencyWithInterpreter {
|
||||
internal class TestCompilerConsistencyWithInterpreter {
|
||||
@Test
|
||||
fun intRing() {
|
||||
fun intRing() = runCompilerTest {
|
||||
val mst = MstRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(bindSymbol(x) - (2.toByte() + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
2.0,
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
number(1),
|
||||
) * number(2)
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(IntRing, x to 3),
|
||||
mst.compile(IntRing, x to 3)
|
||||
mst.compile(IntRing, x to 3),
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun doubleField() {
|
||||
fun doubleField() = runCompilerTest {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
number(1) / 2 + number(2.0) * one,
|
||||
) + zero
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(DoubleField, x to 2.0),
|
||||
mst.compile(DoubleField, x to 2.0)
|
||||
mst.compile(DoubleField, x to 2.0),
|
||||
)
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.MstExtendedField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.Symbol.Companion.x
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestCompilerOperations {
|
||||
@Test
|
||||
fun testUnaryPlus() = runCompilerTest {
|
||||
val expr = MstExtendedField { +bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
assertEquals(2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() = runCompilerTest {
|
||||
val expr = MstExtendedField { -bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
assertEquals(-2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() = runCompilerTest {
|
||||
val expr = MstExtendedField { bindSymbol(x) + bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() = runCompilerTest {
|
||||
val expr = MstExtendedField { sin(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCosine() = runCompilerTest {
|
||||
val expr = MstExtendedField { cos(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(1.0, expr(x to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSubtract() = runCompilerTest {
|
||||
val expr = MstExtendedField { bindSymbol(x) - bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() = runCompilerTest {
|
||||
val expr = MstExtendedField { bindSymbol(x) / bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
assertEquals(1.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() = runCompilerTest {
|
||||
val expr = MstExtendedField { bindSymbol(x) pow 2 }.compileToExpression(DoubleField)
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
}
|
@ -3,11 +3,11 @@
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.wasm
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.MstRing
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.misc.Symbol.Companion.x
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
@ -15,20 +15,16 @@ import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
internal class TestWasmVariables {
|
||||
internal class TestCompilerVariables {
|
||||
@Test
|
||||
fun testVariable() {
|
||||
fun testVariable() = runCompilerTest {
|
||||
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing)
|
||||
assertEquals(1, expr(x to 1))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUndefinedVariableFails() {
|
||||
fun testUndefinedVariableFails() = runCompilerTest {
|
||||
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing)
|
||||
assertFailsWith<NoSuchElementException> { expr() }
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -13,7 +13,7 @@ import space.kscience.kmath.operations.DoubleField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class ParserTest {
|
||||
internal class TestParser {
|
||||
@Test
|
||||
fun evaluateParsedMst() {
|
||||
val mst = "2+2*(2+2)".parseMath()
|
@ -5,13 +5,12 @@
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.ast.parseMath
|
||||
import space.kscience.kmath.expressions.evaluate
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class ParserPrecedenceTest {
|
||||
internal class TestParserPrecedence {
|
||||
@Test
|
||||
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
|
||||
|
@ -99,13 +99,17 @@ internal class TestFeatures {
|
||||
fun multiplication() = testLatex("x*1", "x\\times1")
|
||||
|
||||
@Test
|
||||
fun inverseTrigonometry() {
|
||||
testLatex("asin(x)", "\\operatorname{sin}^{-1}\\,\\left(x\\right)")
|
||||
testLatex("asinh(x)", "\\operatorname{sinh}^{-1}\\,\\left(x\\right)")
|
||||
testLatex("acos(x)", "\\operatorname{cos}^{-1}\\,\\left(x\\right)")
|
||||
testLatex("acosh(x)", "\\operatorname{cosh}^{-1}\\,\\left(x\\right)")
|
||||
testLatex("atan(x)", "\\operatorname{tan}^{-1}\\,\\left(x\\right)")
|
||||
testLatex("atanh(x)", "\\operatorname{tanh}^{-1}\\,\\left(x\\right)")
|
||||
fun inverseTrigonometric() {
|
||||
testLatex("asin(x)", "\\operatorname{arcsin}\\,\\left(x\\right)")
|
||||
testLatex("acos(x)", "\\operatorname{arccos}\\,\\left(x\\right)")
|
||||
testLatex("atan(x)", "\\operatorname{arctan}\\,\\left(x\\right)")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun inverseHyperbolic() {
|
||||
testLatex("asinh(x)", "\\operatorname{arsinh}\\,\\left(x\\right)")
|
||||
testLatex("acosh(x)", "\\operatorname{arcosh}\\,\\left(x\\right)")
|
||||
testLatex("atanh(x)", "\\operatorname{artanh}\\,\\left(x\\right)")
|
||||
}
|
||||
|
||||
// @Test
|
||||
|
@ -37,4 +37,10 @@ internal class TestStages {
|
||||
testLatex("exp(x/2)", "\\operatorname{exp}\\,\\left(\\frac{x}{2}\\right)")
|
||||
testLatex("exp(x^2)", "\\operatorname{exp}\\,\\left(x^{2}\\right)")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun fraction() {
|
||||
testLatex("x/y", "\\frac{x}{y}")
|
||||
testLatex("x^(x/y)", "x^{x/y}")
|
||||
}
|
||||
}
|
||||
|
@ -30,17 +30,17 @@ internal object TestUtils {
|
||||
)
|
||||
|
||||
internal fun testMathML(mst: MST, expectedMathML: String) = assertEquals(
|
||||
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
|
||||
expected = "<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
|
||||
actual = mathML(mst),
|
||||
)
|
||||
|
||||
internal fun testMathML(expression: String, expectedMathML: String) = assertEquals(
|
||||
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
|
||||
expected = "<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
|
||||
actual = mathML(expression.parseMath()),
|
||||
)
|
||||
|
||||
internal fun testMathML(expression: MathSyntax, expectedMathML: String) = assertEquals(
|
||||
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
|
||||
expected = "<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
|
||||
actual = MathMLSyntaxRenderer.renderWithStringBuilder(expression),
|
||||
)
|
||||
}
|
||||
|
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
|
||||
internal interface CompilerTestContext {
|
||||
fun MST.compileToExpression(algebra: IntRing): Expression<Int>
|
||||
fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int
|
||||
fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int = compile(algebra, mapOf(*arguments))
|
||||
fun MST.compileToExpression(algebra: DoubleField): Expression<Double>
|
||||
fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double
|
||||
|
||||
fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
||||
compile(algebra, mapOf(*arguments))
|
||||
}
|
||||
|
||||
internal expect inline fun runCompilerTest(action: CompilerTestContext.() -> Unit)
|
@ -3,7 +3,8 @@
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
@file:Suppress("INTERFACE_WITH_SUPERCLASS",
|
||||
@file:Suppress(
|
||||
"INTERFACE_WITH_SUPERCLASS",
|
||||
"OVERRIDING_FINAL_MEMBER",
|
||||
"RETURN_TYPE_MISMATCH_ON_OVERRIDE",
|
||||
"CONFLICTING_OVERLOADS",
|
||||
|
@ -26,7 +26,7 @@ internal sealed class WasmBuilder<T>(
|
||||
val keys: MutableList<String> = mutableListOf()
|
||||
lateinit var ctx: BinaryenModule
|
||||
|
||||
open fun visitSymbolic(mst: MST.Symbolic): ExpressionRef {
|
||||
open fun visitSymbolic(mst: Symbolic): ExpressionRef {
|
||||
try {
|
||||
algebra.bindSymbol(mst.value)
|
||||
} catch (ignored: Throwable) {
|
||||
|
@ -10,6 +10,7 @@ import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
import space.kscience.kmath.wasm.internal.DoubleWasmBuilder
|
||||
@ -20,6 +21,7 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun DoubleField.expression(mst: MST): Expression<Double> =
|
||||
DoubleWasmBuilder(mst).instance
|
||||
|
||||
@ -28,6 +30,7 @@ public fun DoubleField.expression(mst: MST): Expression<Double> =
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun IntRing.expression(mst: MST): Expression<Int> =
|
||||
IntWasmBuilder(mst).instance
|
||||
|
||||
@ -36,6 +39,7 @@ public fun IntRing.expression(mst: MST): Expression<Int> =
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = compileWith(algebra)
|
||||
|
||||
|
||||
@ -44,6 +48,7 @@ public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = compileW
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
||||
compileToExpression(algebra).invoke(arguments)
|
||||
|
||||
@ -53,6 +58,7 @@ public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int =
|
||||
compileToExpression(algebra)(*arguments)
|
||||
|
||||
@ -61,6 +67,7 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = compileWith(algebra)
|
||||
|
||||
|
||||
@ -69,6 +76,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = c
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||
compileToExpression(algebra).invoke(arguments)
|
||||
|
||||
@ -78,5 +86,6 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
||||
compileToExpression(algebra).invoke(*arguments)
|
||||
|
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
import space.kscience.kmath.estree.compile as estreeCompile
|
||||
import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression
|
||||
import space.kscience.kmath.wasm.compile as wasmCompile
|
||||
import space.kscience.kmath.wasm.compileToExpression as wasmCompileToExpression
|
||||
|
||||
private object WasmCompilerTestContext : CompilerTestContext {
|
||||
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = wasmCompileToExpression(algebra)
|
||||
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = wasmCompile(algebra, arguments)
|
||||
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = wasmCompileToExpression(algebra)
|
||||
|
||||
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||
wasmCompile(algebra, arguments)
|
||||
}
|
||||
|
||||
private object ESTreeCompilerTestContext : CompilerTestContext {
|
||||
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = estreeCompileToExpression(algebra)
|
||||
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = estreeCompile(algebra, arguments)
|
||||
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = estreeCompileToExpression(algebra)
|
||||
|
||||
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||
estreeCompile(algebra, arguments)
|
||||
}
|
||||
|
||||
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
|
||||
action(WasmCompilerTestContext)
|
||||
action(ESTreeCompilerTestContext)
|
||||
}
|
@ -1,97 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.complex.ComplexField
|
||||
import space.kscience.kmath.complex.toComplex
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeConsistencyWithInterpreter {
|
||||
@Test
|
||||
fun mstSpace() {
|
||||
val mst = MstGroup {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + bindSymbol(x) + zero
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(MstGroup, x to MST.Numeric(2)),
|
||||
mst.compile(MstGroup, x to MST.Numeric(2))
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun byteRing() {
|
||||
val mst = MstRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(bindSymbol(x) - (2.toByte() + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(ByteRing, x to 3.toByte()),
|
||||
mst.compile(ByteRing, x to 3.toByte())
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun doubleField() {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(DoubleField, x to 2.0),
|
||||
mst.compile(DoubleField, x to 2.0)
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun complexField() {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(ComplexField, x to 2.0.toComplex()),
|
||||
mst.compile(ComplexField, x to 2.0.toComplex()),
|
||||
)
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.expressions.MstField
|
||||
import space.kscience.kmath.expressions.MstGroup
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeOperationsSupport {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
val res = expression(x to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBinaryOperationInvocation() {
|
||||
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
|
||||
val res = expression(x to 2.0)
|
||||
assertEquals(-1.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.expressions.MstExtendedField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(-2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("+")(
|
||||
bindSymbol(x),
|
||||
bindSymbol(x),
|
||||
)
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSubtract() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("-")(bindSymbol(x),
|
||||
bindSymbol(x))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(1.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("pow")(bindSymbol(x), number(2))
|
||||
}.compileToExpression(DoubleField)
|
||||
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.expressions.MstRing
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
internal class TestESTreeVariables {
|
||||
@Test
|
||||
fun testVariable() {
|
||||
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
|
||||
assertEquals(1.toByte(), expr(x to 1.toByte()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUndefinedVariableFails() {
|
||||
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
|
||||
assertFailsWith<NoSuchElementException> { expr() }
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.wasm
|
||||
|
||||
import space.kscience.kmath.expressions.MstField
|
||||
import space.kscience.kmath.expressions.MstGroup
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestWasmOperationsSupport {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
val res = expression(x to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBinaryOperationInvocation() {
|
||||
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
|
||||
val res = expression(x to 2.0)
|
||||
assertEquals(-1.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.wasm
|
||||
|
||||
import space.kscience.kmath.expressions.MstExtendedField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestWasmSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(-2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("+")(
|
||||
bindSymbol(x),
|
||||
bindSymbol(x),
|
||||
)
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSubtract() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("-")(bindSymbol(x),
|
||||
bindSymbol(x))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(1.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("pow")(bindSymbol(x), number(2))
|
||||
}.compileToExpression(DoubleField)
|
||||
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -342,8 +342,8 @@ internal class AsmBuilder<T>(
|
||||
val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") }
|
||||
|
||||
/**
|
||||
* ASM Type for [kscience.kmath.expressions.Symbol].
|
||||
* ASM Type for [space.kscience.kmath.misc.Symbol].
|
||||
*/
|
||||
val SYMBOL_TYPE: Type by lazy { getObjectType("space/kscience/kmath/expressions/Symbol") }
|
||||
val SYMBOL_TYPE: Type by lazy { getObjectType("space/kscience/kmath/misc/Symbol") }
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
|
||||
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
|
||||
|
||||
/**
|
||||
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
||||
|
@ -1,97 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.asm
|
||||
|
||||
import space.kscience.kmath.complex.ComplexField
|
||||
import space.kscience.kmath.complex.toComplex
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmConsistencyWithInterpreter {
|
||||
@Test
|
||||
fun mstSpace() {
|
||||
val mst = MstGroup {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + bindSymbol(x) + zero
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(MstGroup, x to MST.Numeric(2)),
|
||||
mst.compile(MstGroup, x to MST.Numeric(2))
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun byteRing() {
|
||||
val mst = MstRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(bindSymbol(x) - (2.toByte() + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(ByteRing, x to 3.toByte()),
|
||||
mst.compile(ByteRing, x to 3.toByte())
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun doubleField() {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(DoubleField, x to 2.0),
|
||||
mst.compile(DoubleField, x to 2.0)
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun complexField() {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
mst.interpret(ComplexField, x to 2.0.toComplex()),
|
||||
mst.compile(ComplexField, x to 2.0.toComplex())
|
||||
)
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.asm
|
||||
|
||||
import space.kscience.kmath.expressions.MstField
|
||||
import space.kscience.kmath.expressions.MstGroup
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmOperationsSupport {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
|
||||
val res = expression(x to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBinaryOperationInvocation() {
|
||||
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
|
||||
val res = expression(x to 2.0)
|
||||
assertEquals(-1.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.asm
|
||||
|
||||
import space.kscience.kmath.expressions.MstExtendedField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(-2.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("+")(
|
||||
bindSymbol(x),
|
||||
bindSymbol(x),
|
||||
)
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSubtract() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("-")(bindSymbol(x),
|
||||
bindSymbol(x))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(1.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("pow")(bindSymbol(x), number(2))
|
||||
}.compileToExpression(DoubleField)
|
||||
|
||||
assertEquals(4.0, expr(x to 2.0))
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.asm
|
||||
|
||||
import space.kscience.kmath.expressions.MstRing
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
internal class TestAsmVariables {
|
||||
@Test
|
||||
fun testVariable() {
|
||||
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
|
||||
assertEquals(1.toByte(), expr(x to 1.toByte()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUndefinedVariableFails() {
|
||||
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
|
||||
assertFailsWith<NoSuchElementException> { expr() }
|
||||
}
|
||||
|
||||
private companion object {
|
||||
private val x by symbol
|
||||
}
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
import space.kscience.kmath.asm.compile as asmCompile
|
||||
import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression
|
||||
|
||||
private object AsmCompilerTestContext : CompilerTestContext {
|
||||
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = asmCompileToExpression(algebra)
|
||||
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = asmCompile(algebra, arguments)
|
||||
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = asmCompileToExpression(algebra)
|
||||
|
||||
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||
asmCompile(algebra, arguments)
|
||||
}
|
||||
|
||||
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) = action(AsmCompilerTestContext)
|
@ -14,8 +14,7 @@ The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-de
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -26,8 +25,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
@ -21,8 +21,7 @@ The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-7
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -33,8 +32,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
@ -265,6 +265,8 @@ public final class space/kscience/kmath/expressions/MstExtendedField : space/ksc
|
||||
public fun sin (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST$Unary;
|
||||
public synthetic fun sinh (Ljava/lang/Object;)Ljava/lang/Object;
|
||||
public fun sinh (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST$Unary;
|
||||
public synthetic fun sqrt (Ljava/lang/Object;)Ljava/lang/Object;
|
||||
public fun sqrt (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST;
|
||||
public synthetic fun tan (Ljava/lang/Object;)Ljava/lang/Object;
|
||||
public fun tan (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST$Unary;
|
||||
public synthetic fun tanh (Ljava/lang/Object;)Ljava/lang/Object;
|
||||
@ -745,7 +747,7 @@ public final class space/kscience/kmath/nd/BufferAlgebraNDKt {
|
||||
public static final fun ring (Lspace/kscience/kmath/nd/AlgebraND$Companion;Lspace/kscience/kmath/operations/Ring;Lkotlin/jvm/functions/Function2;[I)Lspace/kscience/kmath/nd/BufferedRingND;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
|
||||
public class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
|
||||
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/Buffer;)V
|
||||
public fun elements ()Lkotlin/sequences/Sequence;
|
||||
public fun get ([I)Ljava/lang/Object;
|
||||
@ -786,10 +788,9 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
|
||||
public fun equals (Ljava/lang/Object;)Z
|
||||
public fun getLinearSize ()I
|
||||
public fun getShape ()[I
|
||||
public fun getStrides ()Ljava/util/List;
|
||||
public fun getStrides ()[I
|
||||
public fun hashCode ()I
|
||||
public fun index (I)[I
|
||||
public fun offset ([I)I
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/nd/DefaultStrides$Companion {
|
||||
@ -873,6 +874,22 @@ public abstract interface class space/kscience/kmath/nd/GroupND : space/kscience
|
||||
public final class space/kscience/kmath/nd/GroupND$Companion {
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/nd/MutableBufferND : space/kscience/kmath/nd/BufferND, space/kscience/kmath/nd/MutableStructureND {
|
||||
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/MutableBuffer;)V
|
||||
public final fun getMutableBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
|
||||
public fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/nd/MutableStructure1D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure1D, space/kscience/kmath/structures/MutableBuffer {
|
||||
public fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/nd/MutableStructure2D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure2D {
|
||||
public fun getColumns ()Ljava/util/List;
|
||||
public fun getRows ()Ljava/util/List;
|
||||
public abstract fun set (IILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/nd/MutableStructureND : space/kscience/kmath/nd/StructureND {
|
||||
public abstract fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
@ -912,10 +929,10 @@ public final class space/kscience/kmath/nd/ShortRingNDKt {
|
||||
public abstract interface class space/kscience/kmath/nd/Strides {
|
||||
public abstract fun getLinearSize ()I
|
||||
public abstract fun getShape ()[I
|
||||
public abstract fun getStrides ()Ljava/util/List;
|
||||
public abstract fun getStrides ()[I
|
||||
public abstract fun index (I)[I
|
||||
public fun indices ()Lkotlin/sequences/Sequence;
|
||||
public abstract fun offset ([I)I
|
||||
public fun offset ([I)I
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/nd/Structure1D : space/kscience/kmath/nd/StructureND, space/kscience/kmath/structures/Buffer {
|
||||
@ -929,6 +946,7 @@ public final class space/kscience/kmath/nd/Structure1D$Companion {
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/nd/Structure1DKt {
|
||||
public static final fun as1D (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructure1D;
|
||||
public static final fun as1D (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/Structure1D;
|
||||
public static final fun asND (Lspace/kscience/kmath/structures/Buffer;)Lspace/kscience/kmath/nd/Structure1D;
|
||||
}
|
||||
@ -949,6 +967,7 @@ public final class space/kscience/kmath/nd/Structure2D$Companion {
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/nd/Structure2DKt {
|
||||
public static final fun as2D (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructure2D;
|
||||
public static final fun as2D (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/Structure2D;
|
||||
}
|
||||
|
||||
@ -988,14 +1007,7 @@ public abstract interface class space/kscience/kmath/operations/Algebra {
|
||||
public fun unaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/operations/AlgebraElement {
|
||||
public abstract fun getContext ()Lspace/kscience/kmath/operations/Algebra;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/operations/AlgebraElementsKt {
|
||||
public static final fun div (Lspace/kscience/kmath/operations/AlgebraElement;Lspace/kscience/kmath/operations/AlgebraElement;)Lspace/kscience/kmath/operations/AlgebraElement;
|
||||
public static final fun plus (Lspace/kscience/kmath/operations/AlgebraElement;Lspace/kscience/kmath/operations/AlgebraElement;)Lspace/kscience/kmath/operations/AlgebraElement;
|
||||
public static final fun times (Lspace/kscience/kmath/operations/AlgebraElement;Lspace/kscience/kmath/operations/AlgebraElement;)Lspace/kscience/kmath/operations/AlgebraElement;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/operations/AlgebraExtensionsKt {
|
||||
|
@ -135,6 +135,7 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
||||
public override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg)
|
||||
public override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||
public override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
|
||||
|
||||
public override fun scale(a: MST, value: Double): MST =
|
||||
binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value))
|
||||
|
@ -19,6 +19,7 @@ import kotlin.reflect.KClass
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public typealias Matrix<T> = Structure2D<T>
|
||||
public typealias MutableMatrix<T> = MutableStructure2D<T>
|
||||
|
||||
/**
|
||||
* Alias or using [Buffer] as a point/vector in a many-dimensional space.
|
||||
|
@ -7,6 +7,8 @@ package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.MutableBufferFactory
|
||||
|
||||
/**
|
||||
* Represents [StructureND] over [Buffer].
|
||||
@ -15,7 +17,7 @@ import space.kscience.kmath.structures.BufferFactory
|
||||
* @param strides The strides to access elements of [Buffer] by linear indices.
|
||||
* @param buffer The underlying buffer.
|
||||
*/
|
||||
public class BufferND<T>(
|
||||
public open class BufferND<T>(
|
||||
public val strides: Strides,
|
||||
public val buffer: Buffer<T>,
|
||||
) : StructureND<T> {
|
||||
@ -50,4 +52,35 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
||||
val strides = DefaultStrides(shape)
|
||||
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [MutableStructureND] over [MutableBuffer].
|
||||
*
|
||||
* @param T the type of items.
|
||||
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
|
||||
* @param mutableBuffer The underlying buffer.
|
||||
*/
|
||||
public class MutableBufferND<T>(
|
||||
strides: Strides,
|
||||
public val mutableBuffer: MutableBuffer<T>,
|
||||
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
|
||||
override fun set(index: IntArray, value: T) {
|
||||
mutableBuffer[strides.offset(index)] = value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
|
||||
*/
|
||||
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
||||
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
|
||||
crossinline transform: (T) -> R,
|
||||
): MutableBufferND<R> {
|
||||
return if (this is MutableBufferND<T>)
|
||||
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
|
||||
else {
|
||||
val strides = DefaultStrides(shape)
|
||||
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||
}
|
||||
}
|
@ -6,6 +6,8 @@
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.asMutableBuffer
|
||||
import space.kscience.kmath.structures.asSequence
|
||||
import kotlin.jvm.JvmInline
|
||||
|
||||
@ -25,6 +27,16 @@ public interface Structure1D<T> : StructureND<T>, Buffer<T> {
|
||||
public companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* A mutable structure that is guaranteed to be one-dimensional
|
||||
*/
|
||||
public interface MutableStructure1D<T> : Structure1D<T>, MutableStructureND<T>, MutableBuffer<T> {
|
||||
public override operator fun set(index: IntArray, value: T) {
|
||||
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
|
||||
set(index[0], value)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A 1D wrapper for nd-structure
|
||||
*/
|
||||
@ -37,6 +49,23 @@ private value class Structure1DWrapper<T>(val structure: StructureND<T>) : Struc
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* A 1D wrapper for a mutable nd-structure
|
||||
*/
|
||||
private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure1D<T> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override val size: Int get() = structure.shape[0]
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
|
||||
override fun get(index: Int): T = structure[index]
|
||||
override fun set(index: Int, value: T) {
|
||||
structure[intArrayOf(index)] = value
|
||||
}
|
||||
|
||||
override fun copy(): MutableBuffer<T> =
|
||||
structure.elements().map { it.second }.toMutableList().asMutableBuffer()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A structure wrapper for buffer
|
||||
@ -52,6 +81,21 @@ private value class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> {
|
||||
override operator fun get(index: Int): T = buffer[index]
|
||||
}
|
||||
|
||||
internal class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
|
||||
override val shape: IntArray get() = intArrayOf(buffer.size)
|
||||
override val size: Int get() = buffer.size
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
buffer.asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
|
||||
|
||||
override operator fun get(index: Int): T = buffer[index]
|
||||
override fun set(index: Int, value: T) {
|
||||
buffer[index] = value
|
||||
}
|
||||
|
||||
override fun copy(): MutableBuffer<T> = buffer.copy()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
@ -62,6 +106,11 @@ public fun <T> StructureND<T>.as1D(): Structure1D<T> = this as? Structure1D<T> ?
|
||||
}
|
||||
} else error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
|
||||
public fun <T> MutableStructureND<T>.as1D(): MutableStructure1D<T> =
|
||||
this as? MutableStructure1D<T> ?: if (shape.size == 1) {
|
||||
MutableStructure1DWrapper(this)
|
||||
} else error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
|
||||
/**
|
||||
* Represent this buffer as 1D structure
|
||||
*/
|
||||
@ -75,3 +124,4 @@ internal fun <T : Any> Structure1D<T>.unwrap(): Buffer<T> = when {
|
||||
this is Structure1DWrapper && structure is BufferND<T> -> structure.buffer
|
||||
else -> this
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@ package space.kscience.kmath.nd
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.VirtualBuffer
|
||||
import space.kscience.kmath.structures.MutableListBuffer
|
||||
import kotlin.jvm.JvmInline
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
@ -63,6 +64,32 @@ public interface Structure2D<T> : StructureND<T> {
|
||||
public companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents mutable [Structure2D].
|
||||
*/
|
||||
public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
|
||||
/**
|
||||
* Inserts an item at the specified indices.
|
||||
*
|
||||
* @param i the first index.
|
||||
* @param j the second index.
|
||||
* @param value the value.
|
||||
*/
|
||||
public operator fun set(i: Int, j: Int, value: T)
|
||||
|
||||
/**
|
||||
* The buffer of rows of this structure. It gets elements from the structure dynamically.
|
||||
*/
|
||||
override val rows: List<MutableStructure1D<T>>
|
||||
get() = List(rowNum) { i -> MutableBuffer1DWrapper(MutableListBuffer(colNum) { j -> get(i, j) })}
|
||||
|
||||
/**
|
||||
* The buffer of columns of this structure. It gets elements from the structure dynamically.
|
||||
*/
|
||||
override val columns: List<MutableStructure1D<T>>
|
||||
get() = List(colNum) { j -> MutableBuffer1DWrapper(MutableListBuffer(rowNum) { i -> get(i, j) }) }
|
||||
}
|
||||
|
||||
/**
|
||||
* A 2D wrapper for nd-structure
|
||||
*/
|
||||
@ -81,6 +108,33 @@ private value class Structure2DWrapper<T>(val structure: StructureND<T>) : Struc
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* A 2D wrapper for a mutable nd-structure
|
||||
*/
|
||||
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T>
|
||||
{
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
|
||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override fun set(index: IntArray, value: T) {
|
||||
structure[index] = value
|
||||
}
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: T){
|
||||
structure[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
|
||||
override fun hashCode(): Int = 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
@ -89,9 +143,18 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
|
||||
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
|
||||
2 -> MutableStructure2DWrapper(this)
|
||||
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
/**
|
||||
* Expose inner [StructureND] if possible
|
||||
*/
|
||||
internal fun <T> Structure2D<T>.unwrap(): StructureND<T> =
|
||||
if (this is Structure2DWrapper) structure
|
||||
else this
|
||||
else this
|
||||
|
||||
internal fun <T> MutableStructure2D<T>.unwrap(): MutableStructureND<T> =
|
||||
if (this is MutableStructure2DWrapper) structure else this
|
||||
|
||||
|
@ -184,12 +184,15 @@ public interface Strides {
|
||||
/**
|
||||
* Array strides
|
||||
*/
|
||||
public val strides: List<Int>
|
||||
public val strides: IntArray
|
||||
|
||||
/**
|
||||
* Get linear index from multidimensional index
|
||||
*/
|
||||
public fun offset(index: IntArray): Int
|
||||
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
|
||||
/**
|
||||
* Get multidimensional from linear
|
||||
@ -221,7 +224,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
override val strides: List<Int> by lazy {
|
||||
override val strides: IntArray by lazy {
|
||||
sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
@ -230,14 +233,9 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
||||
current *= it
|
||||
yield(current)
|
||||
}
|
||||
}.toList()
|
||||
}.toList().toIntArray()
|
||||
}
|
||||
|
||||
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
|
||||
override fun index(offset: Int): IntArray {
|
||||
val res = IntArray(shape.size)
|
||||
var current = offset
|
||||
|
@ -13,6 +13,8 @@ import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
* @param C the type of mathematical context for this element.
|
||||
* @param T the type wrapped by this wrapper.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public interface AlgebraElement<T, C : Algebra<T>> {
|
||||
/**
|
||||
* The context this element belongs to.
|
||||
@ -45,6 +47,7 @@ public interface AlgebraElement<T, C : Algebra<T>> {
|
||||
* @return the difference.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public operator fun <T : AlgebraElement<T, S>, S : NumbersAddOperations<T>> T.minus(b: T): T =
|
||||
context.add(this, context.run { -b })
|
||||
|
||||
@ -55,6 +58,8 @@ public operator fun <T : AlgebraElement<T, S>, S : NumbersAddOperations<T>> T.mi
|
||||
* @param b the addend.
|
||||
* @return the sum.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public operator fun <T : AlgebraElement<T, S>, S : Ring<T>> T.plus(b: T): T =
|
||||
context.add(this, b)
|
||||
|
||||
@ -71,6 +76,8 @@ public operator fun <T : AlgebraElement<T, S>, S : Ring<T>> T.plus(b: T): T =
|
||||
* @param b the multiplier.
|
||||
* @return the product.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public operator fun <T : AlgebraElement<T, R>, R : Ring<T>> T.times(b: T): T =
|
||||
context.multiply(this, b)
|
||||
|
||||
@ -81,6 +88,8 @@ public operator fun <T : AlgebraElement<T, R>, R : Ring<T>> T.times(b: T): T =
|
||||
* @param b the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public operator fun <T : AlgebraElement<T, F>, F : Field<T>> T.div(b: T): T =
|
||||
context.divide(this, b)
|
||||
|
||||
@ -93,6 +102,7 @@ public operator fun <T : AlgebraElement<T, F>, F : Field<T>> T.div(b: T): T =
|
||||
* @param S the type of space.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public interface GroupElement<T : GroupElement<T, S>, S : Group<T>> : AlgebraElement<T, S>
|
||||
|
||||
/**
|
||||
@ -103,6 +113,7 @@ public interface GroupElement<T : GroupElement<T, S>, S : Group<T>> : AlgebraEle
|
||||
* @param R the type of ring.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public interface RingElement<T : RingElement<T, R>, R : Ring<T>> : GroupElement<T, R>
|
||||
|
||||
/**
|
||||
@ -113,4 +124,5 @@ public interface RingElement<T : RingElement<T, R>, R : Ring<T>> : GroupElement<
|
||||
* @param F the type of field.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public interface FieldElement<T : FieldElement<T, F>, F : Field<T>> : RingElement<T, F>
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public interface FieldElement<T : FieldElement<T, F>, F : Field<T>> : RingElement<T, F>
|
||||
|
@ -80,36 +80,42 @@ public interface TrigonometricOperations<T> : Algebra<T> {
|
||||
* Computes the sine of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
|
||||
|
||||
/**
|
||||
* Computes the cosine of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
|
||||
|
||||
/**
|
||||
* Computes the tangent of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse sine of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse cosine of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse tangent of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
|
||||
|
||||
/**
|
||||
@ -154,18 +160,21 @@ public interface PowerOperations<T> : Algebra<T> {
|
||||
* @return the base raised to the power.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public infix fun <T : AlgebraElement<T, out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
||||
|
||||
/**
|
||||
* Computes the square root of the value [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
|
||||
|
||||
/**
|
||||
* Computes the square of the value [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
|
||||
|
||||
/**
|
||||
@ -261,12 +270,14 @@ public interface ExponentialOperations<T> : Algebra<T> {
|
||||
* The identifier of exponential function.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
|
||||
|
||||
/**
|
||||
* The identifier of natural logarithm.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
|
||||
|
||||
|
||||
@ -280,30 +291,35 @@ public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> sinh(arg: T): T
|
||||
* Computes the hyperbolic cosine of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> cosh(arg: T): T = arg.context.cosh(arg)
|
||||
|
||||
/**
|
||||
* Computes the hyperbolic tangent of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> tanh(arg: T): T = arg.context.tanh(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse hyperbolic sine of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> asinh(arg: T): T = arg.context.asinh(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse hyperbolic cosine of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> acosh(arg: T): T = arg.context.acosh(arg)
|
||||
|
||||
/**
|
||||
* Computes the inverse hyperbolic tangent of [arg].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
|
||||
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg)
|
||||
|
||||
/**
|
||||
|
@ -232,7 +232,7 @@ public value class MutableListBuffer<T>(public val list: MutableList<T>) : Mutab
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an [ListBuffer] that wraps the original list.
|
||||
* Returns an [MutableListBuffer] that wraps the original list.
|
||||
*/
|
||||
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
|
||||
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
EJML based linear algebra implementation.
|
||||
|
||||
- [ejml-vector](src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : The Point implementation using SimpleMatrix.
|
||||
- [ejml-matrix](src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : The Matrix implementation using SimpleMatrix.
|
||||
- [ejml-linear-space](src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : The LinearSpace implementation using SimpleMatrix.
|
||||
- [ejml-vector](src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : Point implementations.
|
||||
- [ejml-matrix](src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation.
|
||||
- [ejml-linear-space](src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations.
|
||||
|
||||
|
||||
## Artifact:
|
||||
@ -15,8 +15,7 @@ The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-7
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -27,8 +26,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
@ -4,7 +4,7 @@ plugins {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api("org.ejml:ejml-simple:0.40")
|
||||
api("org.ejml:ejml-ddense:0.40")
|
||||
api(project(":kmath-core"))
|
||||
}
|
||||
|
||||
@ -14,19 +14,19 @@ readme {
|
||||
|
||||
feature(
|
||||
id = "ejml-vector",
|
||||
description = "The Point implementation using SimpleMatrix.",
|
||||
description = "Point implementations.",
|
||||
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "ejml-matrix",
|
||||
description = "The Matrix implementation using SimpleMatrix.",
|
||||
description = "Matrix implementation.",
|
||||
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "ejml-linear-space",
|
||||
description = "The LinearSpace implementation using SimpleMatrix.",
|
||||
description = "LinearSpace implementations.",
|
||||
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt"
|
||||
)
|
||||
}
|
||||
|
@ -5,45 +5,71 @@
|
||||
|
||||
package space.kscience.kmath.ejml
|
||||
|
||||
import org.ejml.data.DMatrix
|
||||
import org.ejml.data.DMatrixD1
|
||||
import org.ejml.data.DMatrixRMaj
|
||||
import org.ejml.dense.row.CommonOps_DDRM
|
||||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import space.kscience.kmath.linear.*
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.StructureFeature
|
||||
import space.kscience.kmath.nd.getFeature
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.cast
|
||||
|
||||
/**
|
||||
* Represents context of basic operations operating with [EjmlMatrix].
|
||||
* [LinearSpace] implementation specialized for a certain EJML type.
|
||||
*
|
||||
* @param T the type of items in the matrices.
|
||||
* @param A the element context type.
|
||||
* @param M the EJML matrix type.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public abstract class EjmlLinearSpace<T : Any, out A : Ring<T>, M : org.ejml.data.Matrix> : LinearSpace<T, A> {
|
||||
/**
|
||||
* Converts this matrix to EJML one.
|
||||
*/
|
||||
public abstract fun Matrix<T>.toEjml(): EjmlMatrix<T, M>
|
||||
|
||||
/**
|
||||
* Converts this vector to EJML one.
|
||||
*/
|
||||
public abstract fun Point<T>.toEjml(): EjmlVector<T, M>
|
||||
|
||||
public abstract override fun buildMatrix(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: A.(i: Int, j: Int) -> T,
|
||||
): EjmlMatrix<T, M>
|
||||
|
||||
public abstract override fun buildVector(size: Int, initializer: A.(Int) -> T): EjmlVector<T, M>
|
||||
}
|
||||
|
||||
/**
|
||||
* [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and
|
||||
* [DMatrixRMaj] matrices.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public object EjmlLinearSpace : LinearSpace<Double, DoubleField> {
|
||||
public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, DoubleField, DMatrixRMaj>() {
|
||||
/**
|
||||
* The [DoubleField] reference.
|
||||
*/
|
||||
public override val elementAlgebra: DoubleField get() = DoubleField
|
||||
|
||||
/**
|
||||
* Converts this matrix to EJML one.
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun Matrix<Double>.toEjml(): EjmlMatrix = when (val matrix = origin) {
|
||||
is EjmlMatrix -> matrix
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixRMaj> = when {
|
||||
this is EjmlDoubleMatrix<*> && origin is DMatrixRMaj -> this as EjmlDoubleMatrix<DMatrixRMaj>
|
||||
else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts this vector to EJML one.
|
||||
*/
|
||||
public fun Point<Double>.toEjml(): EjmlVector = when (this) {
|
||||
is EjmlVector -> this
|
||||
else -> EjmlVector(SimpleMatrix(size, 1).also {
|
||||
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) }
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public override fun Point<Double>.toEjml(): EjmlDoubleVector<DMatrixRMaj> = when {
|
||||
this is EjmlDoubleVector<*> && origin is DMatrixRMaj -> this as EjmlDoubleVector<DMatrixRMaj>
|
||||
else -> EjmlDoubleVector(DMatrixRMaj(size, 1).also {
|
||||
(0 until it.numRows).forEach { row -> it[row, 0] = get(row) }
|
||||
})
|
||||
}
|
||||
|
||||
@ -51,159 +77,178 @@ public object EjmlLinearSpace : LinearSpace<Double, DoubleField> {
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: DoubleField.(i: Int, j: Int) -> Double,
|
||||
): EjmlMatrix = EjmlMatrix(SimpleMatrix(rows, columns).also {
|
||||
): EjmlDoubleMatrix<DMatrixRMaj> = EjmlDoubleMatrix(DMatrixRMaj(rows, columns).also {
|
||||
(0 until rows).forEach { row ->
|
||||
(0 until columns).forEach { col -> it[row, col] = DoubleField.initializer(row, col) }
|
||||
(0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) }
|
||||
}
|
||||
})
|
||||
|
||||
public override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): Point<Double> =
|
||||
EjmlVector(SimpleMatrix(size, 1).also {
|
||||
(0 until it.numRows()).forEach { row -> it[row, 0] = DoubleField.initializer(row) }
|
||||
})
|
||||
public override fun buildVector(
|
||||
size: Int,
|
||||
initializer: DoubleField.(Int) -> Double,
|
||||
): EjmlDoubleVector<DMatrixRMaj> = EjmlDoubleVector(DMatrixRMaj(size, 1).also {
|
||||
(0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) }
|
||||
})
|
||||
|
||||
private fun SimpleMatrix.wrapMatrix() = EjmlMatrix(this)
|
||||
private fun SimpleMatrix.wrapVector() = EjmlVector(this)
|
||||
private fun <T : DMatrix> T.wrapMatrix() = EjmlDoubleMatrix(this)
|
||||
private fun <T : DMatrixD1> T.wrapVector() = EjmlDoubleVector(this)
|
||||
|
||||
public override fun Matrix<Double>.unaryMinus(): Matrix<Double> = this * (-1.0)
|
||||
|
||||
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
|
||||
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
|
||||
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
|
||||
val out = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.mult(toEjml().origin, other.toEjml().origin, out)
|
||||
return out.wrapMatrix()
|
||||
}
|
||||
|
||||
public override fun Matrix<Double>.dot(vector: Point<Double>): EjmlVector =
|
||||
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
|
||||
public override fun Matrix<Double>.dot(vector: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
|
||||
val out = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.mult(toEjml().origin, vector.toEjml().origin, out)
|
||||
return out.wrapVector()
|
||||
}
|
||||
|
||||
public override operator fun Matrix<Double>.minus(other: Matrix<Double>): EjmlMatrix =
|
||||
(toEjml().origin - other.toEjml().origin).wrapMatrix()
|
||||
public override operator fun Matrix<Double>.minus(other: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
|
||||
val out = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.subtract(toEjml().origin, other.toEjml().origin, out)
|
||||
return out.wrapMatrix()
|
||||
}
|
||||
|
||||
public override operator fun Matrix<Double>.times(value: Double): EjmlMatrix =
|
||||
toEjml().origin.scale(value).wrapMatrix()
|
||||
public override operator fun Matrix<Double>.times(value: Double): EjmlDoubleMatrix<DMatrixRMaj> {
|
||||
val res = this.toEjml().origin.copy()
|
||||
CommonOps_DDRM.scale(value, res)
|
||||
return res.wrapMatrix()
|
||||
}
|
||||
|
||||
public override fun Point<Double>.unaryMinus(): EjmlVector =
|
||||
toEjml().origin.negative().wrapVector()
|
||||
public override fun Point<Double>.unaryMinus(): EjmlDoubleVector<DMatrixRMaj> {
|
||||
val out = toEjml().origin.copy()
|
||||
CommonOps_DDRM.changeSign(out)
|
||||
return out.wrapVector()
|
||||
}
|
||||
|
||||
public override fun Matrix<Double>.plus(other: Matrix<Double>): EjmlMatrix =
|
||||
(toEjml().origin + other.toEjml().origin).wrapMatrix()
|
||||
public override fun Matrix<Double>.plus(other: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
|
||||
val out = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.add(toEjml().origin, other.toEjml().origin, out)
|
||||
return out.wrapMatrix()
|
||||
}
|
||||
|
||||
public override fun Point<Double>.plus(other: Point<Double>): EjmlVector =
|
||||
(toEjml().origin + other.toEjml().origin).wrapVector()
|
||||
public override fun Point<Double>.plus(other: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
|
||||
val out = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.add(toEjml().origin, other.toEjml().origin, out)
|
||||
return out.wrapVector()
|
||||
}
|
||||
|
||||
public override fun Point<Double>.minus(other: Point<Double>): EjmlVector =
|
||||
(toEjml().origin - other.toEjml().origin).wrapVector()
|
||||
public override fun Point<Double>.minus(other: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
|
||||
val out = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.subtract(toEjml().origin, other.toEjml().origin, out)
|
||||
return out.wrapVector()
|
||||
}
|
||||
|
||||
public override fun Double.times(m: Matrix<Double>): EjmlMatrix =
|
||||
m.toEjml().origin.scale(this).wrapMatrix()
|
||||
public override fun Double.times(m: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> = m * this
|
||||
|
||||
public override fun Point<Double>.times(value: Double): EjmlVector =
|
||||
toEjml().origin.scale(value).wrapVector()
|
||||
public override fun Point<Double>.times(value: Double): EjmlDoubleVector<DMatrixRMaj> {
|
||||
val res = this.toEjml().origin.copy()
|
||||
CommonOps_DDRM.scale(value, res)
|
||||
return res.wrapVector()
|
||||
}
|
||||
|
||||
public override fun Double.times(v: Point<Double>): EjmlVector =
|
||||
v.toEjml().origin.scale(this).wrapVector()
|
||||
public override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this
|
||||
|
||||
@UnstableKMathAPI
|
||||
public override fun <F : StructureFeature> getFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
|
||||
//Return the feature if it is intrinsic to the structure
|
||||
// Return the feature if it is intrinsic to the structure
|
||||
structure.getFeature(type)?.let { return it }
|
||||
|
||||
val origin = structure.toEjml().origin
|
||||
|
||||
return when (type) {
|
||||
InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
|
||||
override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
|
||||
override val inverse: Matrix<Double> by lazy {
|
||||
val res = origin.copy()
|
||||
CommonOps_DDRM.invert(res)
|
||||
EjmlDoubleMatrix(res)
|
||||
}
|
||||
}
|
||||
|
||||
DeterminantFeature::class -> object : DeterminantFeature<Double> {
|
||||
override val determinant: Double by lazy(origin::determinant)
|
||||
override val determinant: Double by lazy { CommonOps_DDRM.det(DMatrixRMaj(origin)) }
|
||||
}
|
||||
|
||||
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
|
||||
private val svd by lazy {
|
||||
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
|
||||
.apply { decompose(origin.ddrm.copy()) }
|
||||
DecompositionFactory_DDRM.svd(origin.numRows, origin.numCols, true, true, false)
|
||||
.apply { decompose(origin.copy()) }
|
||||
}
|
||||
|
||||
override val u: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) }
|
||||
override val s: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) }
|
||||
override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
|
||||
override val u: Matrix<Double> by lazy { EjmlDoubleMatrix(svd.getU(null, false)) }
|
||||
override val s: Matrix<Double> by lazy { EjmlDoubleMatrix(svd.getW(null)) }
|
||||
override val v: Matrix<Double> by lazy { EjmlDoubleMatrix(svd.getV(null, false)) }
|
||||
override val singularValues: Point<Double> by lazy { DoubleBuffer(svd.singularValues) }
|
||||
}
|
||||
|
||||
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
|
||||
private val qr by lazy {
|
||||
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
|
||||
DecompositionFactory_DDRM.qr().apply { decompose(origin.copy()) }
|
||||
}
|
||||
|
||||
override val q: Matrix<Double> by lazy {
|
||||
EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) + OrthogonalFeature
|
||||
EjmlDoubleMatrix(qr.getQ(null, false)) + OrthogonalFeature
|
||||
}
|
||||
|
||||
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) + UFeature }
|
||||
override val r: Matrix<Double> by lazy { EjmlDoubleMatrix(qr.getR(null, false)) + UFeature }
|
||||
}
|
||||
|
||||
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
|
||||
override val l: Matrix<Double> by lazy {
|
||||
val cholesky =
|
||||
DecompositionFactory_DDRM.chol(structure.rowNum, true).apply { decompose(origin.ddrm.copy()) }
|
||||
DecompositionFactory_DDRM.chol(structure.rowNum, true).apply { decompose(origin.copy()) }
|
||||
|
||||
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
|
||||
EjmlDoubleMatrix(cholesky.getT(null)) + LFeature
|
||||
}
|
||||
}
|
||||
|
||||
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
|
||||
private val lup by lazy {
|
||||
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
|
||||
.apply { decompose(origin.ddrm.copy()) }
|
||||
DecompositionFactory_DDRM.lu(origin.numRows, origin.numCols).apply { decompose(origin.copy()) }
|
||||
}
|
||||
|
||||
override val l: Matrix<Double> by lazy {
|
||||
EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature
|
||||
EjmlDoubleMatrix(lup.getLower(null)) + LFeature
|
||||
}
|
||||
|
||||
override val u: Matrix<Double> by lazy {
|
||||
EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature
|
||||
EjmlDoubleMatrix(lup.getUpper(null)) + UFeature
|
||||
}
|
||||
|
||||
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
|
||||
override val p: Matrix<Double> by lazy { EjmlDoubleMatrix(lup.getRowPivot(null)) }
|
||||
}
|
||||
|
||||
else -> null
|
||||
}?.let(type::cast)
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> · [b]*.
|
||||
*
|
||||
* @param a the base matrix.
|
||||
* @param b n by p matrix.
|
||||
* @return the solution for 'x' that is n by p.
|
||||
*/
|
||||
public fun solve(a: Matrix<Double>, b: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
|
||||
val res = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res)
|
||||
return EjmlDoubleMatrix(res)
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> · [b]*.
|
||||
*
|
||||
* @param a the base matrix.
|
||||
* @param b n by p vector.
|
||||
* @return the solution for 'x' that is n by p.
|
||||
*/
|
||||
public fun solve(a: Matrix<Double>, b: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
|
||||
val res = DMatrixRMaj(1, 1)
|
||||
CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res)
|
||||
return EjmlDoubleVector(res)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> · [b]*.
|
||||
*
|
||||
* @param a the base matrix.
|
||||
* @param b n by p matrix.
|
||||
* @return the solution for 'x' that is n by p.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public fun EjmlLinearSpace.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
||||
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
|
||||
|
||||
/**
|
||||
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> · [b]*.
|
||||
*
|
||||
* @param a the base matrix.
|
||||
* @param b n by p vector.
|
||||
* @return the solution for 'x' that is n by p.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public fun EjmlLinearSpace.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
|
||||
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
|
||||
|
||||
/**
|
||||
* Inverts this matrix.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature<InverseMatrixFeature<Double>>()!!.inverse as EjmlMatrix
|
||||
|
||||
/**
|
||||
* Inverts the given matrix.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public fun EjmlLinearSpace.inverse(matrix: Matrix<Double>): Matrix<Double> = matrix.toEjml().inverted()
|
@ -5,18 +5,28 @@
|
||||
|
||||
package space.kscience.kmath.ejml
|
||||
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import space.kscience.kmath.linear.Matrix
|
||||
import org.ejml.data.DMatrix
|
||||
import org.ejml.data.Matrix
|
||||
import space.kscience.kmath.nd.Structure2D
|
||||
|
||||
/**
|
||||
* The matrix implementation over EJML [SimpleMatrix].
|
||||
* [space.kscience.kmath.linear.Matrix] implementation based on EJML [Matrix].
|
||||
*
|
||||
* @property origin the underlying [SimpleMatrix].
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @param M the type of EJML matrix.
|
||||
* @property origin The underlying EJML matrix.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public class EjmlMatrix(public val origin: SimpleMatrix) : Matrix<Double> {
|
||||
public override val rowNum: Int get() = origin.numRows()
|
||||
public override val colNum: Int get() = origin.numCols()
|
||||
public abstract class EjmlMatrix<T, out M : Matrix>(public open val origin: M) : Structure2D<T> {
|
||||
public override val rowNum: Int get() = origin.numRows
|
||||
public override val colNum: Int get() = origin.numCols
|
||||
}
|
||||
|
||||
/**
|
||||
* [EjmlMatrix] specialization for [Double].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public class EjmlDoubleMatrix<out M : DMatrix>(public override val origin: M) : EjmlMatrix<Double, M>(origin) {
|
||||
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||
}
|
||||
|
@ -5,35 +5,41 @@
|
||||
|
||||
package space.kscience.kmath.ejml
|
||||
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import org.ejml.data.DMatrixD1
|
||||
import org.ejml.data.Matrix
|
||||
import space.kscience.kmath.linear.Point
|
||||
|
||||
/**
|
||||
* Represents point over EJML [SimpleMatrix].
|
||||
* [Point] implementation based on EJML [Matrix].
|
||||
*
|
||||
* @property origin the underlying [SimpleMatrix].
|
||||
* @param T the type of elements contained in the buffer.
|
||||
* @param M the type of EJML matrix.
|
||||
* @property origin The underlying matrix.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public class EjmlVector internal constructor(public val origin: SimpleMatrix) : Point<Double> {
|
||||
public abstract class EjmlVector<out T, out M : Matrix>(public open val origin: M) : Point<T> {
|
||||
public override val size: Int
|
||||
get() = origin.numRows()
|
||||
get() = origin.numRows
|
||||
|
||||
init {
|
||||
require(origin.numCols() == 1) { "Only single column matrices are allowed" }
|
||||
}
|
||||
|
||||
public override operator fun get(index: Int): Double = origin[index]
|
||||
|
||||
public override operator fun iterator(): Iterator<Double> = object : Iterator<Double> {
|
||||
public override operator fun iterator(): Iterator<T> = object : Iterator<T> {
|
||||
private var cursor: Int = 0
|
||||
|
||||
override fun next(): Double {
|
||||
override fun next(): T {
|
||||
cursor += 1
|
||||
return origin[cursor - 1]
|
||||
return this@EjmlVector[cursor - 1]
|
||||
}
|
||||
|
||||
override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows()
|
||||
override fun hasNext(): Boolean = cursor < origin.numCols * origin.numRows
|
||||
}
|
||||
|
||||
public override fun toString(): String = "EjmlVector(origin=$origin)"
|
||||
}
|
||||
|
||||
/**
|
||||
* [EjmlVector] specialization for [Double].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public class EjmlDoubleVector<out M : DMatrixD1>(public override val origin: M) : EjmlVector<Double, M>(origin) {
|
||||
public override operator fun get(index: Int): Double = origin[index]
|
||||
}
|
||||
|
@ -5,12 +5,15 @@
|
||||
|
||||
package space.kscience.kmath.ejml
|
||||
|
||||
import org.ejml.data.DMatrixRMaj
|
||||
import org.ejml.dense.row.CommonOps_DDRM
|
||||
import org.ejml.dense.row.RandomMatrices_DDRM
|
||||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import space.kscience.kmath.linear.*
|
||||
import space.kscience.kmath.linear.DeterminantFeature
|
||||
import space.kscience.kmath.linear.LupDecompositionFeature
|
||||
import space.kscience.kmath.linear.getFeature
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.getFeature
|
||||
import kotlin.random.Random
|
||||
import kotlin.random.asJavaRandom
|
||||
import kotlin.test.*
|
||||
@ -22,65 +25,59 @@ fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T
|
||||
internal class EjmlMatrixTest {
|
||||
private val random = Random(0)
|
||||
|
||||
private val randomMatrix: SimpleMatrix
|
||||
private val randomMatrix: DMatrixRMaj
|
||||
get() {
|
||||
val s = random.nextInt(2, 100)
|
||||
return SimpleMatrix.random_DDRM(s, s, 0.0, 10.0, random.asJavaRandom())
|
||||
val d = DMatrixRMaj(s, s)
|
||||
RandomMatrices_DDRM.fillUniform(d, random.asJavaRandom())
|
||||
return d
|
||||
}
|
||||
|
||||
@Test
|
||||
fun rowNum() {
|
||||
val m = randomMatrix
|
||||
assertEquals(m.numRows(), EjmlMatrix(m).rowNum)
|
||||
assertEquals(m.numRows, EjmlDoubleMatrix(m).rowNum)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun colNum() {
|
||||
val m = randomMatrix
|
||||
assertEquals(m.numCols(), EjmlMatrix(m).rowNum)
|
||||
assertEquals(m.numCols, EjmlDoubleMatrix(m).rowNum)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shape() {
|
||||
val m = randomMatrix
|
||||
val w = EjmlMatrix(m)
|
||||
assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList())
|
||||
val w = EjmlDoubleMatrix(m)
|
||||
assertContentEquals(intArrayOf(m.numRows, m.numCols), w.shape)
|
||||
}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
@Test
|
||||
fun features() {
|
||||
val m = randomMatrix
|
||||
val w = EjmlMatrix(m)
|
||||
val det: DeterminantFeature<Double> = EjmlLinearSpace.getFeature(w) ?: fail()
|
||||
assertEquals(m.determinant(), det.determinant)
|
||||
val lup: LupDecompositionFeature<Double> = EjmlLinearSpace.getFeature(w) ?: fail()
|
||||
val w = EjmlDoubleMatrix(m)
|
||||
val det: DeterminantFeature<Double> = EjmlLinearSpaceDDRM.getFeature(w) ?: fail()
|
||||
assertEquals(CommonOps_DDRM.det(m), det.determinant)
|
||||
val lup: LupDecompositionFeature<Double> = EjmlLinearSpaceDDRM.getFeature(w) ?: fail()
|
||||
|
||||
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols())
|
||||
.also { it.decompose(m.ddrm.copy()) }
|
||||
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows, m.numCols)
|
||||
.also { it.decompose(m.copy()) }
|
||||
|
||||
assertMatrixEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), lup.l)
|
||||
assertMatrixEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), lup.u)
|
||||
assertMatrixEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), lup.p)
|
||||
}
|
||||
|
||||
private object SomeFeature : MatrixFeature {}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
@Test
|
||||
fun suggestFeature() {
|
||||
assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>())
|
||||
assertMatrixEquals(EjmlDoubleMatrix(ludecompositionF64.getLower(null)), lup.l)
|
||||
assertMatrixEquals(EjmlDoubleMatrix(ludecompositionF64.getUpper(null)), lup.u)
|
||||
assertMatrixEquals(EjmlDoubleMatrix(ludecompositionF64.getRowPivot(null)), lup.p)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun get() {
|
||||
val m = randomMatrix
|
||||
assertEquals(m[0, 0], EjmlMatrix(m)[0, 0])
|
||||
assertEquals(m[0, 0], EjmlDoubleMatrix(m)[0, 0])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun origin() {
|
||||
val m = randomMatrix
|
||||
assertSame(m, EjmlMatrix(m).origin)
|
||||
assertSame(m, EjmlDoubleMatrix(m).origin)
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,8 @@
|
||||
|
||||
package space.kscience.kmath.ejml
|
||||
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import org.ejml.data.DMatrixRMaj
|
||||
import org.ejml.dense.row.RandomMatrices_DDRM
|
||||
import kotlin.random.Random
|
||||
import kotlin.random.asJavaRandom
|
||||
import kotlin.test.Test
|
||||
@ -15,30 +16,34 @@ import kotlin.test.assertSame
|
||||
internal class EjmlVectorTest {
|
||||
private val random = Random(0)
|
||||
|
||||
private val randomMatrix: SimpleMatrix
|
||||
get() = SimpleMatrix.random_DDRM(random.nextInt(2, 100), 1, 0.0, 10.0, random.asJavaRandom())
|
||||
private val randomMatrix: DMatrixRMaj
|
||||
get() {
|
||||
val d = DMatrixRMaj(random.nextInt(2, 100), 1)
|
||||
RandomMatrices_DDRM.fillUniform(d, random.asJavaRandom())
|
||||
return d
|
||||
}
|
||||
|
||||
@Test
|
||||
fun size() {
|
||||
val m = randomMatrix
|
||||
val w = EjmlVector(m)
|
||||
assertEquals(m.numRows(), w.size)
|
||||
val w = EjmlDoubleVector(m)
|
||||
assertEquals(m.numRows, w.size)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun get() {
|
||||
val m = randomMatrix
|
||||
val w = EjmlVector(m)
|
||||
val w = EjmlDoubleVector(m)
|
||||
assertEquals(m[0, 0], w[0])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun iterator() {
|
||||
val m = randomMatrix
|
||||
val w = EjmlVector(m)
|
||||
val w = EjmlDoubleVector(m)
|
||||
|
||||
assertEquals(
|
||||
m.iterator(true, 0, 0, m.numRows() - 1, 0).asSequence().toList(),
|
||||
m.iterator(true, 0, 0, m.numRows - 1, 0).asSequence().toList(),
|
||||
w.iterator().asSequence().toList()
|
||||
)
|
||||
}
|
||||
@ -46,7 +51,7 @@ internal class EjmlVectorTest {
|
||||
@Test
|
||||
fun origin() {
|
||||
val m = randomMatrix
|
||||
val w = EjmlVector(m)
|
||||
val w = EjmlDoubleVector(m)
|
||||
assertSame(m, w.origin)
|
||||
}
|
||||
}
|
||||
|
@ -15,8 +15,7 @@ The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-d
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -27,8 +26,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
@ -17,8 +17,7 @@ The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -29,8 +28,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
19
kmath-jupyter/build.gradle.kts
Normal file
19
kmath-jupyter/build.gradle.kts
Normal file
@ -0,0 +1,19 @@
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.jvm")
|
||||
kotlin("jupyter.api")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-ast"))
|
||||
api(project(":kmath-complex"))
|
||||
api(project(":kmath-for-real"))
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-html-jvm:0.7.3")
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||
}
|
||||
|
||||
kotlin.sourceSets.all {
|
||||
languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
package space.kscience.kmath.jupyter
|
||||
|
||||
import kotlinx.html.Unsafe
|
||||
import kotlinx.html.div
|
||||
import kotlinx.html.stream.createHTML
|
||||
import kotlinx.html.unsafe
|
||||
import org.jetbrains.kotlinx.jupyter.api.DisplayResult
|
||||
import org.jetbrains.kotlinx.jupyter.api.HTML
|
||||
import org.jetbrains.kotlinx.jupyter.api.annotations.JupyterLibrary
|
||||
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess
|
||||
import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer
|
||||
import space.kscience.kmath.ast.rendering.renderWithStringBuilder
|
||||
import space.kscience.kmath.complex.Complex
|
||||
import space.kscience.kmath.nd.Structure2D
|
||||
import space.kscience.kmath.operations.GroupOperations
|
||||
import space.kscience.kmath.operations.RingOperations
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.asSequence
|
||||
|
||||
@JupyterLibrary
|
||||
internal class KMathJupyter : JupyterIntegration() {
|
||||
private val mathRender = FeaturedMathRendererWithPostProcess.Default
|
||||
private val syntaxRender = MathMLSyntaxRenderer
|
||||
|
||||
override fun Builder.onLoaded() {
|
||||
import(
|
||||
"space.kscience.kmath.ast.*",
|
||||
"space.kscience.kmath.ast.rendering.*",
|
||||
"space.kscience.kmath.operations.*",
|
||||
"space.kscience.kmath.expressions.*",
|
||||
"space.kscience.kmath.misc.*",
|
||||
"space.kscience.kmath.real.*",
|
||||
)
|
||||
|
||||
fun MST.toDisplayResult(): DisplayResult = HTML(createHTML().div {
|
||||
unsafe {
|
||||
+syntaxRender.renderWithStringBuilder(mathRender.render(this@toDisplayResult))
|
||||
}
|
||||
})
|
||||
|
||||
render<MST> { it.toDisplayResult() }
|
||||
render<Number> { MST.Numeric(it).toDisplayResult() }
|
||||
|
||||
fun Unsafe.appendCellValue(it: Any?) {
|
||||
when (it) {
|
||||
is Number -> {
|
||||
val s = StringBuilder()
|
||||
syntaxRender.renderPart(mathRender.render(MST.Numeric(it)), s)
|
||||
+s.toString()
|
||||
}
|
||||
is MST -> {
|
||||
val s = StringBuilder()
|
||||
syntaxRender.renderPart(mathRender.render(it), s)
|
||||
+s.toString()
|
||||
}
|
||||
else -> {
|
||||
+"<ms>"
|
||||
+it.toString()
|
||||
+"</ms>"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
render<Structure2D<*>> { structure ->
|
||||
HTML(createHTML().div {
|
||||
unsafe {
|
||||
+"<math xmlns=\"https://www.w3.org/1998/Math/MathML\">"
|
||||
+"<mrow>"
|
||||
+"<mfenced open=\"[\" close=\"]\" separators=\"\">"
|
||||
+"<mtable>"
|
||||
structure.rows.forEach { row ->
|
||||
+"<mtr>"
|
||||
row.asSequence().forEach {
|
||||
+"<mtd>"
|
||||
appendCellValue(it)
|
||||
+"</mtd>"
|
||||
}
|
||||
+"</mtr>"
|
||||
}
|
||||
+"</mtable>"
|
||||
+"</mfenced>"
|
||||
+"</mrow>"
|
||||
+"</math>"
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
render<Buffer<*>> { buffer ->
|
||||
HTML(createHTML().div {
|
||||
unsafe {
|
||||
+"<math xmlns=\"https://www.w3.org/1998/Math/MathML\">"
|
||||
+"<mrow>"
|
||||
+"<mfenced open=\"[\" close=\"]\" separators=\"\">"
|
||||
+"<mtable>"
|
||||
buffer.asSequence().forEach {
|
||||
+"<mtr>"
|
||||
+"<mtd>"
|
||||
appendCellValue(it)
|
||||
+"</mtd>"
|
||||
+"</mtr>"
|
||||
}
|
||||
+"</mtable>"
|
||||
+"</mfenced>"
|
||||
+"</mrow>"
|
||||
+"</math>"
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
render<Complex> {
|
||||
MST.Binary(
|
||||
operation = GroupOperations.PLUS_OPERATION,
|
||||
left = MST.Numeric(it.re),
|
||||
right = MST.Binary(RingOperations.TIMES_OPERATION, MST.Numeric(it.im), MST.Symbolic("i")),
|
||||
).toDisplayResult()
|
||||
}
|
||||
}
|
||||
}
|
@ -4,8 +4,8 @@ plugins {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("com.github.breandan:kaliningraph:0.1.4")
|
||||
implementation("com.github.breandan:kotlingrad:0.4.0")
|
||||
api("com.github.breandan:kaliningraph:0.1.4")
|
||||
api("com.github.breandan:kotlingrad:0.4.5")
|
||||
api(project(":kmath-ast"))
|
||||
}
|
||||
|
||||
|
@ -15,8 +15,7 @@ The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-7
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -27,8 +26,7 @@ dependencies {
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
40
kmath-tensors/README.md
Normal file
40
kmath-tensors/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
# Module kmath-tensors
|
||||
|
||||
Common operations on tensors, the API consists of:
|
||||
|
||||
- [TensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic algebra operations on tensors (plus, dot, etc.)
|
||||
- [TensorPartialDivisionAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt) : Emulates an algebra over a field
|
||||
- [LinearOpsTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Linear algebra operations including LU, QR, Cholesky LL and SVD decompositions
|
||||
- [AnalyticTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt) : Element-wise analytic operations
|
||||
|
||||
The library offers a multiplatform implementation for this interface over the `Double`'s. As a highlight, the user can find:
|
||||
- [BroadcastDoubleTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic algebra operations implemented with broadcasting.
|
||||
- [DoubleLinearOpsTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt) : Contains the power method for SVD and the spectrum of symmetric matrices.
|
||||
## Artifact:
|
||||
|
||||
The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-7`.
|
||||
|
||||
**Gradle:**
|
||||
```gradle
|
||||
repositories {
|
||||
maven { url 'https://repo.kotlin.link' }
|
||||
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'space.kscience:kmath-tensors:0.3.0-dev-7'
|
||||
}
|
||||
```
|
||||
**Gradle Kotlin DSL:**
|
||||
```kotlin
|
||||
repositories {
|
||||
maven("https://repo.kotlin.link")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("space.kscience:kmath-tensors:0.3.0-dev-7")
|
||||
}
|
||||
```
|
43
kmath-tensors/build.gradle.kts
Normal file
43
kmath-tensors/build.gradle.kts
Normal file
@ -0,0 +1,43 @@
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.mpp")
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
all {
|
||||
languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||
}
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api(project(":kmath-stat"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks.dokkaHtml {
|
||||
dependsOn(tasks.build)
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
|
||||
feature(
|
||||
id = "tensor algebra",
|
||||
description = "Basic linear algebra operations on tensors (plus, dot, etc.)",
|
||||
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "tensor algebra with broadcasting",
|
||||
description = "Basic linear algebra operations implemented with broadcasting.",
|
||||
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "linear algebra operations",
|
||||
description = "Advanced linear algebra operations like LU decomposition, SVD, etc.",
|
||||
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt"
|
||||
)
|
||||
|
||||
}
|
7
kmath-tensors/docs/README-TEMPLATE.md
Normal file
7
kmath-tensors/docs/README-TEMPLATE.md
Normal file
@ -0,0 +1,7 @@
|
||||
# Module kmath-tensors
|
||||
|
||||
Common linear algebra operations on tensors.
|
||||
|
||||
${features}
|
||||
|
||||
${artifact}
|
@ -0,0 +1,131 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
|
||||
/**
|
||||
* Analytic operations on [Tensor].
|
||||
*
|
||||
* @param T the type of items closed under analytic functions in the tensors.
|
||||
*/
|
||||
public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
|
||||
|
||||
/**
|
||||
* @return the mean of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.mean(): T
|
||||
|
||||
/**
|
||||
* Returns the mean of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the mean of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the standard deviation of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.std(): T
|
||||
|
||||
/**
|
||||
* Returns the standard deviation of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the standard deviation of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the variance of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.variance(): T
|
||||
|
||||
/**
|
||||
* Returns the variance of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the variance of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns the covariance matrix M of given vectors.
|
||||
*
|
||||
* M[i, j] contains covariance of i-th and j-th given vectors
|
||||
*
|
||||
* @param tensors the [List] of 1-dimensional tensors with same shape
|
||||
* @return the covariance matrix
|
||||
*/
|
||||
public fun cov(tensors: List<Tensor<T>>): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||
public fun Tensor<T>.exp(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.log.html
|
||||
public fun Tensor<T>.ln(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
|
||||
public fun Tensor<T>.sqrt(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
|
||||
public fun Tensor<T>.cos(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
|
||||
public fun Tensor<T>.acos(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
|
||||
public fun Tensor<T>.cosh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
|
||||
public fun Tensor<T>.acosh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
|
||||
public fun Tensor<T>.sin(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
|
||||
public fun Tensor<T>.asin(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
|
||||
public fun Tensor<T>.sinh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
|
||||
public fun Tensor<T>.asinh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
|
||||
public fun Tensor<T>.tan(): Tensor<T>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
|
||||
public fun Tensor<T>.atan(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
|
||||
public fun Tensor<T>.tanh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
|
||||
public fun Tensor<T>.atanh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
|
||||
public fun Tensor<T>.ceil(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
|
||||
public fun Tensor<T>.floor(): Tensor<T>
|
||||
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
/**
|
||||
* Common linear algebra operations. Operates on [Tensor].
|
||||
*
|
||||
* @param T the type of items closed under division in the tensors.
|
||||
*/
|
||||
public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
|
||||
|
||||
/**
|
||||
* Computes the determinant of a square matrix input, or of each square matrix in a batched input.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
||||
*
|
||||
* @return the determinant.
|
||||
*/
|
||||
public fun Tensor<T>.det(): Tensor<T>
|
||||
|
||||
/**
|
||||
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
|
||||
* Given a square matrix `A`, return the matrix `AInv` satisfying
|
||||
* `A dot AInv = AInv dot A = eye(a.shape[0])`.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
||||
*
|
||||
* @return the multiplicative inverse of a matrix.
|
||||
*/
|
||||
public fun Tensor<T>.inv(): Tensor<T>
|
||||
|
||||
/**
|
||||
* Cholesky decomposition.
|
||||
*
|
||||
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
|
||||
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
|
||||
* Each decomposition has the form:
|
||||
* Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`,
|
||||
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
|
||||
* which is just a transpose for the case of real-valued input matrices.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
|
||||
*
|
||||
* @return the batch of L matrices.
|
||||
*/
|
||||
public fun Tensor<T>.cholesky(): Tensor<T>
|
||||
|
||||
/**
|
||||
* QR decomposition.
|
||||
*
|
||||
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
|
||||
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``,
|
||||
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
|
||||
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||
*
|
||||
* @return pair of Q and R tensors.
|
||||
*/
|
||||
public fun Tensor<T>.qr(): Pair<Tensor<T>, Tensor<T>>
|
||||
|
||||
/**
|
||||
* LUP decomposition
|
||||
*
|
||||
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
||||
* Given a tensor `input`, return tensors (P, L, U) satisfying `P dot input = L dot U`,
|
||||
* with `P` being a permutation matrix or batch of matrices,
|
||||
* `L` being a lower triangular matrix or batch of matrices,
|
||||
* `U` being an upper triangular matrix or batch of matrices.
|
||||
*
|
||||
* * @return triple of P, L and U tensors
|
||||
*/
|
||||
public fun Tensor<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
|
||||
|
||||
/**
|
||||
* Singular Value Decomposition.
|
||||
*
|
||||
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
|
||||
* The singular value decomposition is represented as a triple `(U, S, V)`,
|
||||
* such that `input = U dot diagonalEmbedding(S) dot V.H`,
|
||||
* where V.H is the conjugate transpose of V.
|
||||
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
||||
*
|
||||
* @return triple `(U, S, V)`.
|
||||
*/
|
||||
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
|
||||
|
||||
/**
|
||||
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
|
||||
* represented by a pair (eigenvalues, eigenvectors).
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.symeig.html
|
||||
*
|
||||
* @return a pair (eigenvalues, eigenvectors)
|
||||
*/
|
||||
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
|
||||
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
import space.kscience.kmath.nd.MutableStructureND
|
||||
|
||||
public typealias Tensor<T> = MutableStructureND<T>
|
@ -0,0 +1,327 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
|
||||
/**
|
||||
* Algebra over a ring on [Tensor].
|
||||
* For more information: https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||
*
|
||||
* @param T the type of items in the tensors.
|
||||
*/
|
||||
public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
||||
*
|
||||
* @return a nullable value of a potentially scalar tensor.
|
||||
*/
|
||||
public fun Tensor<T>.valueOrNull(): T?
|
||||
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
|
||||
*
|
||||
* @return the value of a scalar tensor.
|
||||
*/
|
||||
public fun Tensor<T>.value(): T
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is added to this value.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to be added.
|
||||
* @return the sum of this value and tensor [other].
|
||||
*/
|
||||
public operator fun T.plus(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor.
|
||||
*
|
||||
* @param value the number to be added to each element of this tensor.
|
||||
* @return the sum of this tensor and [value].
|
||||
*/
|
||||
public operator fun Tensor<T>.plus(value: T): Tensor<T>
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is added to each element of this tensor.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to be added.
|
||||
* @return the sum of this tensor and [other].
|
||||
*/
|
||||
public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Adds the scalar [value] to each element of this tensor.
|
||||
*
|
||||
* @param value the number to be added to each element of this tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.plusAssign(value: T): Unit
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is added to each element of this tensor.
|
||||
*
|
||||
* @param other tensor to be added.
|
||||
*/
|
||||
public operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit
|
||||
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is subtracted from this value.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to be subtracted.
|
||||
* @return the difference between this value and tensor [other].
|
||||
*/
|
||||
public operator fun T.minus(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor.
|
||||
*
|
||||
* @param value the number to be subtracted from each element of this tensor.
|
||||
* @return the difference between this tensor and [value].
|
||||
*/
|
||||
public operator fun Tensor<T>.minus(value: T): Tensor<T>
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is subtracted from each element of this tensor.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to be subtracted.
|
||||
* @return the difference between this tensor and [other].
|
||||
*/
|
||||
public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Subtracts the scalar [value] from each element of this tensor.
|
||||
*
|
||||
* @param value the number to be subtracted from each element of this tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.minusAssign(value: T): Unit
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is subtracted from each element of this tensor.
|
||||
*
|
||||
* @param other tensor to be subtracted.
|
||||
*/
|
||||
public operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit
|
||||
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is multiplied by this value.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to be multiplied.
|
||||
* @return the product of this value and tensor [other].
|
||||
*/
|
||||
public operator fun T.times(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor.
|
||||
*
|
||||
* @param value the number to be multiplied by each element of this tensor.
|
||||
* @return the product of this tensor and [value].
|
||||
*/
|
||||
public operator fun Tensor<T>.times(value: T): Tensor<T>
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is multiplied by each element of this tensor.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to be multiplied.
|
||||
* @return the product of this tensor and [other].
|
||||
*/
|
||||
public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Multiplies the scalar [value] by each element of this tensor.
|
||||
*
|
||||
* @param value the number to be multiplied by each element of this tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.timesAssign(value: T): Unit
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is multiplied by each element of this tensor.
|
||||
*
|
||||
* @param other tensor to be multiplied.
|
||||
*/
|
||||
public operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit
|
||||
|
||||
/**
|
||||
* Numerical negative, element-wise.
|
||||
*
|
||||
* @return tensor negation of the original tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.unaryMinus(): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns the tensor at index i
|
||||
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||
*
|
||||
* @param i index of the extractable tensor
|
||||
* @return subtensor of the original tensor with index [i]
|
||||
*/
|
||||
public operator fun Tensor<T>.get(i: Int): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||
*
|
||||
* @param i the first dimension to be transposed
|
||||
* @param j the second dimension to be transposed
|
||||
* @return transposed tensor
|
||||
*/
|
||||
public fun Tensor<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns a new tensor with the same data as the self tensor but of a different shape.
|
||||
* The returned tensor shares the same data and must have the same number of elements, but may have a different size
|
||||
* For more information: https://pytorch.org/docs/stable/tensor_view.html
|
||||
*
|
||||
* @param shape the desired size
|
||||
* @return tensor with new shape
|
||||
*/
|
||||
public fun Tensor<T>.view(shape: IntArray): Tensor<T>
|
||||
|
||||
/**
|
||||
* View this tensor as the same size as [other].
|
||||
* ``this.viewAs(other) is equivalent to this.view(other.shape)``.
|
||||
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||
*
|
||||
* @param other the result tensor has the same size as other.
|
||||
* @return the result tensor with the same size as other.
|
||||
*/
|
||||
public fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Matrix product of two tensors.
|
||||
*
|
||||
* The behavior depends on the dimensionality of the tensors as follows:
|
||||
* 1. If both tensors are 1-dimensional, the dot product (scalar) is returned.
|
||||
*
|
||||
* 2. If both arguments are 2-dimensional, the matrix-matrix product is returned.
|
||||
*
|
||||
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
|
||||
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
|
||||
* After the matrix multiply, the prepended dimension is removed.
|
||||
*
|
||||
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
|
||||
* the matrix-vector product is returned.
|
||||
*
|
||||
* 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2),
|
||||
* then a batched matrix multiply is returned. If the first argument is 1-dimensional,
|
||||
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
|
||||
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
|
||||
* multiple and removed after.
|
||||
* The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).
|
||||
* For example, if `input` is a (j × 1 × n × n) tensor and `other` is a
|
||||
* (k × n × n) tensor, out will be a (j × k × n × n) tensor.
|
||||
*
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
*
|
||||
* @param other tensor to be multiplied
|
||||
* @return mathematical product of two tensors
|
||||
*/
|
||||
public infix fun Tensor<T>.dot(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
|
||||
* are filled by [diagonalEntries].
|
||||
* To facilitate creating batched diagonal matrices,
|
||||
* the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
|
||||
*
|
||||
* The argument [offset] controls which diagonal to consider:
|
||||
* 1. If [offset] = 0, it is the main diagonal.
|
||||
* 1. If [offset] > 0, it is above the main diagonal.
|
||||
* 1. If [offset] < 0, it is below the main diagonal.
|
||||
*
|
||||
* The size of the new matrix will be calculated
|
||||
* to make the specified diagonal of the size of the last input dimension.
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||
*
|
||||
* @param diagonalEntries the input tensor. Must be at least 1-dimensional.
|
||||
* @param offset which diagonal to consider. Default: 0 (main diagonal).
|
||||
* @param dim1 first dimension with respect to which to take diagonal. Default: -2.
|
||||
* @param dim2 second dimension with respect to which to take diagonal. Default: -1.
|
||||
*
|
||||
* @return tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
|
||||
* are filled by [diagonalEntries]
|
||||
*/
|
||||
public fun diagonalEmbedding(
|
||||
diagonalEntries: Tensor<T>,
|
||||
offset: Int = 0,
|
||||
dim1: Int = -2,
|
||||
dim2: Int = -1
|
||||
): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the sum of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.sum(): T
|
||||
|
||||
/**
|
||||
* Returns the sum of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the sum of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the minimum value of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.min(): T
|
||||
|
||||
/**
|
||||
* Returns the minimum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the minimum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns the maximum value of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.max(): T
|
||||
|
||||
/**
|
||||
* Returns the maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
}
|
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
/**
|
||||
* Algebra over a field with partial division on [Tensor].
|
||||
* For more information: https://proofwiki.org/wiki/Definition:Division_Algebra
|
||||
*
|
||||
* @param T the type of items closed under division in the tensors.
|
||||
*/
|
||||
public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> {
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is divided by this value.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to divide by.
|
||||
* @return the division of this value by the tensor [other].
|
||||
*/
|
||||
public operator fun T.div(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Divide by the scalar [value] each element of this tensor returns a new resulting tensor.
|
||||
*
|
||||
* @param value the number to divide by each element of this tensor.
|
||||
* @return the division of this tensor by the [value].
|
||||
*/
|
||||
public operator fun Tensor<T>.div(value: T): Tensor<T>
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is divided by each element of this tensor.
|
||||
* The resulting tensor is returned.
|
||||
*
|
||||
* @param other tensor to be divided by.
|
||||
* @return the division of this tensor by [other].
|
||||
*/
|
||||
public operator fun Tensor<T>.div(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Divides by the scalar [value] each element of this tensor.
|
||||
*
|
||||
* @param value the number to divide by each element of this tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.divAssign(value: T)
|
||||
|
||||
/**
|
||||
* Each element of this tensor is divided by each element of the [other] tensor.
|
||||
*
|
||||
* @param other tensor to be divide by.
|
||||
*/
|
||||
public operator fun Tensor<T>.divAssign(other: Tensor<T>)
|
||||
}
|
@ -0,0 +1,93 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.internal.array
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastTensors
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastTo
|
||||
import space.kscience.kmath.tensors.core.internal.tensor
|
||||
|
||||
/**
|
||||
* Basic linear algebra operations implemented with broadcasting.
|
||||
* For more information: https://pytorch.org/docs/stable/notes/broadcasting.html
|
||||
*/
|
||||
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
|
||||
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[newThis.bufferStart + i] *
|
||||
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
|
||||
val broadcast = broadcastTensors(tensor, other.tensor)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
|
||||
newThis.mutableBuffer.array()[newOther.bufferStart + i] /
|
||||
newOther.mutableBuffer.array()[newOther.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
|
||||
val newOther = broadcastTo(other.tensor, tensor.shape)
|
||||
for (i in 0 until tensor.linearStructure.linearSize) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
||||
newOther.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.Strides
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
|
||||
|
||||
/**
|
||||
* Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor]
|
||||
*/
|
||||
public open class BufferedTensor<T> internal constructor(
|
||||
override val shape: IntArray,
|
||||
internal val mutableBuffer: MutableBuffer<T>,
|
||||
internal val bufferStart: Int
|
||||
) : Tensor<T> {
|
||||
|
||||
/**
|
||||
* Buffer strides based on [TensorLinearStructure] implementation
|
||||
*/
|
||||
public val linearStructure: Strides
|
||||
get() = TensorLinearStructure(shape)
|
||||
|
||||
/**
|
||||
* Number of elements in tensor
|
||||
*/
|
||||
public val numElements: Int
|
||||
get() = linearStructure.linearSize
|
||||
|
||||
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
|
||||
|
||||
override fun set(index: IntArray, value: T) {
|
||||
mutableBuffer[bufferStart + linearStructure.offset(index)] = value
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map {
|
||||
it to this[it]
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import space.kscience.kmath.tensors.core.internal.toPrettyString
|
||||
|
||||
/**
|
||||
* Default [BufferedTensor] implementation for [Double] values
|
||||
*/
|
||||
public class DoubleTensor internal constructor(
|
||||
shape: IntArray,
|
||||
buffer: DoubleArray,
|
||||
offset: Int = 0
|
||||
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) {
|
||||
override fun toString(): String = toPrettyString()
|
||||
}
|
@ -0,0 +1,937 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.internal.dotHelper
|
||||
import space.kscience.kmath.tensors.core.internal.getRandomNormals
|
||||
import space.kscience.kmath.tensors.core.internal.*
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
|
||||
import space.kscience.kmath.tensors.core.internal.checkBufferShapeConsistency
|
||||
import space.kscience.kmath.tensors.core.internal.checkEmptyDoubleBuffer
|
||||
import space.kscience.kmath.tensors.core.internal.checkEmptyShape
|
||||
import space.kscience.kmath.tensors.core.internal.checkShapesCompatible
|
||||
import space.kscience.kmath.tensors.core.internal.checkSquareMatrix
|
||||
import space.kscience.kmath.tensors.core.internal.checkTranspose
|
||||
import space.kscience.kmath.tensors.core.internal.checkView
|
||||
import space.kscience.kmath.tensors.core.internal.minusIndexFrom
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* Implementation of basic operations over double tensors and basic algebra operations on them.
|
||||
*/
|
||||
public open class DoubleTensorAlgebra :
|
||||
TensorPartialDivisionAlgebra<Double>,
|
||||
AnalyticTensorAlgebra<Double>,
|
||||
LinearOpsTensorAlgebra<Double> {
|
||||
|
||||
public companion object : DoubleTensorAlgebra()
|
||||
|
||||
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart] else null
|
||||
|
||||
override fun Tensor<Double>.value(): Double =
|
||||
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
|
||||
|
||||
/**
|
||||
* Constructs a tensor with the specified shape and data.
|
||||
*
|
||||
* @param shape the desired shape for the tensor.
|
||||
* @param buffer one-dimensional data array.
|
||||
* @return tensor with the [shape] shape and [buffer] data.
|
||||
*/
|
||||
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
|
||||
checkEmptyShape(shape)
|
||||
checkEmptyDoubleBuffer(buffer)
|
||||
checkBufferShapeConsistency(shape, buffer)
|
||||
return DoubleTensor(shape, buffer, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a tensor with the specified shape and initializer.
|
||||
*
|
||||
* @param shape the desired shape for the tensor.
|
||||
* @param initializer mapping tensor indices to values.
|
||||
* @return tensor with the [shape] shape and data generated by the [initializer].
|
||||
*/
|
||||
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor =
|
||||
fromArray(
|
||||
shape,
|
||||
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray()
|
||||
)
|
||||
|
||||
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
|
||||
val lastShape = tensor.shape.drop(1).toIntArray()
|
||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
|
||||
return DoubleTensor(newShape, tensor.mutableBuffer.array(), newStart)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a tensor of a given shape and fills all elements with a given value.
|
||||
*
|
||||
* @param value the value to fill the output tensor with.
|
||||
* @param shape array of integers defining the shape of the output tensor.
|
||||
* @return tensor with the [shape] shape and filled with [value].
|
||||
*/
|
||||
public fun full(value: Double, shape: IntArray): DoubleTensor {
|
||||
checkEmptyShape(shape)
|
||||
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
||||
return DoubleTensor(shape, buffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a tensor with the same shape as `input` filled with [value].
|
||||
*
|
||||
* @param value the value to fill the output tensor with.
|
||||
* @return tensor with the `input` tensor shape and filled with [value].
|
||||
*/
|
||||
public fun Tensor<Double>.fullLike(value: Double): DoubleTensor {
|
||||
val shape = tensor.shape
|
||||
val buffer = DoubleArray(tensor.numElements) { value }
|
||||
return DoubleTensor(shape, buffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 0.0, with the shape defined by the variable argument [shape].
|
||||
*
|
||||
* @param shape array of integers defining the shape of the output tensor.
|
||||
* @return tensor filled with the scalar value 0.0, with the [shape] shape.
|
||||
*/
|
||||
public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 0.0, with the same shape as a given array.
|
||||
*
|
||||
* @return tensor filled with the scalar value 0.0, with the same shape as `input` tensor.
|
||||
*/
|
||||
public fun Tensor<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0)
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 1.0, with the shape defined by the variable argument [shape].
|
||||
*
|
||||
* @param shape array of integers defining the shape of the output tensor.
|
||||
* @return tensor filled with the scalar value 1.0, with the [shape] shape.
|
||||
*/
|
||||
public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 1.0, with the same shape as a given array.
|
||||
*
|
||||
* @return tensor filled with the scalar value 1.0, with the same shape as `input` tensor.
|
||||
*/
|
||||
public fun Tensor<Double>.onesLike(): DoubleTensor = tensor.fullLike(1.0)
|
||||
|
||||
/**
|
||||
* Returns a 2-D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere.
|
||||
*
|
||||
* @param n the number of rows and columns
|
||||
* @return a 2-D tensor with ones on the diagonal and zeros elsewhere.
|
||||
*/
|
||||
public fun eye(n: Int): DoubleTensor {
|
||||
val shape = intArrayOf(n, n)
|
||||
val buffer = DoubleArray(n * n) { 0.0 }
|
||||
val res = DoubleTensor(shape, buffer)
|
||||
for (i in 0 until n) {
|
||||
res[intArrayOf(i, i)] = 1.0
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a copy of the tensor.
|
||||
*
|
||||
* @return a copy of the `input` tensor with a copied buffer.
|
||||
*/
|
||||
public fun Tensor<Double>.copy(): DoubleTensor {
|
||||
return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
|
||||
}
|
||||
|
||||
override fun Double.plus(other: Tensor<Double>): DoubleTensor {
|
||||
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
|
||||
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this
|
||||
}
|
||||
return DoubleTensor(other.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.plus(value: Double): DoubleTensor = value + tensor
|
||||
|
||||
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
|
||||
checkShapesCompatible(tensor, other.tensor)
|
||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||
tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(tensor.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.plusAssign(value: Double) {
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] += value
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
|
||||
checkShapesCompatible(tensor, other.tensor)
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
|
||||
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Double.minus(other: Tensor<Double>): DoubleTensor {
|
||||
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
|
||||
this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(other.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.minus(value: Double): DoubleTensor {
|
||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] - value
|
||||
}
|
||||
return DoubleTensor(tensor.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
|
||||
checkShapesCompatible(tensor, other)
|
||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||
tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(tensor.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.minusAssign(value: Double) {
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
|
||||
checkShapesCompatible(tensor, other)
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
|
||||
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Double.times(other: Tensor<Double>): DoubleTensor {
|
||||
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
|
||||
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this
|
||||
}
|
||||
return DoubleTensor(other.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.times(value: Double): DoubleTensor = value * tensor
|
||||
|
||||
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
|
||||
checkShapesCompatible(tensor, other)
|
||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] *
|
||||
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(tensor.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.timesAssign(value: Double) {
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
|
||||
checkShapesCompatible(tensor, other)
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
|
||||
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Double.div(other: Tensor<Double>): DoubleTensor {
|
||||
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
|
||||
this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(other.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.div(value: Double): DoubleTensor {
|
||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] / value
|
||||
}
|
||||
return DoubleTensor(shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
|
||||
checkShapesCompatible(tensor, other)
|
||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||
tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
|
||||
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(tensor.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.divAssign(value: Double) {
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
|
||||
checkShapesCompatible(tensor, other)
|
||||
for (i in 0 until tensor.numElements) {
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
||||
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.unaryMinus(): DoubleTensor {
|
||||
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
|
||||
}
|
||||
return DoubleTensor(tensor.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor {
|
||||
val ii = tensor.minusIndex(i)
|
||||
val jj = tensor.minusIndex(j)
|
||||
checkTranspose(tensor.dimension, ii, jj)
|
||||
val n = tensor.numElements
|
||||
val resBuffer = DoubleArray(n)
|
||||
|
||||
val resShape = tensor.shape.copyOf()
|
||||
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
|
||||
|
||||
val resTensor = DoubleTensor(resShape, resBuffer)
|
||||
|
||||
for (offset in 0 until n) {
|
||||
val oldMultiIndex = tensor.linearStructure.index(offset)
|
||||
val newMultiIndex = oldMultiIndex.copyOf()
|
||||
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
|
||||
|
||||
val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
|
||||
resTensor.mutableBuffer.array()[linearIndex] =
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + offset]
|
||||
}
|
||||
return resTensor
|
||||
}
|
||||
|
||||
|
||||
override fun Tensor<Double>.view(shape: IntArray): DoubleTensor {
|
||||
checkView(tensor, shape)
|
||||
return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.viewAs(other: Tensor<Double>): DoubleTensor =
|
||||
tensor.view(other.shape)
|
||||
|
||||
override infix fun Tensor<Double>.dot(other: Tensor<Double>): DoubleTensor {
|
||||
if (tensor.shape.size == 1 && other.shape.size == 1) {
|
||||
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
|
||||
}
|
||||
|
||||
var newThis = tensor.copy()
|
||||
var newOther = other.copy()
|
||||
|
||||
var penultimateDim = false
|
||||
var lastDim = false
|
||||
if (tensor.shape.size == 1) {
|
||||
penultimateDim = true
|
||||
newThis = tensor.view(intArrayOf(1) + tensor.shape)
|
||||
}
|
||||
if (other.shape.size == 1) {
|
||||
lastDim = true
|
||||
newOther = other.tensor.view(other.shape + intArrayOf(1))
|
||||
}
|
||||
|
||||
val broadcastTensors = broadcastOuterTensors(newThis.tensor, newOther.tensor)
|
||||
newThis = broadcastTensors[0]
|
||||
newOther = broadcastTensors[1]
|
||||
|
||||
val l = newThis.shape[newThis.shape.size - 2]
|
||||
val m1 = newThis.shape[newThis.shape.size - 1]
|
||||
val m2 = newOther.shape[newOther.shape.size - 2]
|
||||
val n = newOther.shape[newOther.shape.size - 1]
|
||||
check(m1 == m2) {
|
||||
"Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)"
|
||||
}
|
||||
|
||||
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
|
||||
val resSize = resShape.reduce { acc, i -> acc * i }
|
||||
val resTensor = DoubleTensor(resShape, DoubleArray(resSize))
|
||||
|
||||
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
||||
val (a, b) = ab
|
||||
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
||||
}
|
||||
|
||||
if (penultimateDim) {
|
||||
return resTensor.view(
|
||||
resTensor.shape.dropLast(2).toIntArray() +
|
||||
intArrayOf(resTensor.shape.last())
|
||||
)
|
||||
}
|
||||
if (lastDim) {
|
||||
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
||||
}
|
||||
return resTensor
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int):
|
||||
DoubleTensor {
|
||||
val n = diagonalEntries.shape.size
|
||||
val d1 = minusIndexFrom(n + 1, dim1)
|
||||
val d2 = minusIndexFrom(n + 1, dim2)
|
||||
|
||||
check(d1 != d2) {
|
||||
"Diagonal dimensions cannot be identical $d1, $d2"
|
||||
}
|
||||
check(d1 <= n && d2 <= n) {
|
||||
"Dimension out of range"
|
||||
}
|
||||
|
||||
var lessDim = d1
|
||||
var greaterDim = d2
|
||||
var realOffset = offset
|
||||
if (lessDim > greaterDim) {
|
||||
realOffset *= -1
|
||||
lessDim = greaterDim.also { greaterDim = lessDim }
|
||||
}
|
||||
|
||||
val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() +
|
||||
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||
diagonalEntries.shape.slice(lessDim until greaterDim - 1).toIntArray() +
|
||||
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
|
||||
val resTensor = zeros(resShape)
|
||||
|
||||
for (i in 0 until diagonalEntries.tensor.numElements) {
|
||||
val multiIndex = diagonalEntries.tensor.linearStructure.index(i)
|
||||
|
||||
var offset1 = 0
|
||||
var offset2 = abs(realOffset)
|
||||
if (realOffset < 0) {
|
||||
offset1 = offset2.also { offset2 = offset1 }
|
||||
}
|
||||
val diagonalMultiIndex = multiIndex.slice(0 until lessDim).toIntArray() +
|
||||
intArrayOf(multiIndex[n - 1] + offset1) +
|
||||
multiIndex.slice(lessDim until greaterDim - 1).toIntArray() +
|
||||
intArrayOf(multiIndex[n - 1] + offset2) +
|
||||
multiIndex.slice(greaterDim - 1 until n - 1).toIntArray()
|
||||
|
||||
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
|
||||
}
|
||||
|
||||
return resTensor.tensor
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
|
||||
*
|
||||
* @param transform the function to be applied to each element of the tensor.
|
||||
* @return the resulting tensor after applying the function.
|
||||
*/
|
||||
public fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor {
|
||||
return DoubleTensor(
|
||||
tensor.shape,
|
||||
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
|
||||
tensor.bufferStart
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Compares element-wise two tensors with a specified precision.
|
||||
*
|
||||
* @param other the tensor to compare with `input` tensor.
|
||||
* @param epsilon permissible error when comparing two Double values.
|
||||
* @return true if two tensors have the same shape and elements, false otherwise.
|
||||
*/
|
||||
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
|
||||
tensor.eq(other) { x, y -> abs(x - y) < epsilon }
|
||||
|
||||
/**
|
||||
* Compares element-wise two tensors.
|
||||
* Comparison of two Double values occurs with 1e-5 precision.
|
||||
*
|
||||
* @param other the tensor to compare with `input` tensor.
|
||||
* @return true if two tensors have the same shape and elements, false otherwise.
|
||||
*/
|
||||
public infix fun Tensor<Double>.eq(other: Tensor<Double>): Boolean = tensor.eq(other, 1e-5)
|
||||
|
||||
private fun Tensor<Double>.eq(
|
||||
other: Tensor<Double>,
|
||||
eqFunction: (Double, Double) -> Boolean
|
||||
): Boolean {
|
||||
checkShapesCompatible(tensor, other)
|
||||
val n = tensor.numElements
|
||||
if (n != other.tensor.numElements) {
|
||||
return false
|
||||
}
|
||||
for (i in 0 until n) {
|
||||
if (!eqFunction(
|
||||
tensor.mutableBuffer[tensor.bufferStart + i],
|
||||
other.tensor.mutableBuffer[other.tensor.bufferStart + i]
|
||||
)
|
||||
) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a tensor of random numbers drawn from normal distributions with 0.0 mean and 1.0 standard deviation.
|
||||
*
|
||||
* @param shape the desired shape for the output tensor.
|
||||
* @param seed the random seed of the pseudo-random number generator.
|
||||
* @return tensor of a given shape filled with numbers from the normal distribution
|
||||
* with 0.0 mean and 1.0 standard deviation.
|
||||
*/
|
||||
public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
|
||||
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
|
||||
|
||||
/**
|
||||
* Returns a tensor with the same shape as `input` of random numbers drawn from normal distributions
|
||||
* with 0.0 mean and 1.0 standard deviation.
|
||||
*
|
||||
* @param seed the random seed of the pseudo-random number generator.
|
||||
* @return tensor with the same shape as `input` filled with numbers from the normal distribution
|
||||
* with 0.0 mean and 1.0 standard deviation.
|
||||
*/
|
||||
public fun Tensor<Double>.randomNormalLike(seed: Long = 0): DoubleTensor =
|
||||
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
|
||||
|
||||
/**
|
||||
* Concatenates a sequence of tensors with equal shapes along the first dimension.
|
||||
*
|
||||
* @param tensors the [List] of tensors with same shapes to concatenate
|
||||
* @return tensor with concatenation result
|
||||
*/
|
||||
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
|
||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||
val shape = tensors[0].shape
|
||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
||||
val resShape = intArrayOf(tensors.size) + shape
|
||||
val resBuffer = tensors.flatMap {
|
||||
it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements)
|
||||
}.toDoubleArray()
|
||||
return DoubleTensor(resShape, resBuffer, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds tensor from rows of input tensor
|
||||
*
|
||||
* @param indices the [IntArray] of 1-dimensional indices
|
||||
* @return tensor with rows corresponding to rows by [indices]
|
||||
*/
|
||||
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
|
||||
return stack(indices.map { this[it] })
|
||||
}
|
||||
|
||||
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
||||
foldFunction(tensor.toDoubleArray())
|
||||
|
||||
internal fun Tensor<Double>.foldDim(
|
||||
foldFunction: (DoubleArray) -> Double,
|
||||
dim: Int,
|
||||
keepDim: Boolean
|
||||
): DoubleTensor {
|
||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||
val resShape = if (keepDim) {
|
||||
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||
} else {
|
||||
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||
}
|
||||
val resNumElements = resShape.reduce(Int::times)
|
||||
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
|
||||
for (index in resTensor.linearStructure.indices()) {
|
||||
val prefix = index.take(dim).toIntArray()
|
||||
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
||||
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
|
||||
tensor[prefix + intArrayOf(i) + suffix]
|
||||
})
|
||||
}
|
||||
|
||||
return resTensor
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
|
||||
|
||||
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.sum() }, dim, keepDim)
|
||||
|
||||
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
|
||||
|
||||
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
||||
|
||||
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x ->
|
||||
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
|
||||
}, dim, keepDim)
|
||||
|
||||
|
||||
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
|
||||
|
||||
override fun Tensor<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim(
|
||||
{ arr ->
|
||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||
arr.sum() / shape[dim]
|
||||
},
|
||||
dim,
|
||||
keepDim
|
||||
)
|
||||
|
||||
override fun Tensor<Double>.std(): Double = this.fold { arr ->
|
||||
val mean = arr.sum() / tensor.numElements
|
||||
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1))
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
|
||||
{ arr ->
|
||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||
val mean = arr.sum() / shape[dim]
|
||||
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1))
|
||||
},
|
||||
dim,
|
||||
keepDim
|
||||
)
|
||||
|
||||
override fun Tensor<Double>.variance(): Double = this.fold { arr ->
|
||||
val mean = arr.sum() / tensor.numElements
|
||||
arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
|
||||
{ arr ->
|
||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||
val mean = arr.sum() / shape[dim]
|
||||
arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1)
|
||||
},
|
||||
dim,
|
||||
keepDim
|
||||
)
|
||||
|
||||
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
|
||||
val n = x.shape[0]
|
||||
return ((x - x.mean()) * (y - y.mean())).mean() * n / (n - 1)
|
||||
}
|
||||
|
||||
override fun cov(tensors: List<Tensor<Double>>): DoubleTensor {
|
||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||
val n = tensors.size
|
||||
val m = tensors[0].shape[0]
|
||||
check(tensors.all { it.shape contentEquals intArrayOf(m) }) { "Tensors must have same shapes" }
|
||||
val resTensor = DoubleTensor(
|
||||
intArrayOf(n, n),
|
||||
DoubleArray(n * n) { 0.0 }
|
||||
)
|
||||
for (i in 0 until n) {
|
||||
for (j in 0 until n) {
|
||||
resTensor[intArrayOf(i, j)] = cov(tensors[i].tensor, tensors[j].tensor)
|
||||
}
|
||||
}
|
||||
return resTensor
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.exp(): DoubleTensor = tensor.map(::exp)
|
||||
|
||||
override fun Tensor<Double>.ln(): DoubleTensor = tensor.map(::ln)
|
||||
|
||||
override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map(::sqrt)
|
||||
|
||||
override fun Tensor<Double>.cos(): DoubleTensor = tensor.map(::cos)
|
||||
|
||||
override fun Tensor<Double>.acos(): DoubleTensor = tensor.map(::acos)
|
||||
|
||||
override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map(::cosh)
|
||||
|
||||
override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map(::acosh)
|
||||
|
||||
override fun Tensor<Double>.sin(): DoubleTensor = tensor.map(::sin)
|
||||
|
||||
override fun Tensor<Double>.asin(): DoubleTensor = tensor.map(::asin)
|
||||
|
||||
override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map(::sinh)
|
||||
|
||||
override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map(::asinh)
|
||||
|
||||
override fun Tensor<Double>.tan(): DoubleTensor = tensor.map(::tan)
|
||||
|
||||
override fun Tensor<Double>.atan(): DoubleTensor = tensor.map(::atan)
|
||||
|
||||
override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map(::tanh)
|
||||
|
||||
override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map(::atanh)
|
||||
|
||||
override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map(::ceil)
|
||||
|
||||
override fun Tensor<Double>.floor(): DoubleTensor = tensor.map(::floor)
|
||||
|
||||
override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9)
|
||||
|
||||
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9)
|
||||
|
||||
/**
|
||||
* Computes the LU factorization of a matrix or batches of matrices `input`.
|
||||
* Returns a tuple containing the LU factorization and pivots of `input`.
|
||||
*
|
||||
* @param epsilon permissible error when comparing the determinant of a matrix with zero
|
||||
* @return pair of `factorization` and `pivots`.
|
||||
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
|
||||
* The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
|
||||
*/
|
||||
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
||||
computeLU(tensor, epsilon)
|
||||
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
|
||||
|
||||
/**
|
||||
* Computes the LU factorization of a matrix or batches of matrices `input`.
|
||||
* Returns a tuple containing the LU factorization and pivots of `input`.
|
||||
* Uses an error of ``1e-9`` when calculating whether a matrix is degenerate.
|
||||
*
|
||||
* @return pair of `factorization` and `pivots`.
|
||||
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
|
||||
* The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
|
||||
*/
|
||||
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
|
||||
|
||||
/**
|
||||
* Unpacks the data and pivots from a LU factorization of a tensor.
|
||||
* Given a tensor [luTensor], return tensors (P, L, U) satisfying ``P * luTensor = L * U``,
|
||||
* with `P` being a permutation matrix or batch of matrices,
|
||||
* `L` being a lower triangular matrix or batch of matrices,
|
||||
* `U` being an upper triangular matrix or batch of matrices.
|
||||
*
|
||||
* @param luTensor the packed LU factorization data
|
||||
* @param pivotsTensor the packed LU factorization pivots
|
||||
* @return triple of P, L and U tensors
|
||||
*/
|
||||
public fun luPivot(
|
||||
luTensor: Tensor<Double>,
|
||||
pivotsTensor: Tensor<Int>
|
||||
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(luTensor.shape)
|
||||
check(
|
||||
luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() ||
|
||||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
|
||||
) { "Inappropriate shapes of input tensors" }
|
||||
|
||||
val n = luTensor.shape.last()
|
||||
val pTensor = luTensor.zeroesLike()
|
||||
pTensor
|
||||
.matrixSequence()
|
||||
.zip(pivotsTensor.tensor.vectorSequence())
|
||||
.forEach { (p, pivot) -> pivInit(p.as2D(), pivot.as1D(), n) }
|
||||
|
||||
val lTensor = luTensor.zeroesLike()
|
||||
val uTensor = luTensor.zeroesLike()
|
||||
|
||||
lTensor.matrixSequence()
|
||||
.zip(uTensor.matrixSequence())
|
||||
.zip(luTensor.tensor.matrixSequence())
|
||||
.forEach { (pairLU, lu) ->
|
||||
val (l, u) = pairLU
|
||||
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
|
||||
}
|
||||
|
||||
return Triple(pTensor, lTensor, uTensor)
|
||||
}
|
||||
|
||||
/**
|
||||
* QR decomposition.
|
||||
*
|
||||
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
|
||||
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
|
||||
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
|
||||
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
|
||||
*
|
||||
* @param epsilon permissible error when comparing tensors for equality.
|
||||
* Used when checking the positive definiteness of the input matrix or matrices.
|
||||
* @return pair of Q and R tensors.
|
||||
*/
|
||||
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor {
|
||||
checkSquareMatrix(shape)
|
||||
checkPositiveDefinite(tensor, epsilon)
|
||||
|
||||
val n = shape.last()
|
||||
val lTensor = zeroesLike()
|
||||
|
||||
for ((a, l) in tensor.matrixSequence().zip(lTensor.matrixSequence()))
|
||||
for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n)
|
||||
|
||||
return lTensor
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.cholesky(): DoubleTensor = cholesky(1e-6)
|
||||
|
||||
override fun Tensor<Double>.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(shape)
|
||||
val qTensor = zeroesLike()
|
||||
val rTensor = zeroesLike()
|
||||
tensor.matrixSequence()
|
||||
.zip(
|
||||
(qTensor.matrixSequence()
|
||||
.zip(rTensor.matrixSequence()))
|
||||
).forEach { (matrix, qr) ->
|
||||
val (q, r) = qr
|
||||
qrHelper(matrix.asTensor(), q.asTensor(), r.as2D())
|
||||
}
|
||||
|
||||
return qTensor to rTensor
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
|
||||
svd(epsilon = 1e-10)
|
||||
|
||||
/**
|
||||
* Singular Value Decomposition.
|
||||
*
|
||||
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
|
||||
* The singular value decomposition is represented as a triple `(U, S, V)`,
|
||||
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
|
||||
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
||||
*
|
||||
* @param epsilon permissible error when calculating the dot product of vectors,
|
||||
* i.e. the precision with which the cosine approaches 1 in an iterative algorithm.
|
||||
* @return triple `(U, S, V)`.
|
||||
*/
|
||||
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
val size = tensor.dimension
|
||||
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
||||
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
||||
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
|
||||
val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
|
||||
val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
|
||||
|
||||
tensor.matrixSequence()
|
||||
.zip(
|
||||
uTensor.matrixSequence()
|
||||
.zip(
|
||||
sTensor.vectorSequence()
|
||||
.zip(vTensor.matrixSequence())
|
||||
)
|
||||
).forEach { (matrix, USV) ->
|
||||
val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
|
||||
val curMatrix = DoubleTensor(
|
||||
matrix.shape,
|
||||
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
|
||||
.toDoubleArray()
|
||||
)
|
||||
svdHelper(curMatrix, USV, m, n, epsilon)
|
||||
}
|
||||
|
||||
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
|
||||
symEig(epsilon = 1e-15)
|
||||
|
||||
/**
|
||||
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
|
||||
* represented by a pair (eigenvalues, eigenvectors).
|
||||
*
|
||||
* @param epsilon permissible error when comparing tensors for equality
|
||||
* and when the cosine approaches 1 in the SVD algorithm.
|
||||
* @return a pair (eigenvalues, eigenvectors)
|
||||
*/
|
||||
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
|
||||
checkSymmetric(tensor, epsilon)
|
||||
val (u, s, v) = tensor.svd(epsilon)
|
||||
val shp = s.shape + intArrayOf(1)
|
||||
val utv = u.transpose() dot v
|
||||
val n = s.shape.last()
|
||||
for (matrix in utv.matrixSequence())
|
||||
cleanSymHelper(matrix.as2D(), n)
|
||||
|
||||
val eig = (utv dot s.view(shp)).view(s.shape)
|
||||
return eig to v
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the determinant of a square matrix input, or of each square matrix in a batched input
|
||||
* using LU factorization algorithm.
|
||||
*
|
||||
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
|
||||
* @return the determinant.
|
||||
*/
|
||||
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||
|
||||
checkSquareMatrix(tensor.shape)
|
||||
val luTensor = tensor.copy()
|
||||
val pivotsTensor = tensor.setUpPivots()
|
||||
|
||||
val n = shape.size
|
||||
|
||||
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
||||
detTensorShape[n - 2] = 1
|
||||
val resBuffer = DoubleArray(detTensorShape.reduce(Int::times)) { 0.0 }
|
||||
|
||||
val detTensor = DoubleTensor(
|
||||
detTensorShape,
|
||||
resBuffer
|
||||
)
|
||||
|
||||
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
|
||||
resBuffer[index] = if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
|
||||
0.0 else luMatrixDet(lu.as2D(), pivots.as1D())
|
||||
}
|
||||
|
||||
return detTensor
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input
|
||||
* using LU factorization algorithm.
|
||||
* Given a square matrix `a`, return the matrix `aInv` satisfying
|
||||
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
|
||||
*
|
||||
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
|
||||
* @return the multiplicative inverse of a matrix.
|
||||
*/
|
||||
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||
val (luTensor, pivotsTensor) = luFactor(epsilon)
|
||||
val invTensor = luTensor.zeroesLike()
|
||||
|
||||
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||
for ((luP, invMatrix) in seq) {
|
||||
val (lu, pivots) = luP
|
||||
luMatrixInv(lu.as2D(), pivots.as1D(), invMatrix.as2D())
|
||||
}
|
||||
|
||||
return invTensor
|
||||
}
|
||||
|
||||
/**
|
||||
* LUP decomposition
|
||||
*
|
||||
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
||||
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
|
||||
* with `P` being a permutation matrix or batch of matrices,
|
||||
* `L` being a lower triangular matrix or batch of matrices,
|
||||
* `U` being an upper triangular matrix or batch of matrices.
|
||||
*
|
||||
* @param epsilon permissible error when comparing the determinant of a matrix with zero
|
||||
* @return triple of P, L and U tensors
|
||||
*/
|
||||
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
val (lu, pivots) = tensor.luFactor(epsilon)
|
||||
return luPivot(lu, pivots)
|
||||
}
|
||||
|
||||
override fun Tensor<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,17 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.IntBuffer
|
||||
|
||||
/**
|
||||
* Default [BufferedTensor] implementation for [Int] values
|
||||
*/
|
||||
public class IntTensor internal constructor(
|
||||
shape: IntArray,
|
||||
buffer: IntArray,
|
||||
offset: Int = 0
|
||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
@ -0,0 +1,57 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.nd.Strides
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
internal fun stridesFromShape(shape: IntArray): IntArray {
|
||||
val nDim = shape.size
|
||||
val res = IntArray(nDim)
|
||||
if (nDim == 0)
|
||||
return res
|
||||
|
||||
var current = nDim - 1
|
||||
res[current] = 1
|
||||
|
||||
while (current > 0) {
|
||||
res[current - 1] = max(1, shape[current]) * res[current]
|
||||
current--
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
val res = IntArray(nDim)
|
||||
var current = offset
|
||||
var strideIndex = 0
|
||||
|
||||
while (strideIndex < nDim) {
|
||||
res[strideIndex] = (current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex++
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
/**
|
||||
* This [Strides] implementation follows the last dimension first convention
|
||||
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
||||
*
|
||||
* @param shape the shape of the tensor.
|
||||
*/
|
||||
internal class TensorLinearStructure(override val shape: IntArray) : Strides {
|
||||
override val strides: IntArray
|
||||
get() = stridesFromShape(shape)
|
||||
|
||||
override fun index(offset: Int): IntArray =
|
||||
indexFromOffset(offset, strides, shape.size)
|
||||
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
|
||||
}
|
@ -0,0 +1,146 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import kotlin.math.max
|
||||
|
||||
internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) {
|
||||
for (linearIndex in 0 until linearSize) {
|
||||
val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
|
||||
val curMultiIndex = tensor.shape.copyOf()
|
||||
|
||||
val offset = totalMultiIndex.size - curMultiIndex.size
|
||||
|
||||
for (i in curMultiIndex.indices) {
|
||||
if (curMultiIndex[i] != 1) {
|
||||
curMultiIndex[i] = totalMultiIndex[i + offset]
|
||||
} else {
|
||||
curMultiIndex[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
|
||||
resTensor.mutableBuffer.array()[linearIndex] =
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex]
|
||||
}
|
||||
}
|
||||
|
||||
internal fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||
var totalDim = 0
|
||||
for (shape in shapes) {
|
||||
totalDim = max(totalDim, shape.size)
|
||||
}
|
||||
|
||||
val totalShape = IntArray(totalDim) { 0 }
|
||||
for (shape in shapes) {
|
||||
for (i in shape.indices) {
|
||||
val curDim = shape[i]
|
||||
val offset = totalDim - shape.size
|
||||
totalShape[i + offset] = max(totalShape[i + offset], curDim)
|
||||
}
|
||||
}
|
||||
|
||||
for (shape in shapes) {
|
||||
for (i in shape.indices) {
|
||||
val curDim = shape[i]
|
||||
val offset = totalDim - shape.size
|
||||
check(curDim == 1 || totalShape[i + offset] == curDim) {
|
||||
"Shapes are not compatible and cannot be broadcast"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return totalShape
|
||||
}
|
||||
|
||||
internal fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
|
||||
require(tensor.shape.size <= newShape.size) {
|
||||
"Tensor is not compatible with the new shape"
|
||||
}
|
||||
|
||||
val n = newShape.reduce { acc, i -> acc * i }
|
||||
val resTensor = DoubleTensor(newShape, DoubleArray(n))
|
||||
|
||||
for (i in tensor.shape.indices) {
|
||||
val curDim = tensor.shape[i]
|
||||
val offset = newShape.size - tensor.shape.size
|
||||
check(curDim == 1 || newShape[i + offset] == curDim) {
|
||||
"Tensor is not compatible with the new shape and cannot be broadcast"
|
||||
}
|
||||
}
|
||||
|
||||
multiIndexBroadCasting(tensor, resTensor, n)
|
||||
return resTensor
|
||||
}
|
||||
|
||||
internal fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
||||
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||
val n = totalShape.reduce { acc, i -> acc * i }
|
||||
|
||||
return tensors.map { tensor ->
|
||||
val resTensor = DoubleTensor(totalShape, DoubleArray(n))
|
||||
multiIndexBroadCasting(tensor, resTensor, n)
|
||||
resTensor
|
||||
}
|
||||
}
|
||||
|
||||
internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
||||
val onlyTwoDims = tensors.asSequence().onEach {
|
||||
require(it.shape.size >= 2) {
|
||||
"Tensors must have at least 2 dimensions"
|
||||
}
|
||||
}.any { it.shape.size != 2 }
|
||||
|
||||
if (!onlyTwoDims) {
|
||||
return tensors.asList()
|
||||
}
|
||||
|
||||
val totalShape = broadcastShapes(*(tensors.map { it.shape.sliceArray(0..it.shape.size - 3) }).toTypedArray())
|
||||
val n = totalShape.reduce { acc, i -> acc * i }
|
||||
|
||||
return buildList {
|
||||
for (tensor in tensors) {
|
||||
val matrixShape = tensor.shape.sliceArray(tensor.shape.size - 2 until tensor.shape.size).copyOf()
|
||||
val matrixSize = matrixShape[0] * matrixShape[1]
|
||||
val matrix = DoubleTensor(matrixShape, DoubleArray(matrixSize))
|
||||
|
||||
val outerTensor = DoubleTensor(totalShape, DoubleArray(n))
|
||||
val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize))
|
||||
|
||||
for (linearIndex in 0 until n) {
|
||||
val totalMultiIndex = outerTensor.linearStructure.index(linearIndex)
|
||||
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
|
||||
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex
|
||||
|
||||
val newTensor = DoubleTensor(curMultiIndex + matrixShape, tensor.mutableBuffer.array())
|
||||
|
||||
for (i in curMultiIndex.indices) {
|
||||
if (curMultiIndex[i] != 1) {
|
||||
curMultiIndex[i] = totalMultiIndex[i]
|
||||
} else {
|
||||
curMultiIndex[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
for (i in 0 until matrixSize) {
|
||||
val curLinearIndex = newTensor.linearStructure.offset(
|
||||
curMultiIndex +
|
||||
matrix.linearStructure.index(i)
|
||||
)
|
||||
val newLinearIndex = resTensor.linearStructure.offset(
|
||||
totalMultiIndex +
|
||||
matrix.linearStructure.index(i)
|
||||
)
|
||||
|
||||
resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] =
|
||||
newTensor.mutableBuffer.array()[newTensor.bufferStart + curLinearIndex]
|
||||
}
|
||||
}
|
||||
add(resTensor)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
|
||||
|
||||
internal fun checkEmptyShape(shape: IntArray) =
|
||||
check(shape.isNotEmpty()) {
|
||||
"Illegal empty shape provided"
|
||||
}
|
||||
|
||||
internal fun checkEmptyDoubleBuffer(buffer: DoubleArray) =
|
||||
check(buffer.isNotEmpty()) {
|
||||
"Illegal empty buffer provided"
|
||||
}
|
||||
|
||||
internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) =
|
||||
check(buffer.size == shape.reduce(Int::times)) {
|
||||
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
|
||||
}
|
||||
|
||||
internal fun <T> checkShapesCompatible(a: Tensor<T>, b: Tensor<T>) =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
|
||||
}
|
||||
|
||||
internal fun checkTranspose(dim: Int, i: Int, j: Int) =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
internal fun <T> checkView(a: Tensor<T>, shape: IntArray) =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||
|
||||
internal fun checkSquareMatrix(shape: IntArray) {
|
||||
val n = shape.size
|
||||
check(n >= 2) {
|
||||
"Expected tensor with 2 or more dimensions, got size $n instead"
|
||||
}
|
||||
check(shape[n - 1] == shape[n - 2]) {
|
||||
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
|
||||
}
|
||||
}
|
||||
|
||||
internal fun DoubleTensorAlgebra.checkSymmetric(
|
||||
tensor: Tensor<Double>, epsilon: Double = 1e-6
|
||||
) =
|
||||
check(tensor.eq(tensor.transpose(), epsilon)) {
|
||||
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||
}
|
||||
|
||||
internal fun DoubleTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor, epsilon: Double = 1e-6) {
|
||||
checkSymmetric(tensor, epsilon)
|
||||
for (mat in tensor.matrixSequence())
|
||||
check(mat.asTensor().detLU().value() > 0.0) {
|
||||
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||
}
|
||||
}
|
@ -0,0 +1,342 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.nd.MutableStructure1D
|
||||
import space.kscience.kmath.nd.MutableStructure2D
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.*
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.valueOrNull
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.min
|
||||
import kotlin.math.sign
|
||||
import kotlin.math.sqrt
|
||||
|
||||
|
||||
internal fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTensor<T>> = sequence {
|
||||
val n = shape.size
|
||||
val vectorOffset = shape[n - 1]
|
||||
val vectorShape = intArrayOf(shape.last())
|
||||
for (offset in 0 until numElements step vectorOffset) {
|
||||
val vector = BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset)
|
||||
yield(vector)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>> = sequence {
|
||||
val n = shape.size
|
||||
check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" }
|
||||
val matrixOffset = shape[n - 1] * shape[n - 2]
|
||||
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
|
||||
for (offset in 0 until numElements step matrixOffset) {
|
||||
val matrix = BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset)
|
||||
yield(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun dotHelper(
|
||||
a: MutableStructure2D<Double>,
|
||||
b: MutableStructure2D<Double>,
|
||||
res: MutableStructure2D<Double>,
|
||||
l: Int, m: Int, n: Int
|
||||
) {
|
||||
for (i in 0 until l) {
|
||||
for (j in 0 until n) {
|
||||
var curr = 0.0
|
||||
for (k in 0 until m) {
|
||||
curr += a[i, k] * b[k, j]
|
||||
}
|
||||
res[i, j] = curr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun luHelper(
|
||||
lu: MutableStructure2D<Double>,
|
||||
pivots: MutableStructure1D<Int>,
|
||||
epsilon: Double
|
||||
): Boolean {
|
||||
|
||||
val m = lu.rowNum
|
||||
|
||||
for (row in 0..m) pivots[row] = row
|
||||
|
||||
for (i in 0 until m) {
|
||||
var maxVal = 0.0
|
||||
var maxInd = i
|
||||
|
||||
for (k in i until m) {
|
||||
val absA = abs(lu[k, i])
|
||||
if (absA > maxVal) {
|
||||
maxVal = absA
|
||||
maxInd = k
|
||||
}
|
||||
}
|
||||
|
||||
if (abs(maxVal) < epsilon)
|
||||
return true // matrix is singular
|
||||
|
||||
if (maxInd != i) {
|
||||
|
||||
val j = pivots[i]
|
||||
pivots[i] = pivots[maxInd]
|
||||
pivots[maxInd] = j
|
||||
|
||||
for (k in 0 until m) {
|
||||
val tmp = lu[i, k]
|
||||
lu[i, k] = lu[maxInd, k]
|
||||
lu[maxInd, k] = tmp
|
||||
}
|
||||
|
||||
pivots[m] += 1
|
||||
|
||||
}
|
||||
|
||||
for (j in i + 1 until m) {
|
||||
lu[j, i] /= lu[i, i]
|
||||
for (k in i + 1 until m) {
|
||||
lu[j, k] -= lu[j, i] * lu[i, k]
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
internal fun <T> BufferedTensor<T>.setUpPivots(): IntTensor {
|
||||
val n = this.shape.size
|
||||
val m = this.shape.last()
|
||||
val pivotsShape = IntArray(n - 1) { i -> this.shape[i] }
|
||||
pivotsShape[n - 2] = m + 1
|
||||
|
||||
return IntTensor(
|
||||
pivotsShape,
|
||||
IntArray(pivotsShape.reduce(Int::times)) { 0 }
|
||||
)
|
||||
}
|
||||
|
||||
internal fun DoubleTensorAlgebra.computeLU(
|
||||
tensor: DoubleTensor,
|
||||
epsilon: Double
|
||||
): Pair<DoubleTensor, IntTensor>? {
|
||||
|
||||
checkSquareMatrix(tensor.shape)
|
||||
val luTensor = tensor.copy()
|
||||
val pivotsTensor = tensor.setUpPivots()
|
||||
|
||||
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
||||
if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
|
||||
return null
|
||||
|
||||
return Pair(luTensor, pivotsTensor)
|
||||
}
|
||||
|
||||
internal fun pivInit(
|
||||
p: MutableStructure2D<Double>,
|
||||
pivot: MutableStructure1D<Int>,
|
||||
n: Int
|
||||
) {
|
||||
for (i in 0 until n) {
|
||||
p[i, pivot[i]] = 1.0
|
||||
}
|
||||
}
|
||||
|
||||
internal fun luPivotHelper(
|
||||
l: MutableStructure2D<Double>,
|
||||
u: MutableStructure2D<Double>,
|
||||
lu: MutableStructure2D<Double>,
|
||||
n: Int
|
||||
) {
|
||||
for (i in 0 until n) {
|
||||
for (j in 0 until n) {
|
||||
if (i == j) {
|
||||
l[i, j] = 1.0
|
||||
}
|
||||
if (j < i) {
|
||||
l[i, j] = lu[i, j]
|
||||
}
|
||||
if (j >= i) {
|
||||
u[i, j] = lu[i, j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun choleskyHelper(
|
||||
a: MutableStructure2D<Double>,
|
||||
l: MutableStructure2D<Double>,
|
||||
n: Int
|
||||
) {
|
||||
for (i in 0 until n) {
|
||||
for (j in 0 until i) {
|
||||
var h = a[i, j]
|
||||
for (k in 0 until j) {
|
||||
h -= l[i, k] * l[j, k]
|
||||
}
|
||||
l[i, j] = h / l[j, j]
|
||||
}
|
||||
var h = a[i, i]
|
||||
for (j in 0 until i) {
|
||||
h -= l[i, j] * l[i, j]
|
||||
}
|
||||
l[i, i] = sqrt(h)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun luMatrixDet(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>): Double {
|
||||
if (lu[0, 0] == 0.0) {
|
||||
return 0.0
|
||||
}
|
||||
val m = lu.shape[0]
|
||||
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||
}
|
||||
|
||||
internal fun luMatrixInv(
|
||||
lu: MutableStructure2D<Double>,
|
||||
pivots: MutableStructure1D<Int>,
|
||||
invMatrix: MutableStructure2D<Double>
|
||||
) {
|
||||
val m = lu.shape[0]
|
||||
|
||||
for (j in 0 until m) {
|
||||
for (i in 0 until m) {
|
||||
if (pivots[i] == j) {
|
||||
invMatrix[i, j] = 1.0
|
||||
}
|
||||
|
||||
for (k in 0 until i) {
|
||||
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
|
||||
}
|
||||
}
|
||||
|
||||
for (i in m - 1 downTo 0) {
|
||||
for (k in i + 1 until m) {
|
||||
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
|
||||
}
|
||||
invMatrix[i, j] /= lu[i, i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun DoubleTensorAlgebra.qrHelper(
|
||||
matrix: DoubleTensor,
|
||||
q: DoubleTensor,
|
||||
r: MutableStructure2D<Double>
|
||||
) {
|
||||
checkSquareMatrix(matrix.shape)
|
||||
val n = matrix.shape[0]
|
||||
val qM = q.as2D()
|
||||
val matrixT = matrix.transpose(0, 1)
|
||||
val qT = q.transpose(0, 1)
|
||||
|
||||
for (j in 0 until n) {
|
||||
val v = matrixT[j]
|
||||
val vv = v.as1D()
|
||||
if (j > 0) {
|
||||
for (i in 0 until j) {
|
||||
r[i, j] = (qT[i] dot matrixT[j]).value()
|
||||
for (k in 0 until n) {
|
||||
val qTi = qT[i].as1D()
|
||||
vv[k] = vv[k] - r[i, j] * qTi[k]
|
||||
}
|
||||
}
|
||||
}
|
||||
r[j, j] = DoubleTensorAlgebra { (v dot v).sqrt().value() }
|
||||
for (i in 0 until n) {
|
||||
qM[i, j] = vv[i] / r[j, j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun DoubleTensorAlgebra.svd1d(a: DoubleTensor, epsilon: Double = 1e-10): DoubleTensor {
|
||||
val (n, m) = a.shape
|
||||
var v: DoubleTensor
|
||||
val b: DoubleTensor
|
||||
if (n > m) {
|
||||
b = a.transpose(0, 1).dot(a)
|
||||
v = DoubleTensor(intArrayOf(m), getRandomUnitVector(m, 0))
|
||||
} else {
|
||||
b = a.dot(a.transpose(0, 1))
|
||||
v = DoubleTensor(intArrayOf(n), getRandomUnitVector(n, 0))
|
||||
}
|
||||
|
||||
var lastV: DoubleTensor
|
||||
while (true) {
|
||||
lastV = v
|
||||
v = b.dot(lastV)
|
||||
val norm = DoubleTensorAlgebra { (v dot v).sqrt().value() }
|
||||
v = v.times(1.0 / norm)
|
||||
if (abs(v.dot(lastV).value()) > 1 - epsilon) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun DoubleTensorAlgebra.svdHelper(
|
||||
matrix: DoubleTensor,
|
||||
USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>,
|
||||
m: Int, n: Int, epsilon: Double
|
||||
) {
|
||||
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
|
||||
val (matrixU, SV) = USV
|
||||
val (matrixS, matrixV) = SV
|
||||
|
||||
for (k in 0 until min(n, m)) {
|
||||
var a = matrix.copy()
|
||||
for ((singularValue, u, v) in res.slice(0 until k)) {
|
||||
val outerProduct = DoubleArray(u.shape[0] * v.shape[0])
|
||||
for (i in 0 until u.shape[0]) {
|
||||
for (j in 0 until v.shape[0]) {
|
||||
outerProduct[i * v.shape[0] + j] = u[i].value() * v[j].value()
|
||||
}
|
||||
}
|
||||
a = a - singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct))
|
||||
}
|
||||
var v: DoubleTensor
|
||||
var u: DoubleTensor
|
||||
var norm: Double
|
||||
if (n > m) {
|
||||
v = svd1d(a, epsilon)
|
||||
u = matrix.dot(v)
|
||||
norm = DoubleTensorAlgebra { (u dot u).sqrt().value() }
|
||||
u = u.times(1.0 / norm)
|
||||
} else {
|
||||
u = svd1d(a, epsilon)
|
||||
v = matrix.transpose(0, 1).dot(u)
|
||||
norm = DoubleTensorAlgebra { (v dot v).sqrt().value() }
|
||||
v = v.times(1.0 / norm)
|
||||
}
|
||||
|
||||
res.add(Triple(norm, u, v))
|
||||
}
|
||||
|
||||
val s = res.map { it.first }.toDoubleArray()
|
||||
val uBuffer = res.map { it.second }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
|
||||
val vBuffer = res.map { it.third }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
|
||||
for (i in uBuffer.indices) {
|
||||
matrixU.mutableBuffer.array()[matrixU.bufferStart + i] = uBuffer[i]
|
||||
}
|
||||
for (i in s.indices) {
|
||||
matrixS.mutableBuffer.array()[matrixS.bufferStart + i] = s[i]
|
||||
}
|
||||
for (i in vBuffer.indices) {
|
||||
matrixV.mutableBuffer.array()[matrixV.bufferStart + i] = vBuffer[i]
|
||||
}
|
||||
}
|
||||
|
||||
internal fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int) {
|
||||
for (i in 0 until n)
|
||||
for (j in 0 until n) {
|
||||
if (i == j) {
|
||||
matrix[i, j] = sign(matrix[i, j])
|
||||
} else {
|
||||
matrix[i, j] = 0.0
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.nd.MutableBufferND
|
||||
import space.kscience.kmath.structures.asMutableBuffer
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.BufferedTensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.IntTensor
|
||||
|
||||
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
|
||||
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
|
||||
|
||||
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
|
||||
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
|
||||
|
||||
internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
|
||||
BufferedTensor(
|
||||
this.shape,
|
||||
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
|
||||
)
|
||||
|
||||
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||
is BufferedTensor<T> -> this
|
||||
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
|
||||
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
|
||||
else -> this.copyToBufferedTensor()
|
||||
}
|
||||
|
||||
internal val Tensor<Double>.tensor: DoubleTensor
|
||||
get() = when (this) {
|
||||
is DoubleTensor -> this
|
||||
else -> this.toBufferedTensor().asTensor()
|
||||
}
|
||||
|
||||
internal val Tensor<Int>.tensor: IntTensor
|
||||
get() = when (this) {
|
||||
is IntTensor -> this
|
||||
else -> this.toBufferedTensor().asTensor()
|
||||
}
|
@ -0,0 +1,124 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.samplers.GaussianSampler
|
||||
import space.kscience.kmath.stat.RandomGenerator
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.tensors.core.BufferedTensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer] or copy the data.
|
||||
*/
|
||||
internal fun Buffer<Int>.array(): IntArray = when (this) {
|
||||
is IntBuffer -> array
|
||||
else -> this.toIntArray()
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
|
||||
*/
|
||||
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
||||
is DoubleBuffer -> array
|
||||
else -> this.toDoubleArray()
|
||||
}
|
||||
|
||||
internal fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
||||
val distribution = GaussianSampler(0.0, 1.0)
|
||||
val generator = RandomGenerator.default(seed)
|
||||
return distribution.sample(generator).nextBufferBlocking(n).toDoubleArray()
|
||||
}
|
||||
|
||||
internal fun getRandomUnitVector(n: Int, seed: Long): DoubleArray {
|
||||
val unnorm = getRandomNormals(n, seed)
|
||||
val norm = sqrt(unnorm.sumOf { it * it })
|
||||
return unnorm.map { it / norm }.toDoubleArray()
|
||||
}
|
||||
|
||||
internal fun minusIndexFrom(n: Int, i: Int): Int = if (i >= 0) i else {
|
||||
val ii = n + i
|
||||
check(ii >= 0) {
|
||||
"Out of bound index $i for tensor of dim $n"
|
||||
}
|
||||
ii
|
||||
}
|
||||
|
||||
internal fun <T> BufferedTensor<T>.minusIndex(i: Int): Int = minusIndexFrom(this.dimension, i)
|
||||
|
||||
internal fun format(value: Double, digits: Int = 4): String = buildString {
|
||||
val res = buildString {
|
||||
val ten = 10.0
|
||||
val approxOrder = if (value == 0.0) 0 else ceil(log10(abs(value))).toInt()
|
||||
val order = if (
|
||||
((value % ten) == 0.0) ||
|
||||
(value == 1.0) ||
|
||||
((1 / value) % ten == 0.0)
|
||||
) approxOrder else approxOrder - 1
|
||||
val lead = value / ten.pow(order)
|
||||
if (value >= 0.0) append(' ')
|
||||
append(round(lead * ten.pow(digits)) / ten.pow(digits))
|
||||
when {
|
||||
order == 0 -> Unit
|
||||
order > 0 -> {
|
||||
append("e+")
|
||||
append(order)
|
||||
}
|
||||
else -> {
|
||||
append('e')
|
||||
append(order)
|
||||
}
|
||||
}
|
||||
}
|
||||
val fLength = digits + 6
|
||||
append(res)
|
||||
repeat(fLength - res.length) { append(' ') }
|
||||
}
|
||||
|
||||
internal fun DoubleTensor.toPrettyString(): String = buildString {
|
||||
var offset = 0
|
||||
val shape = this@toPrettyString.shape
|
||||
val linearStructure = this@toPrettyString.linearStructure
|
||||
val vectorSize = shape.last()
|
||||
append("DoubleTensor(\n")
|
||||
var charOffset = 3
|
||||
for (vector in vectorSequence()) {
|
||||
repeat(charOffset) { append(' ') }
|
||||
val index = linearStructure.index(offset)
|
||||
for (ind in index.reversed()) {
|
||||
if (ind != 0) {
|
||||
break
|
||||
}
|
||||
append('[')
|
||||
charOffset += 1
|
||||
}
|
||||
|
||||
val values = vector.as1D().toMutableList().map(::format)
|
||||
|
||||
values.joinTo(this, separator = ", ")
|
||||
|
||||
append(']')
|
||||
charOffset -= 1
|
||||
|
||||
index.reversed().zip(shape.reversed()).drop(1).forEach { (ind, maxInd) ->
|
||||
if (ind != maxInd - 1) {
|
||||
return@forEach
|
||||
}
|
||||
append(']')
|
||||
charOffset -= 1
|
||||
}
|
||||
|
||||
offset += vectorSize
|
||||
if (this@toPrettyString.numElements == offset) {
|
||||
break
|
||||
}
|
||||
|
||||
append(",\n")
|
||||
}
|
||||
append("\n)")
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.internal.tensor
|
||||
|
||||
/**
|
||||
* Casts [Tensor] of [Double] to [DoubleTensor]
|
||||
*/
|
||||
public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
|
||||
|
||||
/**
|
||||
* Casts [Tensor] of [Int] to [IntTensor]
|
||||
*/
|
||||
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
|
||||
|
||||
/**
|
||||
* Returns [DoubleArray] of tensor elements
|
||||
*/
|
||||
public fun DoubleTensor.toDoubleArray(): DoubleArray {
|
||||
return DoubleArray(numElements) { i ->
|
||||
mutableBuffer[bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [IntArray] of tensor elements
|
||||
*/
|
||||
public fun IntTensor.toIntArray(): IntArray {
|
||||
return IntArray(numElements) { i ->
|
||||
mutableBuffer[bufferStart + i]
|
||||
}
|
||||
}
|
@ -0,0 +1,105 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.internal.*
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal class TestBroadcasting {
|
||||
|
||||
@Test
|
||||
fun testBroadcastShapes() = DoubleTensorAlgebra {
|
||||
assertTrue(
|
||||
broadcastShapes(
|
||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||
) contentEquals intArrayOf(1, 2, 3)
|
||||
)
|
||||
|
||||
assertTrue(
|
||||
broadcastShapes(
|
||||
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7), intArrayOf(5, 1, 7)
|
||||
) contentEquals intArrayOf(5, 6, 7)
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBroadcastTo() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
|
||||
val res = broadcastTo(tensor2, tensor1.shape)
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBroadcastTensors() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||
|
||||
assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBroadcastOuterTensors() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 1, 3))
|
||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 1, 1))
|
||||
|
||||
assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0))
|
||||
assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBroadcastOuterTensorsShapes() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 1, 3, 2, 3), DoubleArray(2 * 1 * 3 * 2 * 3) {0.0})
|
||||
val tensor2 = fromArray(intArrayOf(4, 2, 5, 1, 3, 3), DoubleArray(4 * 2 * 5 * 1 * 3 * 3) {0.0})
|
||||
val tensor3 = fromArray(intArrayOf(1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals intArrayOf(4, 2, 5, 3, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals intArrayOf(4, 2, 5, 3, 3, 3))
|
||||
assertTrue(res[2].shape contentEquals intArrayOf(4, 2, 5, 3, 1, 1))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinusTensor() = BroadcastDoubleTensorAlgebra.invoke {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val tensor21 = tensor2 - tensor1
|
||||
val tensor31 = tensor3 - tensor1
|
||||
val tensor32 = tensor3 - tensor2
|
||||
|
||||
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue(tensor21.mutableBuffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
|
||||
assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(
|
||||
tensor31.mutableBuffer.array()
|
||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
|
||||
)
|
||||
|
||||
assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3))
|
||||
assertTrue(tensor32.mutableBuffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,158 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.math.*
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal class TestDoubleAnalyticTensorAlgebra {
|
||||
|
||||
val shape = intArrayOf(2, 1, 3, 2)
|
||||
val buffer = doubleArrayOf(
|
||||
27.1, 20.0, 19.84,
|
||||
23.123, 3.0, 2.0,
|
||||
|
||||
3.23, 133.7, 25.3,
|
||||
100.3, 11.0, 12.012
|
||||
)
|
||||
val tensor = DoubleTensor(shape, buffer)
|
||||
|
||||
fun DoubleArray.fmap(transform: (Double) -> Double): DoubleArray {
|
||||
return this.map(transform).toDoubleArray()
|
||||
}
|
||||
|
||||
fun expectedTensor(transform: (Double) -> Double): DoubleTensor {
|
||||
return DoubleTensor(shape, buffer.fmap(transform))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testExp() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.exp() eq expectedTensor(::exp) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLog() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.ln() eq expectedTensor(::ln) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqrt() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.sqrt() eq expectedTensor(::sqrt) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCos() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.cos() eq expectedTensor(::cos) }
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testCosh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.cosh() eq expectedTensor(::cosh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcosh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.acosh() eq expectedTensor(::acosh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSin() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.sin() eq expectedTensor(::sin) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSinh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.sinh() eq expectedTensor(::sinh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsinh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.asinh() eq expectedTensor(::asinh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTan() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.tan() eq expectedTensor(::tan) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtan() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.atan() eq expectedTensor(::atan) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTanh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.tanh() eq expectedTensor(::tanh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCeil() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.ceil() eq expectedTensor(::ceil) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFloor() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.floor() eq expectedTensor(::floor) }
|
||||
}
|
||||
|
||||
val shape2 = intArrayOf(2, 2)
|
||||
val buffer2 = doubleArrayOf(
|
||||
1.0, 2.0,
|
||||
-3.0, 4.0
|
||||
)
|
||||
val tensor2 = DoubleTensor(shape2, buffer2)
|
||||
|
||||
@Test
|
||||
fun testMin() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor2.min() == -3.0 }
|
||||
assertTrue { tensor2.min(0, true) eq fromArray(
|
||||
intArrayOf(1, 2),
|
||||
doubleArrayOf(-3.0, 2.0)
|
||||
)}
|
||||
assertTrue { tensor2.min(1, false) eq fromArray(
|
||||
intArrayOf(2),
|
||||
doubleArrayOf(1.0, -3.0)
|
||||
)}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMax() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor2.max() == 4.0 }
|
||||
assertTrue { tensor2.max(0, true) eq fromArray(
|
||||
intArrayOf(1, 2),
|
||||
doubleArrayOf(1.0, 4.0)
|
||||
)}
|
||||
assertTrue { tensor2.max(1, false) eq fromArray(
|
||||
intArrayOf(2),
|
||||
doubleArrayOf(2.0, 4.0)
|
||||
)}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSum() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor2.sum() == 4.0 }
|
||||
assertTrue { tensor2.sum(0, true) eq fromArray(
|
||||
intArrayOf(1, 2),
|
||||
doubleArrayOf(-2.0, 6.0)
|
||||
)}
|
||||
assertTrue { tensor2.sum(1, false) eq fromArray(
|
||||
intArrayOf(2),
|
||||
doubleArrayOf(3.0, 1.0)
|
||||
)}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMean() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor2.mean() == 1.0 }
|
||||
assertTrue { tensor2.mean(0, true) eq fromArray(
|
||||
intArrayOf(1, 2),
|
||||
doubleArrayOf(-1.0, 3.0)
|
||||
)}
|
||||
assertTrue { tensor2.mean(1, false) eq fromArray(
|
||||
intArrayOf(2),
|
||||
doubleArrayOf(1.5, 0.5)
|
||||
)}
|
||||
}
|
||||
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user