Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
Alexander Nozik 2021-05-08 11:35:15 +03:00
commit 0622be2494
105 changed files with 5323 additions and 1092 deletions

View File

@ -13,9 +13,11 @@ jobs:
- name: Checkout the repo
uses: actions/checkout@v2
- name: Set up JDK 11
uses: actions/setup-java@v1
uses: DeLaGuardo/setup-graalvm@4.0
with:
java-version: 11
graalvm: 21.1.0
java: java11
arch: amd64
- name: Add msys to path
if: matrix.os == 'windows-latest'
run: SETX PATH "%PATH%;C:\msys64\mingw64\bin"

View File

@ -12,9 +12,11 @@ jobs:
- name: Checkout the repo
uses: actions/checkout@v2
- name: Set up JDK 11
uses: actions/setup-java@v1
uses: DeLaGuardo/setup-graalvm@4.0
with:
java-version: 11
graalvm: 21.1.0
java: java11
arch: amd64
- name: Cache gradle
uses: actions/cache@v2
with:
@ -30,9 +32,7 @@ jobs:
restore-keys: |
${{ runner.os }}-gradle-
- name: Build
run: |
./gradlew dokkaHtmlMultiModule --no-daemon --no-parallel --stacktrace
mv build/dokka/htmlMultiModule/-modules.html build/dokka/htmlMultiModule/index.html
run: ./gradlew dokkaHtmlMultiModule --no-daemon --no-parallel --stacktrace
- name: Deploy to GitHub Pages
uses: JamesIves/github-pages-deploy-action@4.1.0
with:

View File

@ -18,9 +18,11 @@ jobs:
- name: Checkout the repo
uses: actions/checkout@v2
- name: Set up JDK 11
uses: actions/setup-java@v1
uses: DeLaGuardo/setup-graalvm@4.0
with:
java-version: 11
graalvm: 21.1.0
java: java11
arch: amd64
- name: Add msys to path
if: matrix.os == 'windows-latest'
run: SETX PATH "%PATH%;C:\msys64\mingw64\bin"

View File

@ -10,7 +10,8 @@
- Blocking chains and Statistics
- Multiplatform integration
- Integration for any Field element
- Extendend operations for ND4J fields
- Extended operations for ND4J fields
- Jupyter Notebook integration module (kmath-jupyter)
### Changed
- Exponential operations merged with hyperbolic functions
@ -24,6 +25,7 @@
- Redesign MST. Remove MSTExpression.
- Move MST to core
- Separated benchmarks and examples
- Rewritten EJML module without ejml-simple
### Deprecated

View File

@ -91,7 +91,7 @@ KMath is a modular library. Different modules provide different features with di
* ### [kmath-ast](kmath-ast)
>
>
> **Maturity**: PROTOTYPE
> **Maturity**: EXPERIMENTAL
>
> **Features:**
> - [expression-language](kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
@ -154,9 +154,9 @@ performance calculations to code generation.
> **Maturity**: PROTOTYPE
>
> **Features:**
> - [ejml-vector](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : The Point implementation using SimpleMatrix.
> - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : The Matrix implementation using SimpleMatrix.
> - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : The LinearSpace implementation using SimpleMatrix.
> - [ejml-vector](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : Point implementations.
> - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation.
> - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations.
<hr/>
@ -200,6 +200,12 @@ One can still use generic algebras though.
> **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-jupyter](kmath-jupyter)
>
>
> **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-kotlingrad](kmath-kotlingrad)
>
>
@ -230,6 +236,18 @@ One can still use generic algebras though.
> **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-tensors](kmath-tensors)
>
>
> **Maturity**: PROTOTYPE
>
> **Features:**
> - [tensor algebra](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.)
> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
> - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
<hr/>
* ### [kmath-viktor](kmath-viktor)
>
>

View File

@ -9,14 +9,10 @@ sourceSets.register("benchmarks")
repositories {
mavenCentral()
jcenter()
maven("https://repo.kotlin.link")
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://jitpack.io")
maven {
setUrl("http://logicrunch.research.it.uu.se/maven/")
maven("http://logicrunch.research.it.uu.se/maven") {
isAllowInsecureProtocol = true
}
}

View File

@ -10,7 +10,7 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import space.kscience.kmath.commons.linear.CMLinearSpace
import space.kscience.kmath.ejml.EjmlLinearSpace
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
import space.kscience.kmath.linear.LinearSpace
import space.kscience.kmath.linear.invoke
import space.kscience.kmath.operations.DoubleField
@ -29,8 +29,8 @@ internal class DotBenchmark {
val cmMatrix1 = CMLinearSpace { matrix1.toCM() }
val cmMatrix2 = CMLinearSpace { matrix2.toCM() }
val ejmlMatrix1 = EjmlLinearSpace { matrix1.toEjml() }
val ejmlMatrix2 = EjmlLinearSpace { matrix2.toEjml() }
val ejmlMatrix1 = EjmlLinearSpaceDDRM { matrix1.toEjml() }
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
}
@Benchmark
@ -42,14 +42,14 @@ internal class DotBenchmark {
@Benchmark
fun ejmlDot(blackhole: Blackhole) {
EjmlLinearSpace {
EjmlLinearSpaceDDRM {
blackhole.consume(ejmlMatrix1 dot ejmlMatrix2)
}
}
@Benchmark
fun ejmlDotWithConversion(blackhole: Blackhole) {
EjmlLinearSpace {
EjmlLinearSpaceDDRM {
blackhole.consume(matrix1 dot matrix2)
}
}

View File

@ -11,25 +11,26 @@ import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import space.kscience.kmath.commons.linear.CMLinearSpace
import space.kscience.kmath.commons.linear.inverse
import space.kscience.kmath.ejml.EjmlLinearSpace
import space.kscience.kmath.ejml.inverse
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
import space.kscience.kmath.linear.InverseMatrixFeature
import space.kscience.kmath.linear.LinearSpace
import space.kscience.kmath.linear.inverseWithLup
import space.kscience.kmath.linear.invoke
import space.kscience.kmath.nd.getFeature
import kotlin.random.Random
@State(Scope.Benchmark)
internal class MatrixInverseBenchmark {
companion object {
val random = Random(1224)
const val dim = 100
private companion object {
private val random = Random(1224)
private const val dim = 100
private val space = LinearSpace.real
//creating invertible matrix
val u = space.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
val l = space.buildMatrix(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
val matrix = space { l dot u }
private val u = space.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
private val l = space.buildMatrix(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
private val matrix = space { l dot u }
}
@Benchmark
@ -46,8 +47,8 @@ internal class MatrixInverseBenchmark {
@Benchmark
fun ejmlInverse(blackhole: Blackhole) {
with(EjmlLinearSpace) {
blackhole.consume(inverse(matrix))
with(EjmlLinearSpaceDDRM) {
blackhole.consume(matrix.getFeature<InverseMatrixFeature<Double>>()?.inverse)
}
}
}

View File

@ -1,17 +1,16 @@
plugins {
id("ru.mipt.npm.gradle.project")
kotlin("jupyter.api") apply false
}
allprojects {
repositories {
jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://jitpack.io")
maven("http://logicrunch.research.it.uu.se/maven/") {
maven("http://logicrunch.research.it.uu.se/maven") {
isAllowInsecureProtocol = true
}
maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven")
mavenCentral()
}
@ -23,22 +22,16 @@ subprojects {
if (name.startsWith("kmath")) apply<MavenPublishPlugin>()
afterEvaluate {
tasks.withType<org.jetbrains.dokka.gradle.DokkaTask> {
dokkaSourceSets.all {
val readmeFile = File(this@subprojects.projectDir, "./README.md")
if (readmeFile.exists())
includes.setFrom(includes + readmeFile.absolutePath)
tasks.withType<org.jetbrains.dokka.gradle.DokkaTaskPartial> {
dependsOn(tasks.getByName("assemble"))
arrayOf(
"http://ejml.org/javadoc/",
"https://commons.apache.org/proper/commons-math/javadocs/api-3.6.1/",
"https://deeplearning4j.org/api/latest/"
).map { java.net.URL("${it}package-list") to java.net.URL(it) }.forEach { (a, b) ->
externalDocumentationLink {
packageListUrl.set(a)
url.set(b)
}
}
dokkaSourceSets.all {
val readmeFile = File(this@subprojects.projectDir, "README.md")
if (readmeFile.exists()) includes.setFrom(includes + readmeFile.absolutePath)
externalDocumentationLink("http://ejml.org/javadoc/")
externalDocumentationLink("https://commons.apache.org/proper/commons-math/javadocs/api-3.6.1/")
externalDocumentationLink("https://deeplearning4j.org/api/latest/")
externalDocumentationLink("https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/")
}
}
}

View File

@ -6,8 +6,7 @@ The Maven coordinates of this project are `${group}:${name}:${version}`.
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -18,8 +17,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {

View File

@ -4,14 +4,11 @@ plugins {
repositories {
mavenCentral()
jcenter()
maven("https://repo.kotlin.link")
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://jitpack.io")
maven{
setUrl("http://logicrunch.research.it.uu.se/maven/")
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
maven("http://logicrunch.research.it.uu.se/maven") {
isAllowInsecureProtocol = true
}
}
@ -28,6 +25,7 @@ dependencies {
implementation(project(":kmath-dimensions"))
implementation(project(":kmath-ejml"))
implementation(project(":kmath-nd4j"))
implementation(project(":kmath-tensors"))
implementation(project(":kmath-for-real"))

View File

@ -0,0 +1,46 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
// Dataset normalization
fun main() {
// work in context with broadcast methods
BroadcastDoubleTensorAlgebra {
// take dataset of 5-element vectors from normal distribution
val dataset = randomNormal(intArrayOf(100, 5)) * 1.5 // all elements from N(0, 1.5)
dataset += fromArray(
intArrayOf(5),
doubleArrayOf(0.0, 1.0, 1.5, 3.0, 5.0) // rows means
)
// find out mean and standard deviation of each column
val mean = dataset.mean(0, false)
val std = dataset.std(0, false)
println("Mean:\n$mean")
println("Standard deviation:\n$std")
// also we can calculate other statistic as minimum and maximum of rows
println("Minimum:\n${dataset.min(0, false)}")
println("Maximum:\n${dataset.max(0, false)}")
// now we can scale dataset with mean normalization
val datasetScaled = (dataset - mean) / std
// find out mean and std of scaled dataset
println("Mean of scaled:\n${datasetScaled.mean(0, false)}")
println("Mean of scaled:\n${datasetScaled.std(0, false)}")
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
// solving linear system with LUP decomposition
fun main () {
// work in context with linear operations
BroadcastDoubleTensorAlgebra {
// set true value of x
val trueX = fromArray(
intArrayOf(4),
doubleArrayOf(-2.0, 1.5, 6.8, -2.4)
)
// and A matrix
val a = fromArray(
intArrayOf(4, 4),
doubleArrayOf(
0.5, 10.5, 4.5, 1.0,
8.5, 0.9, 12.8, 0.1,
5.56, 9.19, 7.62, 5.45,
1.0, 2.0, -3.0, -2.5
)
)
// calculate y value
val b = a dot trueX
// check out A and b
println("A:\n$a")
println("b:\n$b")
// solve `Ax = b` system using LUP decomposition
// get P, L, U such that PA = LU
val (p, l, u) = a.lu()
// check that P is permutation matrix
println("P:\n$p")
// L is lower triangular matrix and U is upper triangular matrix
println("L:\n$l")
println("U:\n$u")
// and PA = LU
println("PA:\n${p dot a}")
println("LU:\n${l dot u}")
/* Ax = b;
PAx = Pb;
LUx = Pb;
let y = Ux, then
Ly = Pb -- this system can be easily solved, since the matrix L is lower triangular;
Ux = y can be solved the same way, since the matrix L is upper triangular
*/
// this function returns solution x of a system lx = b, l should be lower triangular
fun solveLT(l: DoubleTensor, b: DoubleTensor): DoubleTensor {
val n = l.shape[0]
val x = zeros(intArrayOf(n))
for (i in 0 until n){
x[intArrayOf(i)] = (b[intArrayOf(i)] - l[i].dot(x).value()) / l[intArrayOf(i, i)]
}
return x
}
val y = solveLT(l, p dot b)
// solveLT(l, b) function can be easily adapted for upper triangular matrix by the permutation matrix revMat
// create it by placing ones on side diagonal
val revMat = u.zeroesLike()
val n = revMat.shape[0]
for (i in 0 until n) {
revMat[intArrayOf(i, n - 1 - i)] = 1.0
}
// solution of system ux = b, u should be upper triangular
fun solveUT(u: DoubleTensor, b: DoubleTensor): DoubleTensor = revMat dot solveLT(
revMat dot u dot revMat, revMat dot b
)
val x = solveUT(u, y)
println("True x:\n$trueX")
println("x founded with LU method:\n$x")
}
}

View File

@ -0,0 +1,241 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.toDoubleArray
import kotlin.math.sqrt
const val seed = 100500L
// Simple feedforward neural network with backpropagation training
// interface of network layer
interface Layer {
fun forward(input: DoubleTensor): DoubleTensor
fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor
}
// activation layer
open class Activation(
val activation: (DoubleTensor) -> DoubleTensor,
val activationDer: (DoubleTensor) -> DoubleTensor
) : Layer {
override fun forward(input: DoubleTensor): DoubleTensor {
return activation(input)
}
override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor {
return DoubleTensorAlgebra { outputError * activationDer(input) }
}
}
fun relu(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
x.map { if (it > 0) it else 0.0 }
}
fun reluDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
x.map { if (it > 0) 1.0 else 0.0 }
}
// activation layer with relu activator
class ReLU : Activation(::relu, ::reluDer)
fun sigmoid(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
1.0 / (1.0 + (-x).exp())
}
fun sigmoidDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
sigmoid(x) * (1.0 - sigmoid(x))
}
// activation layer with sigmoid activator
class Sigmoid : Activation(::sigmoid, ::sigmoidDer)
// dense layer
class Dense(
private val inputUnits: Int,
private val outputUnits: Int,
private val learningRate: Double = 0.1
) : Layer {
private val weights: DoubleTensor = DoubleTensorAlgebra {
randomNormal(
intArrayOf(inputUnits, outputUnits),
seed
) * sqrt(2.0 / (inputUnits + outputUnits))
}
private val bias: DoubleTensor = DoubleTensorAlgebra { zeros(intArrayOf(outputUnits)) }
override fun forward(input: DoubleTensor): DoubleTensor {
return BroadcastDoubleTensorAlgebra { (input dot weights) + bias }
}
override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
val gradInput = outputError dot weights.transpose()
val gradW = input.transpose() dot outputError
val gradBias = outputError.mean(dim = 0, keepDim = false) * input.shape[0].toDouble()
weights -= learningRate * gradW
bias -= learningRate * gradBias
gradInput
}
}
// simple accuracy equal to the proportion of correct answers
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
check(yPred.shape contentEquals yTrue.shape)
val n = yPred.shape[0]
var correctCnt = 0
for (i in 0 until n) {
if (yPred[intArrayOf(i, 0)] == yTrue[intArrayOf(i, 0)]) {
correctCnt += 1
}
}
return correctCnt.toDouble() / n.toDouble()
}
// neural network class
class NeuralNetwork(private val layers: List<Layer>) {
private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra {
val onesForAnswers = yPred.zeroesLike()
yTrue.toDoubleArray().forEachIndexed { index, labelDouble ->
val label = labelDouble.toInt()
onesForAnswers[intArrayOf(index, label)] = 1.0
}
val softmaxValue = yPred.exp() / yPred.exp().sum(dim = 1, keepDim = true)
(-onesForAnswers + softmaxValue) / (yPred.shape[0].toDouble())
}
@OptIn(ExperimentalStdlibApi::class)
private fun forward(x: DoubleTensor): List<DoubleTensor> {
var input = x
return buildList {
layers.forEach { layer ->
val output = layer.forward(input)
add(output)
input = output
}
}
}
@OptIn(ExperimentalStdlibApi::class)
private fun train(xTrain: DoubleTensor, yTrain: DoubleTensor) {
val layerInputs = buildList {
add(xTrain)
addAll(forward(xTrain))
}
var lossGrad = softMaxLoss(layerInputs.last(), yTrain)
layers.zip(layerInputs).reversed().forEach { (layer, input) ->
lossGrad = layer.backward(input, lossGrad)
}
}
fun fit(xTrain: DoubleTensor, yTrain: DoubleTensor, batchSize: Int, epochs: Int) = DoubleTensorAlgebra {
fun iterBatch(x: DoubleTensor, y: DoubleTensor): Sequence<Pair<DoubleTensor, DoubleTensor>> = sequence {
val n = x.shape[0]
val shuffledIndices = (0 until n).shuffled()
for (i in 0 until n step batchSize) {
val excerptIndices = shuffledIndices.drop(i).take(batchSize).toIntArray()
val batch = x.rowsByIndices(excerptIndices) to y.rowsByIndices(excerptIndices)
yield(batch)
}
}
for (epoch in 0 until epochs) {
println("Epoch ${epoch + 1}/$epochs")
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
train(xBatch, yBatch)
}
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}")
}
}
fun predict(x: DoubleTensor): DoubleTensor {
return forward(x).last()
}
}
@OptIn(ExperimentalStdlibApi::class)
fun main() {
BroadcastDoubleTensorAlgebra {
val features = 5
val sampleSize = 250
val trainSize = 180
val testSize = sampleSize - trainSize
// take sample of features from normal distribution
val x = randomNormal(intArrayOf(sampleSize, features), seed) * 2.5
x += fromArray(
intArrayOf(5),
doubleArrayOf(0.0, -1.0, -2.5, -3.0, 5.5) // rows means
)
// define class like '1' if the sum of features > 0 and '0' otherwise
val y = fromArray(
intArrayOf(sampleSize, 1),
DoubleArray(sampleSize) { i ->
if (x[i].sum() > 0.0) {
1.0
} else {
0.0
}
}
)
// split train ans test
val trainIndices = (0 until trainSize).toList().toIntArray()
val testIndices = (trainSize until sampleSize).toList().toIntArray()
val xTrain = x.rowsByIndices(trainIndices)
val yTrain = y.rowsByIndices(trainIndices)
val xTest = x.rowsByIndices(testIndices)
val yTest = y.rowsByIndices(testIndices)
// build model
val layers = buildList {
add(Dense(features, 64))
add(ReLU())
add(Dense(64, 16))
add(ReLU())
add(Dense(16, 2))
add(Sigmoid())
}
val model = NeuralNetwork(layers)
// fit it with train data
model.fit(xTrain, yTrain, batchSize = 20, epochs = 10)
// make prediction
val prediction = model.predict(xTest)
// process raw prediction via argMax
val predictionLabels = prediction.argMax(1, true)
// find out accuracy
val acc = accuracy(yTest, predictionLabels)
println("Test accuracy:$acc")
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import kotlin.math.abs
// OLS estimator using SVD
fun main() {
//seed for random
val randSeed = 100500L
// work in context with linear operations
DoubleTensorAlgebra {
// take coefficient vector from normal distribution
val alpha = randomNormal(
intArrayOf(5),
randSeed
) + fromArray(
intArrayOf(5),
doubleArrayOf(1.0, 2.5, 3.4, 5.0, 10.1)
)
println("Real alpha:\n$alpha")
// also take sample of size 20 from normal distribution for x
val x = randomNormal(
intArrayOf(20, 5),
randSeed
)
// calculate y and add gaussian noise (N(0, 0.05))
val y = x dot alpha
y += y.randomNormalLike(randSeed) * 0.05
// now restore the coefficient vector with OSL estimator with SVD
val (u, singValues, v) = x.svd()
// we have to make sure the singular values of the matrix are not close to zero
println("Singular values:\n$singValues")
// inverse Sigma matrix can be restored from singular values with diagonalEmbedding function
val sigma = diagonalEmbedding(singValues.map{ x -> if (abs(x) < 1e-3) 0.0 else 1.0/x })
val alphaOLS = v dot sigma dot u.transpose() dot y
println("Estimated alpha:\n" +
"$alphaOLS")
// figure out MSE of approximation
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
require(yTrue.shape.size == 1)
require(yTrue.shape contentEquals yPred.shape)
val diff = yTrue - yPred
return diff.dot(diff).sqrt().value()
}
println("MSE: ${mse(alpha, alphaOLS)}")
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
// simple PCA
fun main(){
val seed = 100500L
// work in context with broadcast methods
BroadcastDoubleTensorAlgebra {
// assume x is range from 0 until 10
val x = fromArray(
intArrayOf(10),
(0 until 10).toList().map { it.toDouble() }.toDoubleArray()
)
// take y dependent on x with noise
val y = 2.0 * x + (3.0 + x.randomNormalLike(seed) * 1.5)
println("x:\n$x")
println("y:\n$y")
// stack them into single dataset
val dataset = stack(listOf(x, y)).transpose()
// normalize both x and y
val xMean = x.mean()
val yMean = y.mean()
val xStd = x.std()
val yStd = y.std()
val xScaled = (x - xMean) / xStd
val yScaled = (y - yMean) / yStd
// save means ans standard deviations for further recovery
val mean = fromArray(
intArrayOf(2),
doubleArrayOf(xMean, yMean)
)
println("Means:\n$mean")
val std = fromArray(
intArrayOf(2),
doubleArrayOf(xStd, yStd)
)
println("Standard deviations:\n$std")
// calculate the covariance matrix of scaled x and y
val covMatrix = cov(listOf(xScaled, yScaled))
println("Covariance matrix:\n$covMatrix")
// and find out eigenvector of it
val (_, evecs) = covMatrix.symEig()
val v = evecs[0]
println("Eigenvector:\n$v")
// reduce dimension of dataset
val datasetReduced = v dot stack(listOf(xScaled, yScaled))
println("Reduced data:\n$datasetReduced")
// we can restore original data from reduced data.
// for example, find 7th element of dataset
val n = 7
val restored = (datasetReduced[n] dot v.view(intArrayOf(1, 2))) * std + mean
println("Original value:\n${dataset[n]}")
println("Restored value:\n$restored")
}
}

View File

@ -1,6 +1,6 @@
# Module kmath-ast
Abstract syntax tree expression representation and related optimizations.
Performance and visualization extensions to MST API.
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
@ -16,8 +16,7 @@ The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-7`
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -28,8 +27,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {
@ -41,21 +39,26 @@ dependencies {
### On JVM
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
a special implementation of `Expression<T>` with implemented `invoke` function.
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder:
```kotlin
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
```
… leads to generation of bytecode, which can be decompiled to the following Java class:
... leads to generation of bytecode, which can be decompiled to the following Java class:
```java
package space.kscience.kmath.asm.generated;
import java.util.Map;
import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics;
import space.kscience.kmath.expressions.Expression;
@ -65,7 +68,7 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
private final Object[] constants;
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2);
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
}
public AsmCompiledExpression_45045_0(Object[] constants) {
@ -77,8 +80,8 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
#### Known issues
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
class loading overhead.
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
loading overhead.
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
### On JS
@ -86,6 +89,10 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
A similar feature is also available on JS.
```kotlin
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.estree.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
```
@ -93,18 +100,22 @@ The code above returns expression implemented with such a JS function:
```js
var executable = function (constants, arguments) {
return constants[1](constants[0](arguments, "x"), 2);
return constants[1](constants[0](arguments, "x"), 2);
};
```
JS also supports very experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation.
Currently, only expressions inside `DoubleField` and `IntRing` are supported.
```kotlin
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.wasm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
```
An example of emitted WASM IR in the form of WAT:
An example of emitted Wasm IR in the form of WAT:
```lisp
(func $executable (param $0 f64) (result f64)
@ -129,7 +140,9 @@ Example usage:
```kotlin
import space.kscience.kmath.ast.*
import space.kscience.kmath.ast.rendering.*
import space.kscience.kmath.misc.*
@OptIn(UnstableKMathAPI::class)
public fun main() {
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
@ -145,13 +158,68 @@ public fun main() {
Result LaTeX:
![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D)
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{-12})
Result MathML (embedding MathML is not allowed by GitHub Markdown):
<details>
```html
<mrow><msup><mrow><mi>e</mi></mrow><mrow><msqrt><mi>x</mi></msqrt></mrow></msup><mo>-</mo><mfrac><mrow><mfrac><mrow><msup><mrow><mo>sin</mo></mrow><mrow><mo>-</mo><mn>1</mn></mrow></msup><mspace width="0.167em"></mspace><mfenced open="(" close=")" separators=""><mn>2</mn><mspace width="0.167em"></mspace><mi>x</mi></mfenced></mrow><mrow><mn>2</mn><mo>&times;</mo><msup><mrow><mn>10</mn></mrow><mrow><mn>10</mn></mrow></msup><mo>+</mo><msup><mrow><mi>x</mi></mrow><mrow><mn>3</mn></mrow></msup></mrow></mfrac></mrow><mrow><mo>-</mo><mn>12</mn></mrow></mfrac></mrow>
<math xmlns="https://www.w3.org/1998/Math/MathML">
<mrow>
<mo>exp</mo>
<mspace width="0.167em"></mspace>
<mfenced open="(" close=")" separators="">
<msqrt>
<mi>x</mi>
</msqrt>
</mfenced>
<mo>-</mo>
<mfrac>
<mrow>
<mfrac>
<mrow>
<mo>arcsin</mo>
<mspace width="0.167em"></mspace>
<mfenced open="(" close=")" separators="">
<mn>2</mn>
<mspace width="0.167em"></mspace>
<mi>x</mi>
</mfenced>
</mrow>
<mrow>
<mn>2</mn>
<mo>&times;</mo>
<msup>
<mrow>
<mn>10</mn>
</mrow>
<mrow>
<mn>10</mn>
</mrow>
</msup>
<mo>+</mo>
<msup>
<mrow>
<mi>x</mi>
</mrow>
<mrow>
<mn>3</mn>
</mrow>
</msup>
</mrow>
</mfrac>
</mrow>
<mrow>
<mo>-</mo>
<mn>12</mn>
</mrow>
</mfrac>
</mrow>
</math>
```
</details>
It is also possible to create custom algorithms of render, and even add support of other markup languages
(see API reference).

View File

@ -18,6 +18,10 @@ kotlin.js {
}
kotlin.sourceSets {
filter { it.name.contains("test", true) }
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
.forEach { it.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI") }
commonMain {
dependencies {
api("com.github.h0tk3y.betterParse:better-parse:0.4.2")
@ -54,7 +58,7 @@ tasks.dokkaHtml {
}
readme {
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature(

View File

@ -1,6 +1,6 @@
# Module kmath-ast
Abstract syntax tree expression representation and related optimizations.
Performance and visualization extensions to MST API.
${features}
@ -10,21 +10,26 @@ ${artifact}
### On JVM
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
a special implementation of `Expression<T>` with implemented `invoke` function.
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder:
```kotlin
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
```
… leads to generation of bytecode, which can be decompiled to the following Java class:
... leads to generation of bytecode, which can be decompiled to the following Java class:
```java
package space.kscience.kmath.asm.generated;
import java.util.Map;
import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics;
import space.kscience.kmath.expressions.Expression;
@ -34,7 +39,7 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
private final Object[] constants;
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2);
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
}
public AsmCompiledExpression_45045_0(Object[] constants) {
@ -46,8 +51,8 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
#### Known issues
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
class loading overhead.
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
loading overhead.
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
### On JS
@ -55,6 +60,10 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
A similar feature is also available on JS.
```kotlin
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.estree.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
```
@ -62,18 +71,22 @@ The code above returns expression implemented with such a JS function:
```js
var executable = function (constants, arguments) {
return constants[1](constants[0](arguments, "x"), 2);
return constants[1](constants[0](arguments, "x"), 2);
};
```
JS also supports very experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation.
Currently, only expressions inside `DoubleField` and `IntRing` are supported.
```kotlin
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.wasm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField)
```
An example of emitted WASM IR in the form of WAT:
An example of emitted Wasm IR in the form of WAT:
```lisp
(func \$executable (param \$0 f64) (result f64)
@ -98,9 +111,11 @@ Example usage:
```kotlin
import space.kscience.kmath.ast.*
import space.kscience.kmath.ast.rendering.*
import space.kscience.kmath.misc.*
@OptIn(UnstableKMathAPI::class)
public fun main() {
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(12)+x^(2/3)".parseMath()
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax)
println("LaTeX:")
@ -114,13 +129,78 @@ public fun main() {
Result LaTeX:
![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D)
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
Result MathML (embedding MathML is not allowed by GitHub Markdown):
Result MathML (can be used with MathJax or other renderers):
<details>
```html
<mrow><msup><mrow><mi>e</mi></mrow><mrow><msqrt><mi>x</mi></msqrt></mrow></msup><mo>-</mo><mfrac><mrow><mfrac><mrow><msup><mrow><mo>sin</mo></mrow><mrow><mo>-</mo><mn>1</mn></mrow></msup><mspace width="0.167em"></mspace><mfenced open="(" close=")" separators=""><mn>2</mn><mspace width="0.167em"></mspace><mi>x</mi></mfenced></mrow><mrow><mn>2</mn><mo>&times;</mo><msup><mrow><mn>10</mn></mrow><mrow><mn>10</mn></mrow></msup><mo>+</mo><msup><mrow><mi>x</mi></mrow><mrow><mn>3</mn></mrow></msup></mrow></mfrac></mrow><mrow><mo>-</mo><mn>12</mn></mrow></mfrac></mrow>
<math xmlns="https://www.w3.org/1998/Math/MathML">
<mrow>
<mo>exp</mo>
<mspace width="0.167em"></mspace>
<mfenced open="(" close=")" separators="">
<msqrt>
<mi>x</mi>
</msqrt>
</mfenced>
<mo>-</mo>
<mfrac>
<mrow>
<mfrac>
<mrow>
<mo>arcsin</mo>
<mspace width="0.167em"></mspace>
<mfenced open="(" close=")" separators="">
<mn>2</mn>
<mspace width="0.167em"></mspace>
<mi>x</mi>
</mfenced>
</mrow>
<mrow>
<mn>2</mn>
<mo>&times;</mo>
<msup>
<mrow>
<mn>10</mn>
</mrow>
<mrow>
<mn>10</mn>
</mrow>
</msup>
<mo>+</mo>
<msup>
<mrow>
<mi>x</mi>
</mrow>
<mrow>
<mn>3</mn>
</mrow>
</msup>
</mrow>
</mfrac>
</mrow>
<mrow>
<mn>12</mn>
</mrow>
</mfrac>
<mo>+</mo>
<msup>
<mrow>
<mi>x</mi>
</mrow>
<mrow>
<mn>2</mn>
<mo>/</mo>
<mn>3</mn>
</mrow>
</msup>
</mrow>
</math>
```
</details>
It is also possible to create custom algorithms of render, and even add support of other markup languages
(see API reference).

View File

@ -29,7 +29,6 @@ import space.kscience.kmath.operations.RingOperations
* @author Iaroslav Postovalov
*/
public object ArithmeticsEvaluator : Grammar<MST>() {
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?".toRegex())
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*".toRegex())
private val lpar: Token by literalToken("(")

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.misc.UnstableKMathAPI
/**
* [SyntaxRenderer] implementation for LaTeX.
*
@ -23,6 +25,7 @@ package space.kscience.kmath.ast.rendering
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public object LatexSyntaxRenderer : SyntaxRenderer {
public override fun render(node: MathSyntax, output: Appendable): Unit = output.run {
fun render(syntax: MathSyntax) = render(syntax, output)
@ -115,7 +118,11 @@ public object LatexSyntaxRenderer : SyntaxRenderer {
render(node.right)
}
is FractionSyntax -> {
is FractionSyntax -> if (node.infix) {
render(node.left)
append('/')
render(node.right)
} else {
append("\\frac{")
render(node.left)
append("}{")

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.misc.UnstableKMathAPI
/**
* [SyntaxRenderer] implementation for MathML.
*
@ -12,14 +14,18 @@ package space.kscience.kmath.ast.rendering
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public object MathMLSyntaxRenderer : SyntaxRenderer {
public override fun render(node: MathSyntax, output: Appendable) {
output.append("<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>")
render0(node, output)
output.append("<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>")
renderPart(node, output)
output.append("</mrow></math>")
}
private fun render0(node: MathSyntax, output: Appendable): Unit = output.run {
/**
* Renders a part of syntax returning a correct MathML tag not the whole MathML instance.
*/
public fun renderPart(node: MathSyntax, output: Appendable): Unit = output.run {
fun tag(tagName: String, vararg attr: Pair<String, String>, block: () -> Unit = {}) {
append('<')
append(tagName)
@ -44,7 +50,7 @@ public object MathMLSyntaxRenderer : SyntaxRenderer {
append('>')
}
fun render(syntax: MathSyntax) = render0(syntax, output)
fun render(syntax: MathSyntax) = renderPart(syntax, output)
when (node) {
is NumberSyntax -> tag("mn") { append(node.string) }
@ -127,14 +133,13 @@ public object MathMLSyntaxRenderer : SyntaxRenderer {
render(node.right)
}
is FractionSyntax -> tag("mfrac") {
tag("mrow") {
render(node.left)
}
tag("mrow") {
render(node.right)
}
is FractionSyntax -> if (node.infix) {
render(node.left)
tag("mo") { append('/') }
render(node.right)
} else tag("mfrac") {
tag("mrow") { render(node.left) }
tag("mrow") { render(node.right) }
}
is RadicalWithIndexSyntax -> tag("mroot") {

View File

@ -6,12 +6,14 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.UnstableKMathAPI
/**
* Renders [MST] to [MathSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun interface MathRenderer {
/**
* Renders [MST] to [MathSyntax].
@ -25,6 +27,7 @@ public fun interface MathRenderer {
* @property features The applied features.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public open class FeaturedMathRenderer(public val features: List<RenderFeature>) : MathRenderer {
public override fun render(mst: MST): MathSyntax {
for (feature in features) feature.render(this, mst)?.let { return it }
@ -48,6 +51,7 @@ public open class FeaturedMathRenderer(public val features: List<RenderFeature>)
* @property stages The applied stages.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public open class FeaturedMathRendererWithPostProcess(
features: List<RenderFeature>,
public val stages: List<PostProcessStage>,
@ -85,6 +89,7 @@ public open class FeaturedMathRendererWithPostProcess(
SquareRoot.Default,
Exponent.Default,
InverseTrigonometricOperations.Default,
InverseHyperbolicOperations.Default,
// Fallback option for unknown operations - printing them as operator
BinaryOperator.Default,
@ -101,6 +106,7 @@ public open class FeaturedMathRendererWithPostProcess(
),
listOf(
BetterExponent,
BetterFraction,
SimplifyParentheses.Default,
BetterMultiplication,
),

View File

@ -5,11 +5,14 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.misc.UnstableKMathAPI
/**
* Mathematical typography syntax node.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public sealed class MathSyntax {
/**
* The parent node of this syntax node.
@ -22,6 +25,7 @@ public sealed class MathSyntax {
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public sealed class TerminalSyntax : MathSyntax()
/**
@ -29,6 +33,7 @@ public sealed class TerminalSyntax : MathSyntax()
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public sealed class OperationSyntax : MathSyntax() {
/**
* The operation token.
@ -41,6 +46,7 @@ public sealed class OperationSyntax : MathSyntax() {
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public sealed class UnarySyntax : OperationSyntax() {
/**
* The operand of this node.
@ -53,6 +59,7 @@ public sealed class UnarySyntax : OperationSyntax() {
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public sealed class BinarySyntax : OperationSyntax() {
/**
* The left-hand side operand.
@ -71,6 +78,7 @@ public sealed class BinarySyntax : OperationSyntax() {
* @property string The digits of number.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class NumberSyntax(public var string: String) : TerminalSyntax()
/**
@ -79,6 +87,7 @@ public data class NumberSyntax(public var string: String) : TerminalSyntax()
* @property string The symbol.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class SymbolSyntax(public var string: String) : TerminalSyntax()
/**
@ -89,14 +98,16 @@ public data class SymbolSyntax(public var string: String) : TerminalSyntax()
* @see UnaryOperatorSyntax
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class OperatorNameSyntax(public var name: String) : TerminalSyntax()
/**
* Represents a usage of special symbols.
* Represents a usage of special symbols (e.g., *&infin;*).
*
* @property kind The kind of symbol.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class SpecialSymbolSyntax(public var kind: Kind) : TerminalSyntax() {
/**
* The kind of symbol.
@ -121,6 +132,7 @@ public data class SpecialSymbolSyntax(public var kind: Kind) : TerminalSyntax()
* @property parentheses Whether the operand should be wrapped with parentheses.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class OperandSyntax(
public val operand: MathSyntax,
public var parentheses: Boolean,
@ -131,11 +143,12 @@ public data class OperandSyntax(
}
/**
* Represents unary, prefix operator syntax (like f x).
* Represents unary, prefix operator syntax (like *f(x)*).
*
* @property prefix The prefix.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class UnaryOperatorSyntax(
public override val operation: String,
public var prefix: MathSyntax,
@ -147,10 +160,11 @@ public data class UnaryOperatorSyntax(
}
/**
* Represents prefix, unary plus operator.
* Represents prefix, unary plus operator (*+x*).
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class UnaryPlusSyntax(
public override val operation: String,
public override val operand: OperandSyntax,
@ -161,10 +175,11 @@ public data class UnaryPlusSyntax(
}
/**
* Represents prefix, unary minus operator.
* Represents prefix, unary minus operator (*-x*).
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class UnaryMinusSyntax(
public override val operation: String,
public override val operand: OperandSyntax,
@ -175,11 +190,12 @@ public data class UnaryMinusSyntax(
}
/**
* Represents radical with a node inside it.
* Represents radical with a node inside it (*&radic;x*).
*
* @property operand The radicand.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class RadicalSyntax(
public override val operation: String,
public override val operand: MathSyntax,
@ -197,6 +213,7 @@ public data class RadicalSyntax(
* (*e<sup>x</sup>*).
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class ExponentSyntax(
public override val operation: String,
public override val operand: OperandSyntax,
@ -208,12 +225,13 @@ public data class ExponentSyntax(
}
/**
* Represents a syntax node with superscript (usually, for exponentiation).
* Represents a syntax node with superscript (*x<sup>2</sup>*).
*
* @property left The node.
* @property right The superscript.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class SuperscriptSyntax(
public override val operation: String,
public override val left: MathSyntax,
@ -226,12 +244,13 @@ public data class SuperscriptSyntax(
}
/**
* Represents a syntax node with subscript.
* Represents a syntax node with subscript (*x<sub>i</sup>*).
*
* @property left The node.
* @property right The subscript.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class SubscriptSyntax(
public override val operation: String,
public override val left: MathSyntax,
@ -244,11 +263,12 @@ public data class SubscriptSyntax(
}
/**
* Represents binary, prefix operator syntax (like f(a, b)).
* Represents binary, prefix operator syntax (like *f(a, b)*).
*
* @property prefix The prefix.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class BinaryOperatorSyntax(
public override val operation: String,
public var prefix: MathSyntax,
@ -262,12 +282,13 @@ public data class BinaryOperatorSyntax(
}
/**
* Represents binary, infix addition.
* Represents binary, infix addition (*42 + 42*).
*
* @param left The augend.
* @param right The addend.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class BinaryPlusSyntax(
public override val operation: String,
public override val left: OperandSyntax,
@ -280,12 +301,13 @@ public data class BinaryPlusSyntax(
}
/**
* Represents binary, infix subtraction.
* Represents binary, infix subtraction (*42 - 42*).
*
* @param left The minuend.
* @param right The subtrahend.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class BinaryMinusSyntax(
public override val operation: String,
public override val left: OperandSyntax,
@ -302,12 +324,15 @@ public data class BinaryMinusSyntax(
*
* @property left The numerator.
* @property right The denominator.
* @property infix Whether infix (*1 / 2*) or normal (*&frac12;*) fraction should be made.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class FractionSyntax(
public override val operation: String,
public override val left: MathSyntax,
public override val right: MathSyntax,
public override val left: OperandSyntax,
public override val right: OperandSyntax,
public var infix: Boolean,
) : BinarySyntax() {
init {
left.parent = this
@ -316,12 +341,13 @@ public data class FractionSyntax(
}
/**
* Represents radical syntax with index.
* Represents radical syntax with index (*<sup>3</sup>&radic;x*).
*
* @property left The index.
* @property right The radicand.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class RadicalWithIndexSyntax(
public override val operation: String,
public override val left: MathSyntax,
@ -334,13 +360,14 @@ public data class RadicalWithIndexSyntax(
}
/**
* Represents binary, infix multiplication in the form of coefficient (2 x) or with operator (x&times;2).
* Represents binary, infix multiplication in the form of coefficient (*2 x*) or with operator (*x &times; 2*).
*
* @property left The multiplicand.
* @property right The multiplier.
* @property times whether the times (&times;) symbol should be used.
* @property times Whether the times (&times;) symbol should be used.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public data class MultiplicationSyntax(
public override val operation: String,
public override val left: OperandSyntax,

View File

@ -5,12 +5,15 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.misc.UnstableKMathAPI
/**
* Abstraction of writing [MathSyntax] as a string of an actual markup language. Typical implementation should
* involve traversal of MathSyntax with handling each its subtype.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun interface SyntaxRenderer {
/**
* Renders the [MathSyntax] to [output].
@ -23,6 +26,7 @@ public fun interface SyntaxRenderer {
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun SyntaxRenderer.renderWithStringBuilder(node: MathSyntax): String {
val sb = StringBuilder()
render(node, sb)

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.FeaturedMathRenderer.RenderFeature
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import kotlin.reflect.KClass
@ -15,11 +16,12 @@ import kotlin.reflect.KClass
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public object PrintSymbolic : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Symbolic) return null
return SymbolSyntax(string = node.value)
}
public override fun render(renderer: FeaturedMathRenderer, node: MST): SymbolSyntax? =
if (node !is MST.Symbolic) null
else
SymbolSyntax(string = node.value)
}
/**
@ -27,35 +29,38 @@ public object PrintSymbolic : RenderFeature {
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public object PrintNumeric : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Numeric) return null
return NumberSyntax(string = node.value.toString())
}
public override fun render(renderer: FeaturedMathRenderer, node: MST): NumberSyntax? = if (node !is MST.Numeric)
null
else
NumberSyntax(string = node.value.toString())
}
private fun printSignedNumberString(s: String): MathSyntax {
if (s.startsWith('-'))
return UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operand = OperandSyntax(
operand = NumberSyntax(string = s.removePrefix("-")),
parentheses = true,
),
)
return NumberSyntax(string = s)
}
@UnstableKMathAPI
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operand = OperandSyntax(
operand = NumberSyntax(string = s.removePrefix("-")),
parentheses = true,
),
)
else
NumberSyntax(string = s)
/**
* Special printing for numeric types which are printed in form of
* *('-'? (DIGIT+ ('.' DIGIT+)? ('E' '-'? DIGIT+)? | 'Infinity')) | 'NaN'*.
*
* @property types The suitable types.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Numeric || node.value::class !in types) return null
val toString = when (val v = node.value) {
is Float -> v.multiplatformToString()
is Double -> v.multiplatformToString()
@ -109,12 +114,15 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
* Special printing for numeric types which are printed in form of *'-'? DIGIT+*.
*
* @property types The suitable types.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class PrettyPrintIntegers(public val types: Set<KClass<out Number>>) : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Numeric || node.value::class !in types) return null
return printSignedNumberString(node.value.toString())
}
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? =
if (node !is MST.Numeric || node.value::class !in types)
null
else
printSignedNumberString(node.value.toString())
public companion object {
/**
@ -129,12 +137,15 @@ public class PrettyPrintIntegers(public val types: Set<KClass<out Number>>) : Re
* Special printing for symbols meaning Pi.
*
* @property symbols The allowed symbols.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class PrettyPrintPi(public val symbols: Set<String>) : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Symbolic || node.value !in symbols) return null
return SpecialSymbolSyntax(kind = SpecialSymbolSyntax.Kind.SMALL_PI)
}
public override fun render(renderer: FeaturedMathRenderer, node: MST): SpecialSymbolSyntax? =
if (node !is MST.Symbolic || node.value !in symbols)
null
else
SpecialSymbolSyntax(kind = SpecialSymbolSyntax.Kind.SMALL_PI)
public companion object {
/**
@ -149,17 +160,20 @@ public class PrettyPrintPi(public val symbols: Set<String>) : RenderFeature {
* not [MST.Unary].
*
* @param operations the allowed operations. If `null`, any operation is accepted.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public abstract class Unary(public val operations: Collection<String>?) : RenderFeature {
/**
* The actual render function.
* The actual render function specialized for [MST.Unary].
*/
protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax?
protected abstract fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax?
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Unary || operations != null && node.operation !in operations) return null
return render0(renderer, node)
}
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? =
if (node !is MST.Unary || operations != null && node.operation !in operations)
null
else
renderUnary(renderer, node)
}
/**
@ -167,169 +181,301 @@ public abstract class Unary(public val operations: Collection<String>?) : Render
* not [MST.Binary].
*
* @property operations the allowed operations. If `null`, any operation is accepted.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public abstract class Binary(public val operations: Collection<String>?) : RenderFeature {
/**
* The actual render function.
* The actual render function specialized for [MST.Binary].
*/
protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax?
protected abstract fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax?
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Binary || operations != null && node.operation !in operations) return null
return render0(renderer, node)
return renderBinary(renderer, node)
}
}
/**
* Handles binary nodes by producing [BinaryPlusSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryPlusSyntax(
operation = node.operation,
left = OperandSyntax(parent.render(node.left), true),
right = OperandSyntax(parent.render(node.right), true),
)
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): BinaryPlusSyntax =
BinaryPlusSyntax(
operation = node.operation,
left = OperandSyntax(parent.render(node.left), true),
right = OperandSyntax(parent.render(node.right), true),
)
public companion object {
/**
* The default instance configured with [GroupOperations.PLUS_OPERATION].
*/
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION))
}
}
/**
* Handles binary nodes by producing [BinaryMinusSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryMinusSyntax(
operation = node.operation,
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
)
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): BinaryMinusSyntax =
BinaryMinusSyntax(
operation = node.operation,
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
)
public companion object {
/**
* The default instance configured with [GroupOperations.MINUS_OPERATION].
*/
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION))
}
}
/**
* Handles unary nodes by producing [UnaryPlusSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryPlusSyntax(
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryPlusSyntax = UnaryPlusSyntax(
operation = node.operation,
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public companion object {
/**
* The default instance configured with [GroupOperations.PLUS_OPERATION].
*/
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION))
}
}
/**
* Handles binary nodes by producing [UnaryMinusSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryMinusSyntax(
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryMinusSyntax = UnaryMinusSyntax(
operation = node.operation,
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public companion object {
/**
* The default instance configured with [GroupOperations.MINUS_OPERATION].
*/
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION))
}
}
/**
* Handles binary nodes by producing [FractionSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class Fraction(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = FractionSyntax(
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): FractionSyntax = FractionSyntax(
operation = node.operation,
left = parent.render(node.left),
right = parent.render(node.right),
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
infix = true,
)
public companion object {
/**
* The default instance configured with [FieldOperations.DIV_OPERATION].
*/
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION))
}
}
/**
* Handles binary nodes by producing [BinaryOperatorSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class BinaryOperator(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(name = node.operation),
left = parent.render(node.left),
right = parent.render(node.right),
)
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): BinaryOperatorSyntax =
BinaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(name = node.operation),
left = parent.render(node.left),
right = parent.render(node.right),
)
public companion object {
/**
* The default instance configured with `null`.
*/
public val Default: BinaryOperator = BinaryOperator(null)
}
}
/**
* Handles unary nodes by producing [UnaryOperatorSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class UnaryOperator(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(node.operation),
operand = OperandSyntax(parent.render(node.value), true),
)
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryOperatorSyntax =
UnaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(node.operation),
operand = OperandSyntax(parent.render(node.value), true),
)
public companion object {
/**
* The default instance configured with `null`.
*/
public val Default: UnaryOperator = UnaryOperator(null)
}
}
/**
* Handles binary nodes by producing [SuperscriptSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class Power(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = SuperscriptSyntax(
operation = node.operation,
left = OperandSyntax(parent.render(node.left), true),
right = OperandSyntax(parent.render(node.right), true),
)
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): SuperscriptSyntax =
SuperscriptSyntax(
operation = node.operation,
left = OperandSyntax(parent.render(node.left), true),
right = OperandSyntax(parent.render(node.right), true),
)
public companion object {
/**
* The default instance configured with [PowerOperations.POW_OPERATION].
*/
public val Default: Power = Power(setOf(PowerOperations.POW_OPERATION))
}
}
/**
* Handles binary nodes by producing [RadicalSyntax] with no index.
*/
@UnstableKMathAPI
public class SquareRoot(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax =
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): RadicalSyntax =
RadicalSyntax(operation = node.operation, operand = parent.render(node.value))
public companion object {
/**
* The default instance configured with [PowerOperations.SQRT_OPERATION].
*/
public val Default: SquareRoot = SquareRoot(setOf(PowerOperations.SQRT_OPERATION))
}
}
/**
* Handles unary nodes by producing [ExponentSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class Exponent(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = ExponentSyntax(
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): ExponentSyntax = ExponentSyntax(
operation = node.operation,
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
useOperatorForm = true,
)
public companion object {
/**
* The default instance configured with [ExponentialOperations.EXP_OPERATION].
*/
public val Default: Exponent = Exponent(setOf(ExponentialOperations.EXP_OPERATION))
}
}
/**
* Handles binary nodes by producing [MultiplicationSyntax].
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class Multiplication(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = MultiplicationSyntax(
operation = node.operation,
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
times = true,
)
public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MultiplicationSyntax =
MultiplicationSyntax(
operation = node.operation,
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
times = true,
)
public companion object {
public val Default: Multiplication = Multiplication(setOf(
RingOperations.TIMES_OPERATION,
))
/**
* The default instance configured with [RingOperations.TIMES_OPERATION].
*/
public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION))
}
}
/**
* Handles binary nodes by producing inverse [UnaryOperatorSyntax] with *arc* prefix instead of *a*.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class InverseTrigonometricOperations(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax(
operation = node.operation,
prefix = SuperscriptSyntax(
operation = PowerOperations.POW_OPERATION,
left = OperatorNameSyntax(name = node.operation.removePrefix("a")),
right = UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operand = OperandSyntax(operand = NumberSyntax(string = "1"), parentheses = true),
),
),
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryOperatorSyntax =
UnaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(name = node.operation.replaceFirst("a", "arc")),
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public companion object {
/**
* The default instance configured with [TrigonometricOperations.ACOS_OPERATION],
* [TrigonometricOperations.ASIN_OPERATION], [TrigonometricOperations.ATAN_OPERATION].
*/
public val Default: InverseTrigonometricOperations = InverseTrigonometricOperations(setOf(
TrigonometricOperations.ACOS_OPERATION,
TrigonometricOperations.ASIN_OPERATION,
TrigonometricOperations.ATAN_OPERATION,
))
}
}
/**
* Handles binary nodes by producing inverse [UnaryOperatorSyntax] with *ar* prefix instead of *a*.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class InverseHyperbolicOperations(operations: Collection<String>?) : Unary(operations) {
public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): UnaryOperatorSyntax =
UnaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(name = node.operation.replaceFirst("a", "ar")),
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public companion object {
/**
* The default instance configured with [ExponentialOperations.ACOSH_OPERATION],
* [ExponentialOperations.ASINH_OPERATION], and [ExponentialOperations.ATANH_OPERATION].
*/
public val Default: InverseHyperbolicOperations = InverseHyperbolicOperations(setOf(
ExponentialOperations.ACOSH_OPERATION,
ExponentialOperations.ASINH_OPERATION,
ExponentialOperations.ATANH_OPERATION,

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.FieldOperations
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.PowerOperations
@ -15,6 +16,7 @@ import space.kscience.kmath.operations.RingOperations
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostProcessStage {
public override fun perform(node: MathSyntax): Unit = when (node) {
is NumberSyntax -> Unit
@ -81,6 +83,75 @@ public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostPro
}
}
/**
* Chooses [FractionSyntax.infix] depending on the context.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public object BetterFraction : FeaturedMathRendererWithPostProcess.PostProcessStage {
private fun perform0(node: MathSyntax, infix: Boolean = false): Unit = when (node) {
is NumberSyntax -> Unit
is SymbolSyntax -> Unit
is OperatorNameSyntax -> Unit
is SpecialSymbolSyntax -> Unit
is OperandSyntax -> perform0(node.operand, infix)
is UnaryOperatorSyntax -> {
perform0(node.prefix, infix)
perform0(node.operand, infix)
}
is UnaryPlusSyntax -> perform0(node.operand, infix)
is UnaryMinusSyntax -> perform0(node.operand, infix)
is RadicalSyntax -> perform0(node.operand, infix)
is ExponentSyntax -> perform0(node.operand, infix)
is SuperscriptSyntax -> {
perform0(node.left, true)
perform0(node.right, true)
}
is SubscriptSyntax -> {
perform0(node.left, true)
perform0(node.right, true)
}
is BinaryOperatorSyntax -> {
perform0(node.prefix, infix)
perform0(node.left, infix)
perform0(node.right, infix)
}
is BinaryPlusSyntax -> {
perform0(node.left, infix)
perform0(node.right, infix)
}
is BinaryMinusSyntax -> {
perform0(node.left, infix)
perform0(node.right, infix)
}
is FractionSyntax -> {
node.infix = infix
perform0(node.left, infix)
perform0(node.right, infix)
}
is RadicalWithIndexSyntax -> {
perform0(node.left, true)
perform0(node.right, true)
}
is MultiplicationSyntax -> {
perform0(node.left, infix)
perform0(node.right, infix)
}
}
public override fun perform(node: MathSyntax): Unit = perform0(node)
}
/**
* Applies [ExponentSyntax.useOperatorForm] to [ExponentSyntax] when the operand contains a fraction, a
@ -88,6 +159,7 @@ public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostPro
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessStage {
private fun perform0(node: MathSyntax): Boolean {
return when (node) {
@ -99,7 +171,7 @@ public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessSt
is UnaryOperatorSyntax -> perform0(node.prefix) || perform0(node.operand)
is UnaryPlusSyntax -> perform0(node.operand)
is UnaryMinusSyntax -> perform0(node.operand)
is RadicalSyntax -> perform0(node.operand)
is RadicalSyntax -> true
is ExponentSyntax -> {
val r = perform0(node.operand)
@ -113,7 +185,7 @@ public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessSt
is BinaryPlusSyntax -> perform0(node.left) || perform0(node.right)
is BinaryMinusSyntax -> perform0(node.left) || perform0(node.right)
is FractionSyntax -> true
is RadicalWithIndexSyntax -> perform0(node.left) || perform0(node.right)
is RadicalWithIndexSyntax -> true
is MultiplicationSyntax -> perform0(node.left) || perform0(node.right)
}
}
@ -129,6 +201,7 @@ public object BetterExponent : FeaturedMathRendererWithPostProcess.PostProcessSt
* @property precedenceFunction Returns the precedence number for syntax node. Higher number is lower priority.
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) -> Int) :
FeaturedMathRendererWithPostProcess.PostProcessStage {
public override fun perform(node: MathSyntax): Unit = when (node) {
@ -159,8 +232,11 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) ->
val isInsideExpOperator =
node.parent is ExponentSyntax && (node.parent as ExponentSyntax).useOperatorForm
val isOnOrUnderNormalFraction = node.parent is FractionSyntax && !((node.parent as FractionSyntax).infix)
node.parentheses = !isRightOfSuperscript
&& (needParenthesesByPrecedence || node.parent is UnaryOperatorSyntax || isInsideExpOperator)
&& !isOnOrUnderNormalFraction
perform(node.operand)
}

View File

@ -3,12 +3,12 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.wasm
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
@ -16,45 +16,41 @@ import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestWasmConsistencyWithInterpreter {
internal class TestCompilerConsistencyWithInterpreter {
@Test
fun intRing() {
fun intRing() = runCompilerTest {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol(x) - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
2.0,
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
number(1),
) * number(2)
}
assertEquals(
mst.interpret(IntRing, x to 3),
mst.compile(IntRing, x to 3)
mst.compile(IntRing, x to 3),
)
}
@Test
fun doubleField() {
fun doubleField() = runCompilerTest {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
number(1) / 2 + number(2.0) * one,
) + zero
}
assertEquals(
mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0)
mst.compile(DoubleField, x to 2.0),
)
}
private companion object {
private val x by symbol
}
}

View File

@ -0,0 +1,65 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestCompilerOperations {
@Test
fun testUnaryPlus() = runCompilerTest {
val expr = MstExtendedField { +bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() = runCompilerTest {
val expr = MstExtendedField { -bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) + bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() = runCompilerTest {
val expr = MstExtendedField { sin(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testCosine() = runCompilerTest {
val expr = MstExtendedField { cos(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 0.0))
}
@Test
fun testSubtract() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) - bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) / bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) pow 2 }.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
}

View File

@ -3,11 +3,11 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.wasm
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
@ -15,20 +15,16 @@ import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class TestWasmVariables {
internal class TestCompilerVariables {
@Test
fun testVariable() {
fun testVariable() = runCompilerTest {
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing)
assertEquals(1, expr(x to 1))
}
@Test
fun testUndefinedVariableFails() {
fun testUndefinedVariableFails() = runCompilerTest {
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing)
assertFailsWith<NoSuchElementException> { expr() }
}
private companion object {
private val x by symbol
}
}

View File

@ -13,7 +13,7 @@ import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserTest {
internal class TestParser {
@Test
fun evaluateParsedMst() {
val mst = "2+2*(2+2)".parseMath()

View File

@ -5,13 +5,12 @@
package space.kscience.kmath.ast
import space.kscience.kmath.ast.parseMath
import space.kscience.kmath.expressions.evaluate
import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserPrecedenceTest {
internal class TestParserPrecedence {
@Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))

View File

@ -99,13 +99,17 @@ internal class TestFeatures {
fun multiplication() = testLatex("x*1", "x\\times1")
@Test
fun inverseTrigonometry() {
testLatex("asin(x)", "\\operatorname{sin}^{-1}\\,\\left(x\\right)")
testLatex("asinh(x)", "\\operatorname{sinh}^{-1}\\,\\left(x\\right)")
testLatex("acos(x)", "\\operatorname{cos}^{-1}\\,\\left(x\\right)")
testLatex("acosh(x)", "\\operatorname{cosh}^{-1}\\,\\left(x\\right)")
testLatex("atan(x)", "\\operatorname{tan}^{-1}\\,\\left(x\\right)")
testLatex("atanh(x)", "\\operatorname{tanh}^{-1}\\,\\left(x\\right)")
fun inverseTrigonometric() {
testLatex("asin(x)", "\\operatorname{arcsin}\\,\\left(x\\right)")
testLatex("acos(x)", "\\operatorname{arccos}\\,\\left(x\\right)")
testLatex("atan(x)", "\\operatorname{arctan}\\,\\left(x\\right)")
}
@Test
fun inverseHyperbolic() {
testLatex("asinh(x)", "\\operatorname{arsinh}\\,\\left(x\\right)")
testLatex("acosh(x)", "\\operatorname{arcosh}\\,\\left(x\\right)")
testLatex("atanh(x)", "\\operatorname{artanh}\\,\\left(x\\right)")
}
// @Test

View File

@ -37,4 +37,10 @@ internal class TestStages {
testLatex("exp(x/2)", "\\operatorname{exp}\\,\\left(\\frac{x}{2}\\right)")
testLatex("exp(x^2)", "\\operatorname{exp}\\,\\left(x^{2}\\right)")
}
@Test
fun fraction() {
testLatex("x/y", "\\frac{x}{y}")
testLatex("x^(x/y)", "x^{x/y}")
}
}

View File

@ -30,17 +30,17 @@ internal object TestUtils {
)
internal fun testMathML(mst: MST, expectedMathML: String) = assertEquals(
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
expected = "<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
actual = mathML(mst),
)
internal fun testMathML(expression: String, expectedMathML: String) = assertEquals(
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
expected = "<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
actual = mathML(expression.parseMath()),
)
internal fun testMathML(expression: MathSyntax, expectedMathML: String) = assertEquals(
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
expected = "<math xmlns=\"https://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
actual = MathMLSyntaxRenderer.renderWithStringBuilder(expression),
)
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
internal interface CompilerTestContext {
fun MST.compileToExpression(algebra: IntRing): Expression<Int>
fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int
fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int = compile(algebra, mapOf(*arguments))
fun MST.compileToExpression(algebra: DoubleField): Expression<Double>
fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double
fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compile(algebra, mapOf(*arguments))
}
internal expect inline fun runCompilerTest(action: CompilerTestContext.() -> Unit)

View File

@ -3,7 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
@file:Suppress("INTERFACE_WITH_SUPERCLASS",
@file:Suppress(
"INTERFACE_WITH_SUPERCLASS",
"OVERRIDING_FINAL_MEMBER",
"RETURN_TYPE_MISMATCH_ON_OVERRIDE",
"CONFLICTING_OVERLOADS",

View File

@ -26,7 +26,7 @@ internal sealed class WasmBuilder<T>(
val keys: MutableList<String> = mutableListOf()
lateinit var ctx: BinaryenModule
open fun visitSymbolic(mst: MST.Symbolic): ExpressionRef {
open fun visitSymbolic(mst: Symbolic): ExpressionRef {
try {
algebra.bindSymbol(mst.value)
} catch (ignored: Throwable) {

View File

@ -10,6 +10,7 @@ import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.wasm.internal.DoubleWasmBuilder
@ -20,6 +21,7 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun DoubleField.expression(mst: MST): Expression<Double> =
DoubleWasmBuilder(mst).instance
@ -28,6 +30,7 @@ public fun DoubleField.expression(mst: MST): Expression<Double> =
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun IntRing.expression(mst: MST): Expression<Int> =
IntWasmBuilder(mst).instance
@ -36,6 +39,7 @@ public fun IntRing.expression(mst: MST): Expression<Int> =
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = compileWith(algebra)
@ -44,6 +48,7 @@ public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = compileW
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments)
@ -53,6 +58,7 @@ public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int =
compileToExpression(algebra)(*arguments)
@ -61,6 +67,7 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = compileWith(algebra)
@ -69,6 +76,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = c
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments)
@ -78,5 +86,6 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments)

View File

@ -0,0 +1,39 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.estree.compile as estreeCompile
import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression
import space.kscience.kmath.wasm.compile as wasmCompile
import space.kscience.kmath.wasm.compileToExpression as wasmCompileToExpression
private object WasmCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = wasmCompileToExpression(algebra)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = wasmCompile(algebra, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = wasmCompileToExpression(algebra)
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
wasmCompile(algebra, arguments)
}
private object ESTreeCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = estreeCompileToExpression(algebra)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = estreeCompile(algebra, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = estreeCompileToExpression(algebra)
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
estreeCompile(algebra, arguments)
}
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
action(WasmCompilerTestContext)
action(ESTreeCompilerTestContext)
}

View File

@ -1,97 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeConsistencyWithInterpreter {
@Test
fun mstSpace() {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol(x) + zero
}
assertEquals(
mst.interpret(MstGroup, x to MST.Numeric(2)),
mst.compile(MstGroup, x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol(x) - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}
assertEquals(
mst.interpret(ByteRing, x to 3.toByte()),
mst.compile(ByteRing, x to 3.toByte())
)
}
@Test
fun doubleField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0)
)
}
@Test
fun complexField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(ComplexField, x to 2.0.toComplex()),
mst.compile(ComplexField, x to 2.0.toComplex()),
)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,42 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
assertEquals(4.0, res)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,76 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeSpecialization {
@Test
fun testUnaryPlus() {
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() {
val expr = MstExtendedField {
binaryOperationFunction("+")(
bindSymbol(x),
bindSymbol(x),
)
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() {
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testSubtract() {
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol(x),
bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() {
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() {
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol(x), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
private companion object {
private val x by symbol
}
}

View File

@ -1,34 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class TestESTreeVariables {
@Test
fun testVariable() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr(x to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
private companion object {
private val x by symbol
}
}

View File

@ -1,42 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.wasm
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestWasmOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
assertEquals(4.0, res)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,76 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.wasm
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestWasmSpecialization {
@Test
fun testUnaryPlus() {
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() {
val expr = MstExtendedField {
binaryOperationFunction("+")(
bindSymbol(x),
bindSymbol(x),
)
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() {
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testSubtract() {
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol(x),
bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() {
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() {
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol(x), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
private companion object {
private val x by symbol
}
}

View File

@ -342,8 +342,8 @@ internal class AsmBuilder<T>(
val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") }
/**
* ASM Type for [kscience.kmath.expressions.Symbol].
* ASM Type for [space.kscience.kmath.misc.Symbol].
*/
val SYMBOL_TYPE: Type by lazy { getObjectType("space/kscience/kmath/expressions/Symbol") }
val SYMBOL_TYPE: Type by lazy { getObjectType("space/kscience/kmath/misc/Symbol") }
}
}

View File

@ -52,7 +52,7 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
*
* @author Iaroslav Postovalov
*/
internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) }
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
/**
* Creates a class name for [Expression] subclassed to implement [mst] provided.

View File

@ -1,97 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmConsistencyWithInterpreter {
@Test
fun mstSpace() {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol(x) + zero
}
assertEquals(
mst.interpret(MstGroup, x to MST.Numeric(2)),
mst.compile(MstGroup, x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol(x) - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}
assertEquals(
mst.interpret(ByteRing, x to 3.toByte()),
mst.compile(ByteRing, x to 3.toByte())
)
}
@Test
fun doubleField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0)
)
}
@Test
fun complexField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(ComplexField, x to 2.0.toComplex()),
mst.compile(ComplexField, x to 2.0.toComplex())
)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,42 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
assertEquals(4.0, res)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,76 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmSpecialization {
@Test
fun testUnaryPlus() {
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() {
val expr = MstExtendedField {
binaryOperationFunction("+")(
bindSymbol(x),
bindSymbol(x),
)
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() {
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testSubtract() {
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol(x),
bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() {
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() {
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol(x), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
private companion object {
private val x by symbol
}
}

View File

@ -1,34 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class TestAsmVariables {
@Test
fun testVariable() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr(x to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
private companion object {
private val x by symbol
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.asm.compile as asmCompile
import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression
private object AsmCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = asmCompileToExpression(algebra)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = asmCompile(algebra, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = asmCompileToExpression(algebra)
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
asmCompile(algebra, arguments)
}
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) = action(AsmCompilerTestContext)

View File

@ -14,8 +14,7 @@ The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-de
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -26,8 +25,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {

View File

@ -21,8 +21,7 @@ The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-7
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -33,8 +32,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {

View File

@ -265,6 +265,8 @@ public final class space/kscience/kmath/expressions/MstExtendedField : space/ksc
public fun sin (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST$Unary;
public synthetic fun sinh (Ljava/lang/Object;)Ljava/lang/Object;
public fun sinh (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST$Unary;
public synthetic fun sqrt (Ljava/lang/Object;)Ljava/lang/Object;
public fun sqrt (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST;
public synthetic fun tan (Ljava/lang/Object;)Ljava/lang/Object;
public fun tan (Lspace/kscience/kmath/expressions/MST;)Lspace/kscience/kmath/expressions/MST$Unary;
public synthetic fun tanh (Ljava/lang/Object;)Ljava/lang/Object;
@ -745,7 +747,7 @@ public final class space/kscience/kmath/nd/BufferAlgebraNDKt {
public static final fun ring (Lspace/kscience/kmath/nd/AlgebraND$Companion;Lspace/kscience/kmath/operations/Ring;Lkotlin/jvm/functions/Function2;[I)Lspace/kscience/kmath/nd/BufferedRingND;
}
public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
public class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/Buffer;)V
public fun elements ()Lkotlin/sequences/Sequence;
public fun get ([I)Ljava/lang/Object;
@ -786,10 +788,9 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
public fun equals (Ljava/lang/Object;)Z
public fun getLinearSize ()I
public fun getShape ()[I
public fun getStrides ()Ljava/util/List;
public fun getStrides ()[I
public fun hashCode ()I
public fun index (I)[I
public fun offset ([I)I
}
public final class space/kscience/kmath/nd/DefaultStrides$Companion {
@ -873,6 +874,22 @@ public abstract interface class space/kscience/kmath/nd/GroupND : space/kscience
public final class space/kscience/kmath/nd/GroupND$Companion {
}
public final class space/kscience/kmath/nd/MutableBufferND : space/kscience/kmath/nd/BufferND, space/kscience/kmath/nd/MutableStructureND {
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/MutableBuffer;)V
public final fun getMutableBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
public fun set ([ILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/nd/MutableStructure1D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure1D, space/kscience/kmath/structures/MutableBuffer {
public fun set ([ILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/nd/MutableStructure2D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure2D {
public fun getColumns ()Ljava/util/List;
public fun getRows ()Ljava/util/List;
public abstract fun set (IILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/nd/MutableStructureND : space/kscience/kmath/nd/StructureND {
public abstract fun set ([ILjava/lang/Object;)V
}
@ -912,10 +929,10 @@ public final class space/kscience/kmath/nd/ShortRingNDKt {
public abstract interface class space/kscience/kmath/nd/Strides {
public abstract fun getLinearSize ()I
public abstract fun getShape ()[I
public abstract fun getStrides ()Ljava/util/List;
public abstract fun getStrides ()[I
public abstract fun index (I)[I
public fun indices ()Lkotlin/sequences/Sequence;
public abstract fun offset ([I)I
public fun offset ([I)I
}
public abstract interface class space/kscience/kmath/nd/Structure1D : space/kscience/kmath/nd/StructureND, space/kscience/kmath/structures/Buffer {
@ -929,6 +946,7 @@ public final class space/kscience/kmath/nd/Structure1D$Companion {
}
public final class space/kscience/kmath/nd/Structure1DKt {
public static final fun as1D (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructure1D;
public static final fun as1D (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/Structure1D;
public static final fun asND (Lspace/kscience/kmath/structures/Buffer;)Lspace/kscience/kmath/nd/Structure1D;
}
@ -949,6 +967,7 @@ public final class space/kscience/kmath/nd/Structure2D$Companion {
}
public final class space/kscience/kmath/nd/Structure2DKt {
public static final fun as2D (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructure2D;
public static final fun as2D (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/Structure2D;
}
@ -988,14 +1007,7 @@ public abstract interface class space/kscience/kmath/operations/Algebra {
public fun unaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
}
public abstract interface class space/kscience/kmath/operations/AlgebraElement {
public abstract fun getContext ()Lspace/kscience/kmath/operations/Algebra;
}
public final class space/kscience/kmath/operations/AlgebraElementsKt {
public static final fun div (Lspace/kscience/kmath/operations/AlgebraElement;Lspace/kscience/kmath/operations/AlgebraElement;)Lspace/kscience/kmath/operations/AlgebraElement;
public static final fun plus (Lspace/kscience/kmath/operations/AlgebraElement;Lspace/kscience/kmath/operations/AlgebraElement;)Lspace/kscience/kmath/operations/AlgebraElement;
public static final fun times (Lspace/kscience/kmath/operations/AlgebraElement;Lspace/kscience/kmath/operations/AlgebraElement;)Lspace/kscience/kmath/operations/AlgebraElement;
}
public final class space/kscience/kmath/operations/AlgebraExtensionsKt {

View File

@ -135,6 +135,7 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
public override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg)
public override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg)
public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
public override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
public override fun scale(a: MST, value: Double): MST =
binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value))

View File

@ -19,6 +19,7 @@ import kotlin.reflect.KClass
* @param T the type of items.
*/
public typealias Matrix<T> = Structure2D<T>
public typealias MutableMatrix<T> = MutableStructure2D<T>
/**
* Alias or using [Buffer] as a point/vector in a many-dimensional space.

View File

@ -7,6 +7,8 @@ package space.kscience.kmath.nd
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableBufferFactory
/**
* Represents [StructureND] over [Buffer].
@ -15,7 +17,7 @@ import space.kscience.kmath.structures.BufferFactory
* @param strides The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer.
*/
public class BufferND<T>(
public open class BufferND<T>(
public val strides: Strides,
public val buffer: Buffer<T>,
) : StructureND<T> {
@ -50,4 +52,35 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
}
/**
* Represents [MutableStructureND] over [MutableBuffer].
*
* @param T the type of items.
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
* @param mutableBuffer The underlying buffer.
*/
public class MutableBufferND<T>(
strides: Strides,
public val mutableBuffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
override fun set(index: IntArray, value: T) {
mutableBuffer[strides.offset(index)] = value
}
}
/**
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
*/
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
crossinline transform: (T) -> R,
): MutableBufferND<R> {
return if (this is MutableBufferND<T>)
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
else {
val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
}

View File

@ -6,6 +6,8 @@
package space.kscience.kmath.nd
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.structures.asSequence
import kotlin.jvm.JvmInline
@ -25,6 +27,16 @@ public interface Structure1D<T> : StructureND<T>, Buffer<T> {
public companion object
}
/**
* A mutable structure that is guaranteed to be one-dimensional
*/
public interface MutableStructure1D<T> : Structure1D<T>, MutableStructureND<T>, MutableBuffer<T> {
public override operator fun set(index: IntArray, value: T) {
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
set(index[0], value)
}
}
/**
* A 1D wrapper for nd-structure
*/
@ -37,6 +49,23 @@ private value class Structure1DWrapper<T>(val structure: StructureND<T>) : Struc
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* A 1D wrapper for a mutable nd-structure
*/
private class MutableStructure1DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure1D<T> {
override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
override fun get(index: Int): T = structure[index]
override fun set(index: Int, value: T) {
structure[intArrayOf(index)] = value
}
override fun copy(): MutableBuffer<T> =
structure.elements().map { it.second }.toMutableList().asMutableBuffer()
}
/**
* A structure wrapper for buffer
@ -52,6 +81,21 @@ private value class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> {
override operator fun get(index: Int): T = buffer[index]
}
internal class MutableBuffer1DWrapper<T>(val buffer: MutableBuffer<T>) : MutableStructure1D<T> {
override val shape: IntArray get() = intArrayOf(buffer.size)
override val size: Int get() = buffer.size
override fun elements(): Sequence<Pair<IntArray, T>> =
buffer.asSequence().mapIndexed { index, value -> intArrayOf(index) to value }
override operator fun get(index: Int): T = buffer[index]
override fun set(index: Int, value: T) {
buffer[index] = value
}
override fun copy(): MutableBuffer<T> = buffer.copy()
}
/**
* Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch
*/
@ -62,6 +106,11 @@ public fun <T> StructureND<T>.as1D(): Structure1D<T> = this as? Structure1D<T> ?
}
} else error("Can't create 1d-structure from ${shape.size}d-structure")
public fun <T> MutableStructureND<T>.as1D(): MutableStructure1D<T> =
this as? MutableStructure1D<T> ?: if (shape.size == 1) {
MutableStructure1DWrapper(this)
} else error("Can't create 1d-structure from ${shape.size}d-structure")
/**
* Represent this buffer as 1D structure
*/
@ -75,3 +124,4 @@ internal fun <T : Any> Structure1D<T>.unwrap(): Buffer<T> = when {
this is Structure1DWrapper && structure is BufferND<T> -> structure.buffer
else -> this
}

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.VirtualBuffer
import space.kscience.kmath.structures.MutableListBuffer
import kotlin.jvm.JvmInline
import kotlin.reflect.KClass
@ -63,6 +64,32 @@ public interface Structure2D<T> : StructureND<T> {
public companion object
}
/**
* Represents mutable [Structure2D].
*/
public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
/**
* Inserts an item at the specified indices.
*
* @param i the first index.
* @param j the second index.
* @param value the value.
*/
public operator fun set(i: Int, j: Int, value: T)
/**
* The buffer of rows of this structure. It gets elements from the structure dynamically.
*/
override val rows: List<MutableStructure1D<T>>
get() = List(rowNum) { i -> MutableBuffer1DWrapper(MutableListBuffer(colNum) { j -> get(i, j) })}
/**
* The buffer of columns of this structure. It gets elements from the structure dynamically.
*/
override val columns: List<MutableStructure1D<T>>
get() = List(colNum) { j -> MutableBuffer1DWrapper(MutableListBuffer(rowNum) { i -> get(i, j) }) }
}
/**
* A 2D wrapper for nd-structure
*/
@ -81,6 +108,33 @@ private value class Structure2DWrapper<T>(val structure: StructureND<T>) : Struc
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* A 2D wrapper for a mutable nd-structure
*/
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T>
{
override val shape: IntArray get() = structure.shape
override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1]
override operator fun get(i: Int, j: Int): T = structure[i, j]
override fun set(index: IntArray, value: T) {
structure[index] = value
}
override operator fun set(i: Int, j: Int, value: T){
structure[intArrayOf(i, j)] = value
}
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}
/**
* Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch
*/
@ -89,9 +143,18 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
}
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) {
2 -> MutableStructure2DWrapper(this)
else -> error("Can't create 2d-structure from ${shape.size}d-structure")
}
/**
* Expose inner [StructureND] if possible
*/
internal fun <T> Structure2D<T>.unwrap(): StructureND<T> =
if (this is Structure2DWrapper) structure
else this
else this
internal fun <T> MutableStructure2D<T>.unwrap(): MutableStructureND<T> =
if (this is MutableStructure2DWrapper) structure else this

View File

@ -184,12 +184,15 @@ public interface Strides {
/**
* Array strides
*/
public val strides: List<Int>
public val strides: IntArray
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
/**
* Get multidimensional from linear
@ -221,7 +224,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
/**
* Strides for memory access
*/
override val strides: List<Int> by lazy {
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
@ -230,14 +233,9 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
current *= it
yield(current)
}
}.toList()
}.toList().toIntArray()
}
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset

View File

@ -13,6 +13,8 @@ import space.kscience.kmath.misc.UnstableKMathAPI
* @param C the type of mathematical context for this element.
* @param T the type wrapped by this wrapper.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public interface AlgebraElement<T, C : Algebra<T>> {
/**
* The context this element belongs to.
@ -45,6 +47,7 @@ public interface AlgebraElement<T, C : Algebra<T>> {
* @return the difference.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public operator fun <T : AlgebraElement<T, S>, S : NumbersAddOperations<T>> T.minus(b: T): T =
context.add(this, context.run { -b })
@ -55,6 +58,8 @@ public operator fun <T : AlgebraElement<T, S>, S : NumbersAddOperations<T>> T.mi
* @param b the addend.
* @return the sum.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public operator fun <T : AlgebraElement<T, S>, S : Ring<T>> T.plus(b: T): T =
context.add(this, b)
@ -71,6 +76,8 @@ public operator fun <T : AlgebraElement<T, S>, S : Ring<T>> T.plus(b: T): T =
* @param b the multiplier.
* @return the product.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public operator fun <T : AlgebraElement<T, R>, R : Ring<T>> T.times(b: T): T =
context.multiply(this, b)
@ -81,6 +88,8 @@ public operator fun <T : AlgebraElement<T, R>, R : Ring<T>> T.times(b: T): T =
* @param b the divisor.
* @return the quotient.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public operator fun <T : AlgebraElement<T, F>, F : Field<T>> T.div(b: T): T =
context.divide(this, b)
@ -93,6 +102,7 @@ public operator fun <T : AlgebraElement<T, F>, F : Field<T>> T.div(b: T): T =
* @param S the type of space.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public interface GroupElement<T : GroupElement<T, S>, S : Group<T>> : AlgebraElement<T, S>
/**
@ -103,6 +113,7 @@ public interface GroupElement<T : GroupElement<T, S>, S : Group<T>> : AlgebraEle
* @param R the type of ring.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public interface RingElement<T : RingElement<T, R>, R : Ring<T>> : GroupElement<T, R>
/**
@ -113,4 +124,5 @@ public interface RingElement<T : RingElement<T, R>, R : Ring<T>> : GroupElement<
* @param F the type of field.
*/
@UnstableKMathAPI
public interface FieldElement<T : FieldElement<T, F>, F : Field<T>> : RingElement<T, F>
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public interface FieldElement<T : FieldElement<T, F>, F : Field<T>> : RingElement<T, F>

View File

@ -80,36 +80,42 @@ public interface TrigonometricOperations<T> : Algebra<T> {
* Computes the sine of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
/**
* Computes the cosine of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
/**
* Computes the tangent of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
/**
* Computes the inverse sine of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
/**
* Computes the inverse cosine of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
/**
* Computes the inverse tangent of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/**
@ -154,18 +160,21 @@ public interface PowerOperations<T> : Algebra<T> {
* @return the base raised to the power.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public infix fun <T : AlgebraElement<T, out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
/**
* Computes the square root of the value [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out PowerOperations<T>>> sqrt(arg: T): T = arg pow 0.5
/**
* Computes the square of the value [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/**
@ -261,12 +270,14 @@ public interface ExponentialOperations<T> : Algebra<T> {
* The identifier of exponential function.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
/**
* The identifier of natural logarithm.
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
@ -280,30 +291,35 @@ public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> sinh(arg: T): T
* Computes the hyperbolic cosine of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> cosh(arg: T): T = arg.context.cosh(arg)
/**
* Computes the hyperbolic tangent of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> tanh(arg: T): T = arg.context.tanh(arg)
/**
* Computes the inverse hyperbolic sine of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> asinh(arg: T): T = arg.context.asinh(arg)
/**
* Computes the inverse hyperbolic cosine of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> acosh(arg: T): T = arg.context.acosh(arg)
/**
* Computes the inverse hyperbolic tangent of [arg].
*/
@UnstableKMathAPI
@Deprecated("AlgebraElements are considered odd and will be removed in future releases.")
public fun <T : AlgebraElement<T, out ExponentialOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg)
/**

View File

@ -232,7 +232,7 @@ public value class MutableListBuffer<T>(public val list: MutableList<T>) : Mutab
}
/**
* Returns an [ListBuffer] that wraps the original list.
* Returns an [MutableListBuffer] that wraps the original list.
*/
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)

View File

@ -2,9 +2,9 @@
EJML based linear algebra implementation.
- [ejml-vector](src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : The Point implementation using SimpleMatrix.
- [ejml-matrix](src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : The Matrix implementation using SimpleMatrix.
- [ejml-linear-space](src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : The LinearSpace implementation using SimpleMatrix.
- [ejml-vector](src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt) : Point implementations.
- [ejml-matrix](src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation.
- [ejml-linear-space](src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations.
## Artifact:
@ -15,8 +15,7 @@ The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-7
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -27,8 +26,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {

View File

@ -4,7 +4,7 @@ plugins {
}
dependencies {
api("org.ejml:ejml-simple:0.40")
api("org.ejml:ejml-ddense:0.40")
api(project(":kmath-core"))
}
@ -14,19 +14,19 @@ readme {
feature(
id = "ejml-vector",
description = "The Point implementation using SimpleMatrix.",
description = "Point implementations.",
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt"
)
feature(
id = "ejml-matrix",
description = "The Matrix implementation using SimpleMatrix.",
description = "Matrix implementation.",
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt"
)
feature(
id = "ejml-linear-space",
description = "The LinearSpace implementation using SimpleMatrix.",
description = "LinearSpace implementations.",
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt"
)
}

View File

@ -5,45 +5,71 @@
package space.kscience.kmath.ejml
import org.ejml.data.DMatrix
import org.ejml.data.DMatrixD1
import org.ejml.data.DMatrixRMaj
import org.ejml.dense.row.CommonOps_DDRM
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix
import space.kscience.kmath.linear.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.nd.getFeature
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.reflect.KClass
import kotlin.reflect.cast
/**
* Represents context of basic operations operating with [EjmlMatrix].
* [LinearSpace] implementation specialized for a certain EJML type.
*
* @param T the type of items in the matrices.
* @param A the element context type.
* @param M the EJML matrix type.
* @author Iaroslav Postovalov
*/
public abstract class EjmlLinearSpace<T : Any, out A : Ring<T>, M : org.ejml.data.Matrix> : LinearSpace<T, A> {
/**
* Converts this matrix to EJML one.
*/
public abstract fun Matrix<T>.toEjml(): EjmlMatrix<T, M>
/**
* Converts this vector to EJML one.
*/
public abstract fun Point<T>.toEjml(): EjmlVector<T, M>
public abstract override fun buildMatrix(
rows: Int,
columns: Int,
initializer: A.(i: Int, j: Int) -> T,
): EjmlMatrix<T, M>
public abstract override fun buildVector(size: Int, initializer: A.(Int) -> T): EjmlVector<T, M>
}
/**
* [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and
* [DMatrixRMaj] matrices.
*
* @author Iaroslav Postovalov
* @author Alexander Nozik
*/
public object EjmlLinearSpace : LinearSpace<Double, DoubleField> {
public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, DoubleField, DMatrixRMaj>() {
/**
* The [DoubleField] reference.
*/
public override val elementAlgebra: DoubleField get() = DoubleField
/**
* Converts this matrix to EJML one.
*/
@OptIn(UnstableKMathAPI::class)
public fun Matrix<Double>.toEjml(): EjmlMatrix = when (val matrix = origin) {
is EjmlMatrix -> matrix
@Suppress("UNCHECKED_CAST")
public override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixRMaj> = when {
this is EjmlDoubleMatrix<*> && origin is DMatrixRMaj -> this as EjmlDoubleMatrix<DMatrixRMaj>
else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) }
}
/**
* Converts this vector to EJML one.
*/
public fun Point<Double>.toEjml(): EjmlVector = when (this) {
is EjmlVector -> this
else -> EjmlVector(SimpleMatrix(size, 1).also {
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) }
@Suppress("UNCHECKED_CAST")
public override fun Point<Double>.toEjml(): EjmlDoubleVector<DMatrixRMaj> = when {
this is EjmlDoubleVector<*> && origin is DMatrixRMaj -> this as EjmlDoubleVector<DMatrixRMaj>
else -> EjmlDoubleVector(DMatrixRMaj(size, 1).also {
(0 until it.numRows).forEach { row -> it[row, 0] = get(row) }
})
}
@ -51,159 +77,178 @@ public object EjmlLinearSpace : LinearSpace<Double, DoubleField> {
rows: Int,
columns: Int,
initializer: DoubleField.(i: Int, j: Int) -> Double,
): EjmlMatrix = EjmlMatrix(SimpleMatrix(rows, columns).also {
): EjmlDoubleMatrix<DMatrixRMaj> = EjmlDoubleMatrix(DMatrixRMaj(rows, columns).also {
(0 until rows).forEach { row ->
(0 until columns).forEach { col -> it[row, col] = DoubleField.initializer(row, col) }
(0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) }
}
})
public override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): Point<Double> =
EjmlVector(SimpleMatrix(size, 1).also {
(0 until it.numRows()).forEach { row -> it[row, 0] = DoubleField.initializer(row) }
})
public override fun buildVector(
size: Int,
initializer: DoubleField.(Int) -> Double,
): EjmlDoubleVector<DMatrixRMaj> = EjmlDoubleVector(DMatrixRMaj(size, 1).also {
(0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) }
})
private fun SimpleMatrix.wrapMatrix() = EjmlMatrix(this)
private fun SimpleMatrix.wrapVector() = EjmlVector(this)
private fun <T : DMatrix> T.wrapMatrix() = EjmlDoubleMatrix(this)
private fun <T : DMatrixD1> T.wrapVector() = EjmlDoubleVector(this)
public override fun Matrix<Double>.unaryMinus(): Matrix<Double> = this * (-1.0)
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
val out = DMatrixRMaj(1, 1)
CommonOps_DDRM.mult(toEjml().origin, other.toEjml().origin, out)
return out.wrapMatrix()
}
public override fun Matrix<Double>.dot(vector: Point<Double>): EjmlVector =
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
public override fun Matrix<Double>.dot(vector: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
val out = DMatrixRMaj(1, 1)
CommonOps_DDRM.mult(toEjml().origin, vector.toEjml().origin, out)
return out.wrapVector()
}
public override operator fun Matrix<Double>.minus(other: Matrix<Double>): EjmlMatrix =
(toEjml().origin - other.toEjml().origin).wrapMatrix()
public override operator fun Matrix<Double>.minus(other: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
val out = DMatrixRMaj(1, 1)
CommonOps_DDRM.subtract(toEjml().origin, other.toEjml().origin, out)
return out.wrapMatrix()
}
public override operator fun Matrix<Double>.times(value: Double): EjmlMatrix =
toEjml().origin.scale(value).wrapMatrix()
public override operator fun Matrix<Double>.times(value: Double): EjmlDoubleMatrix<DMatrixRMaj> {
val res = this.toEjml().origin.copy()
CommonOps_DDRM.scale(value, res)
return res.wrapMatrix()
}
public override fun Point<Double>.unaryMinus(): EjmlVector =
toEjml().origin.negative().wrapVector()
public override fun Point<Double>.unaryMinus(): EjmlDoubleVector<DMatrixRMaj> {
val out = toEjml().origin.copy()
CommonOps_DDRM.changeSign(out)
return out.wrapVector()
}
public override fun Matrix<Double>.plus(other: Matrix<Double>): EjmlMatrix =
(toEjml().origin + other.toEjml().origin).wrapMatrix()
public override fun Matrix<Double>.plus(other: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
val out = DMatrixRMaj(1, 1)
CommonOps_DDRM.add(toEjml().origin, other.toEjml().origin, out)
return out.wrapMatrix()
}
public override fun Point<Double>.plus(other: Point<Double>): EjmlVector =
(toEjml().origin + other.toEjml().origin).wrapVector()
public override fun Point<Double>.plus(other: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
val out = DMatrixRMaj(1, 1)
CommonOps_DDRM.add(toEjml().origin, other.toEjml().origin, out)
return out.wrapVector()
}
public override fun Point<Double>.minus(other: Point<Double>): EjmlVector =
(toEjml().origin - other.toEjml().origin).wrapVector()
public override fun Point<Double>.minus(other: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
val out = DMatrixRMaj(1, 1)
CommonOps_DDRM.subtract(toEjml().origin, other.toEjml().origin, out)
return out.wrapVector()
}
public override fun Double.times(m: Matrix<Double>): EjmlMatrix =
m.toEjml().origin.scale(this).wrapMatrix()
public override fun Double.times(m: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> = m * this
public override fun Point<Double>.times(value: Double): EjmlVector =
toEjml().origin.scale(value).wrapVector()
public override fun Point<Double>.times(value: Double): EjmlDoubleVector<DMatrixRMaj> {
val res = this.toEjml().origin.copy()
CommonOps_DDRM.scale(value, res)
return res.wrapVector()
}
public override fun Double.times(v: Point<Double>): EjmlVector =
v.toEjml().origin.scale(this).wrapVector()
public override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this
@UnstableKMathAPI
public override fun <F : StructureFeature> getFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
//Return the feature if it is intrinsic to the structure
// Return the feature if it is intrinsic to the structure
structure.getFeature(type)?.let { return it }
val origin = structure.toEjml().origin
return when (type) {
InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
override val inverse: Matrix<Double> by lazy {
val res = origin.copy()
CommonOps_DDRM.invert(res)
EjmlDoubleMatrix(res)
}
}
DeterminantFeature::class -> object : DeterminantFeature<Double> {
override val determinant: Double by lazy(origin::determinant)
override val determinant: Double by lazy { CommonOps_DDRM.det(DMatrixRMaj(origin)) }
}
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
private val svd by lazy {
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
.apply { decompose(origin.ddrm.copy()) }
DecompositionFactory_DDRM.svd(origin.numRows, origin.numCols, true, true, false)
.apply { decompose(origin.copy()) }
}
override val u: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) }
override val s: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) }
override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
override val u: Matrix<Double> by lazy { EjmlDoubleMatrix(svd.getU(null, false)) }
override val s: Matrix<Double> by lazy { EjmlDoubleMatrix(svd.getW(null)) }
override val v: Matrix<Double> by lazy { EjmlDoubleMatrix(svd.getV(null, false)) }
override val singularValues: Point<Double> by lazy { DoubleBuffer(svd.singularValues) }
}
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy {
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
DecompositionFactory_DDRM.qr().apply { decompose(origin.copy()) }
}
override val q: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) + OrthogonalFeature
EjmlDoubleMatrix(qr.getQ(null, false)) + OrthogonalFeature
}
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) + UFeature }
override val r: Matrix<Double> by lazy { EjmlDoubleMatrix(qr.getR(null, false)) + UFeature }
}
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
override val l: Matrix<Double> by lazy {
val cholesky =
DecompositionFactory_DDRM.chol(structure.rowNum, true).apply { decompose(origin.ddrm.copy()) }
DecompositionFactory_DDRM.chol(structure.rowNum, true).apply { decompose(origin.copy()) }
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
EjmlDoubleMatrix(cholesky.getT(null)) + LFeature
}
}
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
private val lup by lazy {
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
.apply { decompose(origin.ddrm.copy()) }
DecompositionFactory_DDRM.lu(origin.numRows, origin.numCols).apply { decompose(origin.copy()) }
}
override val l: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature
EjmlDoubleMatrix(lup.getLower(null)) + LFeature
}
override val u: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature
EjmlDoubleMatrix(lup.getUpper(null)) + UFeature
}
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
override val p: Matrix<Double> by lazy { EjmlDoubleMatrix(lup.getRowPivot(null)) }
}
else -> null
}?.let(type::cast)
}
/**
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> &middot; [b]*.
*
* @param a the base matrix.
* @param b n by p matrix.
* @return the solution for 'x' that is n by p.
*/
public fun solve(a: Matrix<Double>, b: Matrix<Double>): EjmlDoubleMatrix<DMatrixRMaj> {
val res = DMatrixRMaj(1, 1)
CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res)
return EjmlDoubleMatrix(res)
}
/**
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> &middot; [b]*.
*
* @param a the base matrix.
* @param b n by p vector.
* @return the solution for 'x' that is n by p.
*/
public fun solve(a: Matrix<Double>, b: Point<Double>): EjmlDoubleVector<DMatrixRMaj> {
val res = DMatrixRMaj(1, 1)
CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res)
return EjmlDoubleVector(res)
}
}
/**
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> &middot; [b]*.
*
* @param a the base matrix.
* @param b n by p matrix.
* @return the solution for 'x' that is n by p.
* @author Iaroslav Postovalov
*/
public fun EjmlLinearSpace.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Solves for *x* in the following equation: *x = [a] <sup>-1</sup> &middot; [b]*.
*
* @param a the base matrix.
* @param b n by p vector.
* @return the solution for 'x' that is n by p.
* @author Iaroslav Postovalov
*/
public fun EjmlLinearSpace.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Inverts this matrix.
*
* @author Alexander Nozik
*/
@OptIn(UnstableKMathAPI::class)
public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature<InverseMatrixFeature<Double>>()!!.inverse as EjmlMatrix
/**
* Inverts the given matrix.
*
* @author Alexander Nozik
*/
public fun EjmlLinearSpace.inverse(matrix: Matrix<Double>): Matrix<Double> = matrix.toEjml().inverted()

View File

@ -5,18 +5,28 @@
package space.kscience.kmath.ejml
import org.ejml.simple.SimpleMatrix
import space.kscience.kmath.linear.Matrix
import org.ejml.data.DMatrix
import org.ejml.data.Matrix
import space.kscience.kmath.nd.Structure2D
/**
* The matrix implementation over EJML [SimpleMatrix].
* [space.kscience.kmath.linear.Matrix] implementation based on EJML [Matrix].
*
* @property origin the underlying [SimpleMatrix].
* @param T the type of elements contained in the buffer.
* @param M the type of EJML matrix.
* @property origin The underlying EJML matrix.
* @author Iaroslav Postovalov
*/
public class EjmlMatrix(public val origin: SimpleMatrix) : Matrix<Double> {
public override val rowNum: Int get() = origin.numRows()
public override val colNum: Int get() = origin.numCols()
public abstract class EjmlMatrix<T, out M : Matrix>(public open val origin: M) : Structure2D<T> {
public override val rowNum: Int get() = origin.numRows
public override val colNum: Int get() = origin.numCols
}
/**
* [EjmlMatrix] specialization for [Double].
*
* @author Iaroslav Postovalov
*/
public class EjmlDoubleMatrix<out M : DMatrix>(public override val origin: M) : EjmlMatrix<Double, M>(origin) {
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
}

View File

@ -5,35 +5,41 @@
package space.kscience.kmath.ejml
import org.ejml.simple.SimpleMatrix
import org.ejml.data.DMatrixD1
import org.ejml.data.Matrix
import space.kscience.kmath.linear.Point
/**
* Represents point over EJML [SimpleMatrix].
* [Point] implementation based on EJML [Matrix].
*
* @property origin the underlying [SimpleMatrix].
* @param T the type of elements contained in the buffer.
* @param M the type of EJML matrix.
* @property origin The underlying matrix.
* @author Iaroslav Postovalov
*/
public class EjmlVector internal constructor(public val origin: SimpleMatrix) : Point<Double> {
public abstract class EjmlVector<out T, out M : Matrix>(public open val origin: M) : Point<T> {
public override val size: Int
get() = origin.numRows()
get() = origin.numRows
init {
require(origin.numCols() == 1) { "Only single column matrices are allowed" }
}
public override operator fun get(index: Int): Double = origin[index]
public override operator fun iterator(): Iterator<Double> = object : Iterator<Double> {
public override operator fun iterator(): Iterator<T> = object : Iterator<T> {
private var cursor: Int = 0
override fun next(): Double {
override fun next(): T {
cursor += 1
return origin[cursor - 1]
return this@EjmlVector[cursor - 1]
}
override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows()
override fun hasNext(): Boolean = cursor < origin.numCols * origin.numRows
}
public override fun toString(): String = "EjmlVector(origin=$origin)"
}
/**
* [EjmlVector] specialization for [Double].
*
* @author Iaroslav Postovalov
*/
public class EjmlDoubleVector<out M : DMatrixD1>(public override val origin: M) : EjmlVector<Double, M>(origin) {
public override operator fun get(index: Int): Double = origin[index]
}

View File

@ -5,12 +5,15 @@
package space.kscience.kmath.ejml
import org.ejml.data.DMatrixRMaj
import org.ejml.dense.row.CommonOps_DDRM
import org.ejml.dense.row.RandomMatrices_DDRM
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix
import space.kscience.kmath.linear.*
import space.kscience.kmath.linear.DeterminantFeature
import space.kscience.kmath.linear.LupDecompositionFeature
import space.kscience.kmath.linear.getFeature
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.getFeature
import kotlin.random.Random
import kotlin.random.asJavaRandom
import kotlin.test.*
@ -22,65 +25,59 @@ fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T
internal class EjmlMatrixTest {
private val random = Random(0)
private val randomMatrix: SimpleMatrix
private val randomMatrix: DMatrixRMaj
get() {
val s = random.nextInt(2, 100)
return SimpleMatrix.random_DDRM(s, s, 0.0, 10.0, random.asJavaRandom())
val d = DMatrixRMaj(s, s)
RandomMatrices_DDRM.fillUniform(d, random.asJavaRandom())
return d
}
@Test
fun rowNum() {
val m = randomMatrix
assertEquals(m.numRows(), EjmlMatrix(m).rowNum)
assertEquals(m.numRows, EjmlDoubleMatrix(m).rowNum)
}
@Test
fun colNum() {
val m = randomMatrix
assertEquals(m.numCols(), EjmlMatrix(m).rowNum)
assertEquals(m.numCols, EjmlDoubleMatrix(m).rowNum)
}
@Test
fun shape() {
val m = randomMatrix
val w = EjmlMatrix(m)
assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList())
val w = EjmlDoubleMatrix(m)
assertContentEquals(intArrayOf(m.numRows, m.numCols), w.shape)
}
@OptIn(UnstableKMathAPI::class)
@Test
fun features() {
val m = randomMatrix
val w = EjmlMatrix(m)
val det: DeterminantFeature<Double> = EjmlLinearSpace.getFeature(w) ?: fail()
assertEquals(m.determinant(), det.determinant)
val lup: LupDecompositionFeature<Double> = EjmlLinearSpace.getFeature(w) ?: fail()
val w = EjmlDoubleMatrix(m)
val det: DeterminantFeature<Double> = EjmlLinearSpaceDDRM.getFeature(w) ?: fail()
assertEquals(CommonOps_DDRM.det(m), det.determinant)
val lup: LupDecompositionFeature<Double> = EjmlLinearSpaceDDRM.getFeature(w) ?: fail()
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols())
.also { it.decompose(m.ddrm.copy()) }
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows, m.numCols)
.also { it.decompose(m.copy()) }
assertMatrixEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), lup.l)
assertMatrixEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), lup.u)
assertMatrixEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), lup.p)
}
private object SomeFeature : MatrixFeature {}
@OptIn(UnstableKMathAPI::class)
@Test
fun suggestFeature() {
assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>())
assertMatrixEquals(EjmlDoubleMatrix(ludecompositionF64.getLower(null)), lup.l)
assertMatrixEquals(EjmlDoubleMatrix(ludecompositionF64.getUpper(null)), lup.u)
assertMatrixEquals(EjmlDoubleMatrix(ludecompositionF64.getRowPivot(null)), lup.p)
}
@Test
fun get() {
val m = randomMatrix
assertEquals(m[0, 0], EjmlMatrix(m)[0, 0])
assertEquals(m[0, 0], EjmlDoubleMatrix(m)[0, 0])
}
@Test
fun origin() {
val m = randomMatrix
assertSame(m, EjmlMatrix(m).origin)
assertSame(m, EjmlDoubleMatrix(m).origin)
}
}

View File

@ -5,7 +5,8 @@
package space.kscience.kmath.ejml
import org.ejml.simple.SimpleMatrix
import org.ejml.data.DMatrixRMaj
import org.ejml.dense.row.RandomMatrices_DDRM
import kotlin.random.Random
import kotlin.random.asJavaRandom
import kotlin.test.Test
@ -15,30 +16,34 @@ import kotlin.test.assertSame
internal class EjmlVectorTest {
private val random = Random(0)
private val randomMatrix: SimpleMatrix
get() = SimpleMatrix.random_DDRM(random.nextInt(2, 100), 1, 0.0, 10.0, random.asJavaRandom())
private val randomMatrix: DMatrixRMaj
get() {
val d = DMatrixRMaj(random.nextInt(2, 100), 1)
RandomMatrices_DDRM.fillUniform(d, random.asJavaRandom())
return d
}
@Test
fun size() {
val m = randomMatrix
val w = EjmlVector(m)
assertEquals(m.numRows(), w.size)
val w = EjmlDoubleVector(m)
assertEquals(m.numRows, w.size)
}
@Test
fun get() {
val m = randomMatrix
val w = EjmlVector(m)
val w = EjmlDoubleVector(m)
assertEquals(m[0, 0], w[0])
}
@Test
fun iterator() {
val m = randomMatrix
val w = EjmlVector(m)
val w = EjmlDoubleVector(m)
assertEquals(
m.iterator(true, 0, 0, m.numRows() - 1, 0).asSequence().toList(),
m.iterator(true, 0, 0, m.numRows - 1, 0).asSequence().toList(),
w.iterator().asSequence().toList()
)
}
@ -46,7 +51,7 @@ internal class EjmlVectorTest {
@Test
fun origin() {
val m = randomMatrix
val w = EjmlVector(m)
val w = EjmlDoubleVector(m)
assertSame(m, w.origin)
}
}

View File

@ -15,8 +15,7 @@ The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-d
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -27,8 +26,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {

View File

@ -17,8 +17,7 @@ The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -29,8 +28,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {

View File

@ -0,0 +1,19 @@
plugins {
id("ru.mipt.npm.gradle.jvm")
kotlin("jupyter.api")
}
dependencies {
api(project(":kmath-ast"))
api(project(":kmath-complex"))
api(project(":kmath-for-real"))
implementation("org.jetbrains.kotlinx:kotlinx-html-jvm:0.7.3")
}
readme {
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
}
kotlin.sourceSets.all {
languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI")
}

View File

@ -0,0 +1,120 @@
package space.kscience.kmath.jupyter
import kotlinx.html.Unsafe
import kotlinx.html.div
import kotlinx.html.stream.createHTML
import kotlinx.html.unsafe
import org.jetbrains.kotlinx.jupyter.api.DisplayResult
import org.jetbrains.kotlinx.jupyter.api.HTML
import org.jetbrains.kotlinx.jupyter.api.annotations.JupyterLibrary
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess
import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer
import space.kscience.kmath.ast.rendering.renderWithStringBuilder
import space.kscience.kmath.complex.Complex
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.RingOperations
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.asSequence
@JupyterLibrary
internal class KMathJupyter : JupyterIntegration() {
private val mathRender = FeaturedMathRendererWithPostProcess.Default
private val syntaxRender = MathMLSyntaxRenderer
override fun Builder.onLoaded() {
import(
"space.kscience.kmath.ast.*",
"space.kscience.kmath.ast.rendering.*",
"space.kscience.kmath.operations.*",
"space.kscience.kmath.expressions.*",
"space.kscience.kmath.misc.*",
"space.kscience.kmath.real.*",
)
fun MST.toDisplayResult(): DisplayResult = HTML(createHTML().div {
unsafe {
+syntaxRender.renderWithStringBuilder(mathRender.render(this@toDisplayResult))
}
})
render<MST> { it.toDisplayResult() }
render<Number> { MST.Numeric(it).toDisplayResult() }
fun Unsafe.appendCellValue(it: Any?) {
when (it) {
is Number -> {
val s = StringBuilder()
syntaxRender.renderPart(mathRender.render(MST.Numeric(it)), s)
+s.toString()
}
is MST -> {
val s = StringBuilder()
syntaxRender.renderPart(mathRender.render(it), s)
+s.toString()
}
else -> {
+"<ms>"
+it.toString()
+"</ms>"
}
}
}
render<Structure2D<*>> { structure ->
HTML(createHTML().div {
unsafe {
+"<math xmlns=\"https://www.w3.org/1998/Math/MathML\">"
+"<mrow>"
+"<mfenced open=\"[\" close=\"]\" separators=\"\">"
+"<mtable>"
structure.rows.forEach { row ->
+"<mtr>"
row.asSequence().forEach {
+"<mtd>"
appendCellValue(it)
+"</mtd>"
}
+"</mtr>"
}
+"</mtable>"
+"</mfenced>"
+"</mrow>"
+"</math>"
}
})
}
render<Buffer<*>> { buffer ->
HTML(createHTML().div {
unsafe {
+"<math xmlns=\"https://www.w3.org/1998/Math/MathML\">"
+"<mrow>"
+"<mfenced open=\"[\" close=\"]\" separators=\"\">"
+"<mtable>"
buffer.asSequence().forEach {
+"<mtr>"
+"<mtd>"
appendCellValue(it)
+"</mtd>"
+"</mtr>"
}
+"</mtable>"
+"</mfenced>"
+"</mrow>"
+"</math>"
}
})
}
render<Complex> {
MST.Binary(
operation = GroupOperations.PLUS_OPERATION,
left = MST.Numeric(it.re),
right = MST.Binary(RingOperations.TIMES_OPERATION, MST.Numeric(it.im), MST.Symbolic("i")),
).toDisplayResult()
}
}
}

View File

@ -4,8 +4,8 @@ plugins {
}
dependencies {
implementation("com.github.breandan:kaliningraph:0.1.4")
implementation("com.github.breandan:kotlingrad:0.4.0")
api("com.github.breandan:kaliningraph:0.1.4")
api("com.github.breandan:kotlingrad:0.4.5")
api(project(":kmath-ast"))
}

View File

@ -15,8 +15,7 @@ The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-7
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
mavenCentral()
}
dependencies {
@ -27,8 +26,7 @@ dependencies {
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
mavenCentral()
}
dependencies {

40
kmath-tensors/README.md Normal file
View File

@ -0,0 +1,40 @@
# Module kmath-tensors
Common operations on tensors, the API consists of:
- [TensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic algebra operations on tensors (plus, dot, etc.)
- [TensorPartialDivisionAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt) : Emulates an algebra over a field
- [LinearOpsTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Linear algebra operations including LU, QR, Cholesky LL and SVD decompositions
- [AnalyticTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt) : Element-wise analytic operations
The library offers a multiplatform implementation for this interface over the `Double`'s. As a highlight, the user can find:
- [BroadcastDoubleTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic algebra operations implemented with broadcasting.
- [DoubleLinearOpsTensorAlgebra](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt) : Contains the power method for SVD and the spectrum of symmetric matrices.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-7`.
**Gradle:**
```gradle
repositories {
maven { url 'https://repo.kotlin.link' }
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
}
dependencies {
implementation 'space.kscience:kmath-tensors:0.3.0-dev-7'
}
```
**Gradle Kotlin DSL:**
```kotlin
repositories {
maven("https://repo.kotlin.link")
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
}
dependencies {
implementation("space.kscience:kmath-tensors:0.3.0-dev-7")
}
```

View File

@ -0,0 +1,43 @@
plugins {
id("ru.mipt.npm.gradle.mpp")
}
kotlin.sourceSets {
all {
languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI")
}
commonMain {
dependencies {
api(project(":kmath-core"))
api(project(":kmath-stat"))
}
}
}
tasks.dokkaHtml {
dependsOn(tasks.build)
}
readme {
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature(
id = "tensor algebra",
description = "Basic linear algebra operations on tensors (plus, dot, etc.)",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt"
)
feature(
id = "tensor algebra with broadcasting",
description = "Basic linear algebra operations implemented with broadcasting.",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt"
)
feature(
id = "linear algebra operations",
description = "Advanced linear algebra operations like LU decomposition, SVD, etc.",
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt"
)
}

View File

@ -0,0 +1,7 @@
# Module kmath-tensors
Common linear algebra operations on tensors.
${features}
${artifact}

View File

@ -0,0 +1,131 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
/**
* Analytic operations on [Tensor].
*
* @param T the type of items closed under analytic functions in the tensors.
*/
public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
/**
* @return the mean of all elements in the input tensor.
*/
public fun Tensor<T>.mean(): T
/**
* Returns the mean of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the mean of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the standard deviation of all elements in the input tensor.
*/
public fun Tensor<T>.std(): T
/**
* Returns the standard deviation of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the standard deviation of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the variance of all elements in the input tensor.
*/
public fun Tensor<T>.variance(): T
/**
* Returns the variance of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the variance of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the covariance matrix M of given vectors.
*
* M[i, j] contains covariance of i-th and j-th given vectors
*
* @param tensors the [List] of 1-dimensional tensors with same shape
* @return the covariance matrix
*/
public fun cov(tensors: List<Tensor<T>>): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html
public fun Tensor<T>.exp(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.log.html
public fun Tensor<T>.ln(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
public fun Tensor<T>.sqrt(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
public fun Tensor<T>.cos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
public fun Tensor<T>.acos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
public fun Tensor<T>.cosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
public fun Tensor<T>.acosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
public fun Tensor<T>.sin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
public fun Tensor<T>.asin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
public fun Tensor<T>.sinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
public fun Tensor<T>.asinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
public fun Tensor<T>.tan(): Tensor<T>
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
public fun Tensor<T>.atan(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
public fun Tensor<T>.tanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
public fun Tensor<T>.atanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
public fun Tensor<T>.ceil(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
public fun Tensor<T>.floor(): Tensor<T>
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
/**
* Common linear algebra operations. Operates on [Tensor].
*
* @param T the type of items closed under division in the tensors.
*/
public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
/**
* Computes the determinant of a square matrix input, or of each square matrix in a batched input.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
*
* @return the determinant.
*/
public fun Tensor<T>.det(): Tensor<T>
/**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
* Given a square matrix `A`, return the matrix `AInv` satisfying
* `A dot AInv = AInv dot A = eye(a.shape[0])`.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
*
* @return the multiplicative inverse of a matrix.
*/
public fun Tensor<T>.inv(): Tensor<T>
/**
* Cholesky decomposition.
*
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
* Each decomposition has the form:
* Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`,
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
* which is just a transpose for the case of real-valued input matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
*
* @return the batch of L matrices.
*/
public fun Tensor<T>.cholesky(): Tensor<T>
/**
* QR decomposition.
*
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``,
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
*
* @return pair of Q and R tensors.
*/
public fun Tensor<T>.qr(): Pair<Tensor<T>, Tensor<T>>
/**
* LUP decomposition
*
* Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying `P dot input = L dot U`,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* * @return triple of P, L and U tensors
*/
public fun Tensor<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/**
* Singular Value Decomposition.
*
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
* The singular value decomposition is represented as a triple `(U, S, V)`,
* such that `input = U dot diagonalEmbedding(S) dot V.H`,
* where V.H is the conjugate transpose of V.
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
*
* @return triple `(U, S, V)`.
*/
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
* represented by a pair (eigenvalues, eigenvectors).
* For more information: https://pytorch.org/docs/stable/generated/torch.symeig.html
*
* @return a pair (eigenvalues, eigenvectors)
*/
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
}

View File

@ -0,0 +1,5 @@
package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.MutableStructureND
public typealias Tensor<T> = MutableStructureND<T>

View File

@ -0,0 +1,327 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Algebra
/**
* Algebra over a ring on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Algebra_over_Ring
*
* @param T the type of items in the tensors.
*/
public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
/**
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
*
* @return a nullable value of a potentially scalar tensor.
*/
public fun Tensor<T>.valueOrNull(): T?
/**
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
*
* @return the value of a scalar tensor.
*/
public fun Tensor<T>.value(): T
/**
* Each element of the tensor [other] is added to this value.
* The resulting tensor is returned.
*
* @param other tensor to be added.
* @return the sum of this value and tensor [other].
*/
public operator fun T.plus(other: Tensor<T>): Tensor<T>
/**
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor.
*
* @param value the number to be added to each element of this tensor.
* @return the sum of this tensor and [value].
*/
public operator fun Tensor<T>.plus(value: T): Tensor<T>
/**
* Each element of the tensor [other] is added to each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be added.
* @return the sum of this tensor and [other].
*/
public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
/**
* Adds the scalar [value] to each element of this tensor.
*
* @param value the number to be added to each element of this tensor.
*/
public operator fun Tensor<T>.plusAssign(value: T): Unit
/**
* Each element of the tensor [other] is added to each element of this tensor.
*
* @param other tensor to be added.
*/
public operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit
/**
* Each element of the tensor [other] is subtracted from this value.
* The resulting tensor is returned.
*
* @param other tensor to be subtracted.
* @return the difference between this value and tensor [other].
*/
public operator fun T.minus(other: Tensor<T>): Tensor<T>
/**
* Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor.
*
* @param value the number to be subtracted from each element of this tensor.
* @return the difference between this tensor and [value].
*/
public operator fun Tensor<T>.minus(value: T): Tensor<T>
/**
* Each element of the tensor [other] is subtracted from each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be subtracted.
* @return the difference between this tensor and [other].
*/
public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
/**
* Subtracts the scalar [value] from each element of this tensor.
*
* @param value the number to be subtracted from each element of this tensor.
*/
public operator fun Tensor<T>.minusAssign(value: T): Unit
/**
* Each element of the tensor [other] is subtracted from each element of this tensor.
*
* @param other tensor to be subtracted.
*/
public operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit
/**
* Each element of the tensor [other] is multiplied by this value.
* The resulting tensor is returned.
*
* @param other tensor to be multiplied.
* @return the product of this value and tensor [other].
*/
public operator fun T.times(other: Tensor<T>): Tensor<T>
/**
* Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor.
*
* @param value the number to be multiplied by each element of this tensor.
* @return the product of this tensor and [value].
*/
public operator fun Tensor<T>.times(value: T): Tensor<T>
/**
* Each element of the tensor [other] is multiplied by each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be multiplied.
* @return the product of this tensor and [other].
*/
public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
/**
* Multiplies the scalar [value] by each element of this tensor.
*
* @param value the number to be multiplied by each element of this tensor.
*/
public operator fun Tensor<T>.timesAssign(value: T): Unit
/**
* Each element of the tensor [other] is multiplied by each element of this tensor.
*
* @param other tensor to be multiplied.
*/
public operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit
/**
* Numerical negative, element-wise.
*
* @return tensor negation of the original tensor.
*/
public operator fun Tensor<T>.unaryMinus(): Tensor<T>
/**
* Returns the tensor at index i
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
*
* @param i index of the extractable tensor
* @return subtensor of the original tensor with index [i]
*/
public operator fun Tensor<T>.get(i: Int): Tensor<T>
/**
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
* For more information: https://pytorch.org/docs/stable/generated/torch.transpose.html
*
* @param i the first dimension to be transposed
* @param j the second dimension to be transposed
* @return transposed tensor
*/
public fun Tensor<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
/**
* Returns a new tensor with the same data as the self tensor but of a different shape.
* The returned tensor shares the same data and must have the same number of elements, but may have a different size
* For more information: https://pytorch.org/docs/stable/tensor_view.html
*
* @param shape the desired size
* @return tensor with new shape
*/
public fun Tensor<T>.view(shape: IntArray): Tensor<T>
/**
* View this tensor as the same size as [other].
* ``this.viewAs(other) is equivalent to this.view(other.shape)``.
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
*
* @param other the result tensor has the same size as other.
* @return the result tensor with the same size as other.
*/
public fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T>
/**
* Matrix product of two tensors.
*
* The behavior depends on the dimensionality of the tensors as follows:
* 1. If both tensors are 1-dimensional, the dot product (scalar) is returned.
*
* 2. If both arguments are 2-dimensional, the matrix-matrix product is returned.
*
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
* After the matrix multiply, the prepended dimension is removed.
*
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
* the matrix-vector product is returned.
*
* 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2),
* then a batched matrix multiply is returned. If the first argument is 1-dimensional,
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
* multiple and removed after.
* The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).
* For example, if `input` is a (j &times; 1 &times; n &times; n) tensor and `other` is a
* (k &times; n &times; n) tensor, out will be a (j &times; k &times; n &times; n) tensor.
*
* For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html
*
* @param other tensor to be multiplied
* @return mathematical product of two tensors
*/
public infix fun Tensor<T>.dot(other: Tensor<T>): Tensor<T>
/**
* Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
* are filled by [diagonalEntries].
* To facilitate creating batched diagonal matrices,
* the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
*
* The argument [offset] controls which diagonal to consider:
* 1. If [offset] = 0, it is the main diagonal.
* 1. If [offset] > 0, it is above the main diagonal.
* 1. If [offset] < 0, it is below the main diagonal.
*
* The size of the new matrix will be calculated
* to make the specified diagonal of the size of the last input dimension.
* For more information: https://pytorch.org/docs/stable/generated/torch.diag_embed.html
*
* @param diagonalEntries the input tensor. Must be at least 1-dimensional.
* @param offset which diagonal to consider. Default: 0 (main diagonal).
* @param dim1 first dimension with respect to which to take diagonal. Default: -2.
* @param dim2 second dimension with respect to which to take diagonal. Default: -1.
*
* @return tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
* are filled by [diagonalEntries]
*/
public fun diagonalEmbedding(
diagonalEntries: Tensor<T>,
offset: Int = 0,
dim1: Int = -2,
dim2: Int = -1
): Tensor<T>
/**
* @return the sum of all elements in the input tensor.
*/
public fun Tensor<T>.sum(): T
/**
* Returns the sum of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the sum of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the minimum value of all elements in the input tensor.
*/
public fun Tensor<T>.min(): T
/**
* Returns the minimum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the minimum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the maximum value of all elements in the input tensor.
*/
public fun Tensor<T>.max(): T
/**
* Returns the maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
/**
* Algebra over a field with partial division on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Division_Algebra
*
* @param T the type of items closed under division in the tensors.
*/
public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> {
/**
* Each element of the tensor [other] is divided by this value.
* The resulting tensor is returned.
*
* @param other tensor to divide by.
* @return the division of this value by the tensor [other].
*/
public operator fun T.div(other: Tensor<T>): Tensor<T>
/**
* Divide by the scalar [value] each element of this tensor returns a new resulting tensor.
*
* @param value the number to divide by each element of this tensor.
* @return the division of this tensor by the [value].
*/
public operator fun Tensor<T>.div(value: T): Tensor<T>
/**
* Each element of the tensor [other] is divided by each element of this tensor.
* The resulting tensor is returned.
*
* @param other tensor to be divided by.
* @return the division of this tensor by [other].
*/
public operator fun Tensor<T>.div(other: Tensor<T>): Tensor<T>
/**
* Divides by the scalar [value] each element of this tensor.
*
* @param value the number to divide by each element of this tensor.
*/
public operator fun Tensor<T>.divAssign(value: T)
/**
* Each element of this tensor is divided by each element of the [other] tensor.
*
* @param other tensor to be divide by.
*/
public operator fun Tensor<T>.divAssign(other: Tensor<T>)
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.broadcastTensors
import space.kscience.kmath.tensors.core.internal.broadcastTo
import space.kscience.kmath.tensors.core.internal.tensor
/**
* Basic linear algebra operations implemented with broadcasting.
* For more information: https://pytorch.org/docs/stable/notes/broadcasting.html
*/
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newThis.bufferStart + i] *
newOther.mutableBuffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i ->
newThis.mutableBuffer.array()[newOther.bufferStart + i] /
newOther.mutableBuffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
newOther.mutableBuffer.array()[tensor.bufferStart + i]
}
}
}

View File

@ -0,0 +1,38 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.Strides
import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
/**
* Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor]
*/
public open class BufferedTensor<T> internal constructor(
override val shape: IntArray,
internal val mutableBuffer: MutableBuffer<T>,
internal val bufferStart: Int
) : Tensor<T> {
/**
* Buffer strides based on [TensorLinearStructure] implementation
*/
public val linearStructure: Strides
get() = TensorLinearStructure(shape)
/**
* Number of elements in tensor
*/
public val numElements: Int
get() = linearStructure.linearSize
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
override fun set(index: IntArray, value: T) {
mutableBuffer[bufferStart + linearStructure.offset(index)] = value
}
override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map {
it to this[it]
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.tensors.core.internal.toPrettyString
/**
* Default [BufferedTensor] implementation for [Double] values
*/
public class DoubleTensor internal constructor(
shape: IntArray,
buffer: DoubleArray,
offset: Int = 0
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) {
override fun toString(): String = toPrettyString()
}

View File

@ -0,0 +1,937 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.dotHelper
import space.kscience.kmath.tensors.core.internal.getRandomNormals
import space.kscience.kmath.tensors.core.internal.*
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
import space.kscience.kmath.tensors.core.internal.checkBufferShapeConsistency
import space.kscience.kmath.tensors.core.internal.checkEmptyDoubleBuffer
import space.kscience.kmath.tensors.core.internal.checkEmptyShape
import space.kscience.kmath.tensors.core.internal.checkShapesCompatible
import space.kscience.kmath.tensors.core.internal.checkSquareMatrix
import space.kscience.kmath.tensors.core.internal.checkTranspose
import space.kscience.kmath.tensors.core.internal.checkView
import space.kscience.kmath.tensors.core.internal.minusIndexFrom
import kotlin.math.*
/**
* Implementation of basic operations over double tensors and basic algebra operations on them.
*/
public open class DoubleTensorAlgebra :
TensorPartialDivisionAlgebra<Double>,
AnalyticTensorAlgebra<Double>,
LinearOpsTensorAlgebra<Double> {
public companion object : DoubleTensorAlgebra()
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null
override fun Tensor<Double>.value(): Double =
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
/**
* Constructs a tensor with the specified shape and data.
*
* @param shape the desired shape for the tensor.
* @param buffer one-dimensional data array.
* @return tensor with the [shape] shape and [buffer] data.
*/
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
checkEmptyShape(shape)
checkEmptyDoubleBuffer(buffer)
checkBufferShapeConsistency(shape, buffer)
return DoubleTensor(shape, buffer, 0)
}
/**
* Constructs a tensor with the specified shape and initializer.
*
* @param shape the desired shape for the tensor.
* @param initializer mapping tensor indices to values.
* @return tensor with the [shape] shape and data generated by the [initializer].
*/
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor =
fromArray(
shape,
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray()
)
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
return DoubleTensor(newShape, tensor.mutableBuffer.array(), newStart)
}
/**
* Creates a tensor of a given shape and fills all elements with a given value.
*
* @param value the value to fill the output tensor with.
* @param shape array of integers defining the shape of the output tensor.
* @return tensor with the [shape] shape and filled with [value].
*/
public fun full(value: Double, shape: IntArray): DoubleTensor {
checkEmptyShape(shape)
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
return DoubleTensor(shape, buffer)
}
/**
* Returns a tensor with the same shape as `input` filled with [value].
*
* @param value the value to fill the output tensor with.
* @return tensor with the `input` tensor shape and filled with [value].
*/
public fun Tensor<Double>.fullLike(value: Double): DoubleTensor {
val shape = tensor.shape
val buffer = DoubleArray(tensor.numElements) { value }
return DoubleTensor(shape, buffer)
}
/**
* Returns a tensor filled with the scalar value 0.0, with the shape defined by the variable argument [shape].
*
* @param shape array of integers defining the shape of the output tensor.
* @return tensor filled with the scalar value 0.0, with the [shape] shape.
*/
public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
/**
* Returns a tensor filled with the scalar value 0.0, with the same shape as a given array.
*
* @return tensor filled with the scalar value 0.0, with the same shape as `input` tensor.
*/
public fun Tensor<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0)
/**
* Returns a tensor filled with the scalar value 1.0, with the shape defined by the variable argument [shape].
*
* @param shape array of integers defining the shape of the output tensor.
* @return tensor filled with the scalar value 1.0, with the [shape] shape.
*/
public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
/**
* Returns a tensor filled with the scalar value 1.0, with the same shape as a given array.
*
* @return tensor filled with the scalar value 1.0, with the same shape as `input` tensor.
*/
public fun Tensor<Double>.onesLike(): DoubleTensor = tensor.fullLike(1.0)
/**
* Returns a 2-D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere.
*
* @param n the number of rows and columns
* @return a 2-D tensor with ones on the diagonal and zeros elsewhere.
*/
public fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 }
val res = DoubleTensor(shape, buffer)
for (i in 0 until n) {
res[intArrayOf(i, i)] = 1.0
}
return res
}
/**
* Return a copy of the tensor.
*
* @return a copy of the `input` tensor with a copied buffer.
*/
public fun Tensor<Double>.copy(): DoubleTensor {
return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
}
override fun Double.plus(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.plus(value: Double): DoubleTensor = value + tensor
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other.tensor)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.plusAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += value
}
}
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other.tensor)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Double.minus(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - value
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.minusAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value
}
}
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Double.times(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.times(value: Double): DoubleTensor = value * tensor
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] *
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.timesAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value
}
}
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Double.div(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(other.shape, resBuffer)
}
override fun Tensor<Double>.div(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] / value
}
return DoubleTensor(shape, resBuffer)
}
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.divAssign(value: Double) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value
}
}
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Tensor<Double>.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
}
return DoubleTensor(tensor.shape, resBuffer)
}
override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor {
val ii = tensor.minusIndex(i)
val jj = tensor.minusIndex(j)
checkTranspose(tensor.dimension, ii, jj)
val n = tensor.numElements
val resBuffer = DoubleArray(n)
val resShape = tensor.shape.copyOf()
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
val resTensor = DoubleTensor(resShape, resBuffer)
for (offset in 0 until n) {
val oldMultiIndex = tensor.linearStructure.index(offset)
val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + offset]
}
return resTensor
}
override fun Tensor<Double>.view(shape: IntArray): DoubleTensor {
checkView(tensor, shape)
return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
}
override fun Tensor<Double>.viewAs(other: Tensor<Double>): DoubleTensor =
tensor.view(other.shape)
override infix fun Tensor<Double>.dot(other: Tensor<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
}
var newThis = tensor.copy()
var newOther = other.copy()
var penultimateDim = false
var lastDim = false
if (tensor.shape.size == 1) {
penultimateDim = true
newThis = tensor.view(intArrayOf(1) + tensor.shape)
}
if (other.shape.size == 1) {
lastDim = true
newOther = other.tensor.view(other.shape + intArrayOf(1))
}
val broadcastTensors = broadcastOuterTensors(newThis.tensor, newOther.tensor)
newThis = broadcastTensors[0]
newOther = broadcastTensors[1]
val l = newThis.shape[newThis.shape.size - 2]
val m1 = newThis.shape[newThis.shape.size - 1]
val m2 = newOther.shape[newOther.shape.size - 2]
val n = newOther.shape[newOther.shape.size - 1]
check(m1 == m2) {
"Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)"
}
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
val resSize = resShape.reduce { acc, i -> acc * i }
val resTensor = DoubleTensor(resShape, DoubleArray(resSize))
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
val (a, b) = ab
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
}
if (penultimateDim) {
return resTensor.view(
resTensor.shape.dropLast(2).toIntArray() +
intArrayOf(resTensor.shape.last())
)
}
if (lastDim) {
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
}
return resTensor
}
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int):
DoubleTensor {
val n = diagonalEntries.shape.size
val d1 = minusIndexFrom(n + 1, dim1)
val d2 = minusIndexFrom(n + 1, dim2)
check(d1 != d2) {
"Diagonal dimensions cannot be identical $d1, $d2"
}
check(d1 <= n && d2 <= n) {
"Dimension out of range"
}
var lessDim = d1
var greaterDim = d2
var realOffset = offset
if (lessDim > greaterDim) {
realOffset *= -1
lessDim = greaterDim.also { greaterDim = lessDim }
}
val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() +
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
diagonalEntries.shape.slice(lessDim until greaterDim - 1).toIntArray() +
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
val resTensor = zeros(resShape)
for (i in 0 until diagonalEntries.tensor.numElements) {
val multiIndex = diagonalEntries.tensor.linearStructure.index(i)
var offset1 = 0
var offset2 = abs(realOffset)
if (realOffset < 0) {
offset1 = offset2.also { offset2 = offset1 }
}
val diagonalMultiIndex = multiIndex.slice(0 until lessDim).toIntArray() +
intArrayOf(multiIndex[n - 1] + offset1) +
multiIndex.slice(lessDim until greaterDim - 1).toIntArray() +
intArrayOf(multiIndex[n - 1] + offset2) +
multiIndex.slice(greaterDim - 1 until n - 1).toIntArray()
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
}
return resTensor.tensor
}
/**
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
*
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
public fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor(
tensor.shape,
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
tensor.bufferStart
)
}
/**
* Compares element-wise two tensors with a specified precision.
*
* @param other the tensor to compare with `input` tensor.
* @param epsilon permissible error when comparing two Double values.
* @return true if two tensors have the same shape and elements, false otherwise.
*/
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
tensor.eq(other) { x, y -> abs(x - y) < epsilon }
/**
* Compares element-wise two tensors.
* Comparison of two Double values occurs with 1e-5 precision.
*
* @param other the tensor to compare with `input` tensor.
* @return true if two tensors have the same shape and elements, false otherwise.
*/
public infix fun Tensor<Double>.eq(other: Tensor<Double>): Boolean = tensor.eq(other, 1e-5)
private fun Tensor<Double>.eq(
other: Tensor<Double>,
eqFunction: (Double, Double) -> Boolean
): Boolean {
checkShapesCompatible(tensor, other)
val n = tensor.numElements
if (n != other.tensor.numElements) {
return false
}
for (i in 0 until n) {
if (!eqFunction(
tensor.mutableBuffer[tensor.bufferStart + i],
other.tensor.mutableBuffer[other.tensor.bufferStart + i]
)
) {
return false
}
}
return true
}
/**
* Returns a tensor of random numbers drawn from normal distributions with 0.0 mean and 1.0 standard deviation.
*
* @param shape the desired shape for the output tensor.
* @param seed the random seed of the pseudo-random number generator.
* @return tensor of a given shape filled with numbers from the normal distribution
* with 0.0 mean and 1.0 standard deviation.
*/
public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
/**
* Returns a tensor with the same shape as `input` of random numbers drawn from normal distributions
* with 0.0 mean and 1.0 standard deviation.
*
* @param seed the random seed of the pseudo-random number generator.
* @return tensor with the same shape as `input` filled with numbers from the normal distribution
* with 0.0 mean and 1.0 standard deviation.
*/
public fun Tensor<Double>.randomNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
/**
* Concatenates a sequence of tensors with equal shapes along the first dimension.
*
* @param tensors the [List] of tensors with same shapes to concatenate
* @return tensor with concatenation result
*/
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = intArrayOf(tensors.size) + shape
val resBuffer = tensors.flatMap {
it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements)
}.toDoubleArray()
return DoubleTensor(resShape, resBuffer, 0)
}
/**
* Builds tensor from rows of input tensor
*
* @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to rows by [indices]
*/
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
return stack(indices.map { this[it] })
}
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray())
internal fun Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double,
dim: Int,
keepDim: Boolean
): DoubleTensor {
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val resShape = if (keepDim) {
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
} else {
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
}
val resNumElements = resShape.reduce(Int::times)
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
for (index in resTensor.linearStructure.indices()) {
val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
tensor[prefix + intArrayOf(i) + suffix]
})
}
return resTensor
}
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.sum() }, dim, keepDim)
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x ->
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
}, dim, keepDim)
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
override fun Tensor<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim(
{ arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" }
arr.sum() / shape[dim]
},
dim,
keepDim
)
override fun Tensor<Double>.std(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1))
}
override fun Tensor<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim]
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1))
},
dim,
keepDim
)
override fun Tensor<Double>.variance(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements
arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)
}
override fun Tensor<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim]
arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1)
},
dim,
keepDim
)
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
val n = x.shape[0]
return ((x - x.mean()) * (y - y.mean())).mean() * n / (n - 1)
}
override fun cov(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val n = tensors.size
val m = tensors[0].shape[0]
check(tensors.all { it.shape contentEquals intArrayOf(m) }) { "Tensors must have same shapes" }
val resTensor = DoubleTensor(
intArrayOf(n, n),
DoubleArray(n * n) { 0.0 }
)
for (i in 0 until n) {
for (j in 0 until n) {
resTensor[intArrayOf(i, j)] = cov(tensors[i].tensor, tensors[j].tensor)
}
}
return resTensor
}
override fun Tensor<Double>.exp(): DoubleTensor = tensor.map(::exp)
override fun Tensor<Double>.ln(): DoubleTensor = tensor.map(::ln)
override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map(::sqrt)
override fun Tensor<Double>.cos(): DoubleTensor = tensor.map(::cos)
override fun Tensor<Double>.acos(): DoubleTensor = tensor.map(::acos)
override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map(::cosh)
override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map(::acosh)
override fun Tensor<Double>.sin(): DoubleTensor = tensor.map(::sin)
override fun Tensor<Double>.asin(): DoubleTensor = tensor.map(::asin)
override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map(::sinh)
override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map(::asinh)
override fun Tensor<Double>.tan(): DoubleTensor = tensor.map(::tan)
override fun Tensor<Double>.atan(): DoubleTensor = tensor.map(::atan)
override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map(::tanh)
override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map(::atanh)
override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map(::ceil)
override fun Tensor<Double>.floor(): DoubleTensor = tensor.map(::floor)
override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9)
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9)
/**
* Computes the LU factorization of a matrix or batches of matrices `input`.
* Returns a tuple containing the LU factorization and pivots of `input`.
*
* @param epsilon permissible error when comparing the determinant of a matrix with zero
* @return pair of `factorization` and `pivots`.
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon)
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
/**
* Computes the LU factorization of a matrix or batches of matrices `input`.
* Returns a tuple containing the LU factorization and pivots of `input`.
* Uses an error of ``1e-9`` when calculating whether a matrix is degenerate.
*
* @return pair of `factorization` and `pivots`.
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
/**
* Unpacks the data and pivots from a LU factorization of a tensor.
* Given a tensor [luTensor], return tensors (P, L, U) satisfying ``P * luTensor = L * U``,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* @param luTensor the packed LU factorization data
* @param pivotsTensor the packed LU factorization pivots
* @return triple of P, L and U tensors
*/
public fun luPivot(
luTensor: Tensor<Double>,
pivotsTensor: Tensor<Int>
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape)
check(
luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() ||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
) { "Inappropriate shapes of input tensors" }
val n = luTensor.shape.last()
val pTensor = luTensor.zeroesLike()
pTensor
.matrixSequence()
.zip(pivotsTensor.tensor.vectorSequence())
.forEach { (p, pivot) -> pivInit(p.as2D(), pivot.as1D(), n) }
val lTensor = luTensor.zeroesLike()
val uTensor = luTensor.zeroesLike()
lTensor.matrixSequence()
.zip(uTensor.matrixSequence())
.zip(luTensor.tensor.matrixSequence())
.forEach { (pairLU, lu) ->
val (l, u) = pairLU
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
}
return Triple(pTensor, lTensor, uTensor)
}
/**
* QR decomposition.
*
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
*
* @param epsilon permissible error when comparing tensors for equality.
* Used when checking the positive definiteness of the input matrix or matrices.
* @return pair of Q and R tensors.
*/
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape)
checkPositiveDefinite(tensor, epsilon)
val n = shape.last()
val lTensor = zeroesLike()
for ((a, l) in tensor.matrixSequence().zip(lTensor.matrixSequence()))
for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n)
return lTensor
}
override fun Tensor<Double>.cholesky(): DoubleTensor = cholesky(1e-6)
override fun Tensor<Double>.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape)
val qTensor = zeroesLike()
val rTensor = zeroesLike()
tensor.matrixSequence()
.zip(
(qTensor.matrixSequence()
.zip(rTensor.matrixSequence()))
).forEach { (matrix, qr) ->
val (q, r) = qr
qrHelper(matrix.asTensor(), q.asTensor(), r.as2D())
}
return qTensor to rTensor
}
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
svd(epsilon = 1e-10)
/**
* Singular Value Decomposition.
*
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
* The singular value decomposition is represented as a triple `(U, S, V)`,
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
*
* @param epsilon permissible error when calculating the dot product of vectors,
* i.e. the precision with which the cosine approaches 1 in an iterative algorithm.
* @return triple `(U, S, V)`.
*/
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.dimension
val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
tensor.matrixSequence()
.zip(
uTensor.matrixSequence()
.zip(
sTensor.vectorSequence()
.zip(vTensor.matrixSequence())
)
).forEach { (matrix, USV) ->
val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
val curMatrix = DoubleTensor(
matrix.shape,
matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
.toDoubleArray()
)
svdHelper(curMatrix, USV, m, n, epsilon)
}
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
}
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
symEig(epsilon = 1e-15)
/**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
* represented by a pair (eigenvalues, eigenvectors).
*
* @param epsilon permissible error when comparing tensors for equality
* and when the cosine approaches 1 in the SVD algorithm.
* @return a pair (eigenvalues, eigenvectors)
*/
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon)
val (u, s, v) = tensor.svd(epsilon)
val shp = s.shape + intArrayOf(1)
val utv = u.transpose() dot v
val n = s.shape.last()
for (matrix in utv.matrixSequence())
cleanSymHelper(matrix.as2D(), n)
val eig = (utv dot s.view(shp)).view(s.shape)
return eig to v
}
/**
* Computes the determinant of a square matrix input, or of each square matrix in a batched input
* using LU factorization algorithm.
*
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
* @return the determinant.
*/
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy()
val pivotsTensor = tensor.setUpPivots()
val n = shape.size
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
detTensorShape[n - 2] = 1
val resBuffer = DoubleArray(detTensorShape.reduce(Int::times)) { 0.0 }
val detTensor = DoubleTensor(
detTensorShape,
resBuffer
)
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
resBuffer[index] = if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
0.0 else luMatrixDet(lu.as2D(), pivots.as1D())
}
return detTensor
}
/**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input
* using LU factorization algorithm.
* Given a square matrix `a`, return the matrix `aInv` satisfying
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
*
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
* @return the multiplicative inverse of a matrix.
*/
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike()
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
for ((luP, invMatrix) in seq) {
val (lu, pivots) = luP
luMatrixInv(lu.as2D(), pivots.as1D(), invMatrix.as2D())
}
return invTensor
}
/**
* LUP decomposition
*
* Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* @param epsilon permissible error when comparing the determinant of a matrix with zero
* @return triple of P, L and U tensors
*/
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = tensor.luFactor(epsilon)
return luPivot(lu, pivots)
}
override fun Tensor<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.IntBuffer
/**
* Default [BufferedTensor] implementation for [Int] values
*/
public class IntTensor internal constructor(
shape: IntArray,
buffer: IntArray,
offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)

View File

@ -0,0 +1,57 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.Strides
import kotlin.math.max
internal fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
return res
var current = nDim - 1
res[current] = 1
while (current > 0) {
res[current - 1] = max(1, shape[current]) * res[current]
current--
}
return res
}
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}
/**
* This [Strides] implementation follows the last dimension first convention
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
*
* @param shape the shape of the tensor.
*/
internal class TensorLinearStructure(override val shape: IntArray) : Strides {
override val strides: IntArray
get() = stridesFromShape(shape)
override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size)
override val linearSize: Int
get() = shape.reduce(Int::times)
}

View File

@ -0,0 +1,146 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.tensors.core.DoubleTensor
import kotlin.math.max
internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) {
for (linearIndex in 0 until linearSize) {
val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex]
}
}
internal fun broadcastShapes(vararg shapes: IntArray): IntArray {
var totalDim = 0
for (shape in shapes) {
totalDim = max(totalDim, shape.size)
}
val totalShape = IntArray(totalDim) { 0 }
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
totalShape[i + offset] = max(totalShape[i + offset], curDim)
}
}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
check(curDim == 1 || totalShape[i + offset] == curDim) {
"Shapes are not compatible and cannot be broadcast"
}
}
}
return totalShape
}
internal fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
require(tensor.shape.size <= newShape.size) {
"Tensor is not compatible with the new shape"
}
val n = newShape.reduce { acc, i -> acc * i }
val resTensor = DoubleTensor(newShape, DoubleArray(n))
for (i in tensor.shape.indices) {
val curDim = tensor.shape[i]
val offset = newShape.size - tensor.shape.size
check(curDim == 1 || newShape[i + offset] == curDim) {
"Tensor is not compatible with the new shape and cannot be broadcast"
}
}
multiIndexBroadCasting(tensor, resTensor, n)
return resTensor
}
internal fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i }
return tensors.map { tensor ->
val resTensor = DoubleTensor(totalShape, DoubleArray(n))
multiIndexBroadCasting(tensor, resTensor, n)
resTensor
}
}
internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val onlyTwoDims = tensors.asSequence().onEach {
require(it.shape.size >= 2) {
"Tensors must have at least 2 dimensions"
}
}.any { it.shape.size != 2 }
if (!onlyTwoDims) {
return tensors.asList()
}
val totalShape = broadcastShapes(*(tensors.map { it.shape.sliceArray(0..it.shape.size - 3) }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i }
return buildList {
for (tensor in tensors) {
val matrixShape = tensor.shape.sliceArray(tensor.shape.size - 2 until tensor.shape.size).copyOf()
val matrixSize = matrixShape[0] * matrixShape[1]
val matrix = DoubleTensor(matrixShape, DoubleArray(matrixSize))
val outerTensor = DoubleTensor(totalShape, DoubleArray(n))
val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize))
for (linearIndex in 0 until n) {
val totalMultiIndex = outerTensor.linearStructure.index(linearIndex)
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex
val newTensor = DoubleTensor(curMultiIndex + matrixShape, tensor.mutableBuffer.array())
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i]
} else {
curMultiIndex[i] = 0
}
}
for (i in 0 until matrixSize) {
val curLinearIndex = newTensor.linearStructure.offset(
curMultiIndex +
matrix.linearStructure.index(i)
)
val newLinearIndex = resTensor.linearStructure.offset(
totalMultiIndex +
matrix.linearStructure.index(i)
)
resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] =
newTensor.mutableBuffer.array()[newTensor.bufferStart + curLinearIndex]
}
}
add(resTensor)
}
}
}

View File

@ -0,0 +1,64 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
internal fun checkEmptyShape(shape: IntArray) =
check(shape.isNotEmpty()) {
"Illegal empty shape provided"
}
internal fun checkEmptyDoubleBuffer(buffer: DoubleArray) =
check(buffer.isNotEmpty()) {
"Illegal empty buffer provided"
}
internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) =
check(buffer.size == shape.reduce(Int::times)) {
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
}
internal fun <T> checkShapesCompatible(a: Tensor<T>, b: Tensor<T>) =
check(a.shape contentEquals b.shape) {
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
}
internal fun checkTranspose(dim: Int, i: Int, j: Int) =
check((i < dim) and (j < dim)) {
"Cannot transpose $i to $j for a tensor of dim $dim"
}
internal fun <T> checkView(a: Tensor<T>, shape: IntArray) =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
internal fun checkSquareMatrix(shape: IntArray) {
val n = shape.size
check(n >= 2) {
"Expected tensor with 2 or more dimensions, got size $n instead"
}
check(shape[n - 1] == shape[n - 2]) {
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
}
}
internal fun DoubleTensorAlgebra.checkSymmetric(
tensor: Tensor<Double>, epsilon: Double = 1e-6
) =
check(tensor.eq(tensor.transpose(), epsilon)) {
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
}
internal fun DoubleTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor, epsilon: Double = 1e-6) {
checkSymmetric(tensor, epsilon)
for (mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0) {
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
}
}

View File

@ -0,0 +1,342 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.MutableStructure1D
import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.valueOrNull
import kotlin.math.abs
import kotlin.math.min
import kotlin.math.sign
import kotlin.math.sqrt
internal fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTensor<T>> = sequence {
val n = shape.size
val vectorOffset = shape[n - 1]
val vectorShape = intArrayOf(shape.last())
for (offset in 0 until numElements step vectorOffset) {
val vector = BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset)
yield(vector)
}
}
internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>> = sequence {
val n = shape.size
check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" }
val matrixOffset = shape[n - 1] * shape[n - 2]
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
for (offset in 0 until numElements step matrixOffset) {
val matrix = BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset)
yield(matrix)
}
}
internal fun dotHelper(
a: MutableStructure2D<Double>,
b: MutableStructure2D<Double>,
res: MutableStructure2D<Double>,
l: Int, m: Int, n: Int
) {
for (i in 0 until l) {
for (j in 0 until n) {
var curr = 0.0
for (k in 0 until m) {
curr += a[i, k] * b[k, j]
}
res[i, j] = curr
}
}
}
internal fun luHelper(
lu: MutableStructure2D<Double>,
pivots: MutableStructure1D<Int>,
epsilon: Double
): Boolean {
val m = lu.rowNum
for (row in 0..m) pivots[row] = row
for (i in 0 until m) {
var maxVal = 0.0
var maxInd = i
for (k in i until m) {
val absA = abs(lu[k, i])
if (absA > maxVal) {
maxVal = absA
maxInd = k
}
}
if (abs(maxVal) < epsilon)
return true // matrix is singular
if (maxInd != i) {
val j = pivots[i]
pivots[i] = pivots[maxInd]
pivots[maxInd] = j
for (k in 0 until m) {
val tmp = lu[i, k]
lu[i, k] = lu[maxInd, k]
lu[maxInd, k] = tmp
}
pivots[m] += 1
}
for (j in i + 1 until m) {
lu[j, i] /= lu[i, i]
for (k in i + 1 until m) {
lu[j, k] -= lu[j, i] * lu[i, k]
}
}
}
return false
}
internal fun <T> BufferedTensor<T>.setUpPivots(): IntTensor {
val n = this.shape.size
val m = this.shape.last()
val pivotsShape = IntArray(n - 1) { i -> this.shape[i] }
pivotsShape[n - 2] = m + 1
return IntTensor(
pivotsShape,
IntArray(pivotsShape.reduce(Int::times)) { 0 }
)
}
internal fun DoubleTensorAlgebra.computeLU(
tensor: DoubleTensor,
epsilon: Double
): Pair<DoubleTensor, IntTensor>? {
checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy()
val pivotsTensor = tensor.setUpPivots()
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
return null
return Pair(luTensor, pivotsTensor)
}
internal fun pivInit(
p: MutableStructure2D<Double>,
pivot: MutableStructure1D<Int>,
n: Int
) {
for (i in 0 until n) {
p[i, pivot[i]] = 1.0
}
}
internal fun luPivotHelper(
l: MutableStructure2D<Double>,
u: MutableStructure2D<Double>,
lu: MutableStructure2D<Double>,
n: Int
) {
for (i in 0 until n) {
for (j in 0 until n) {
if (i == j) {
l[i, j] = 1.0
}
if (j < i) {
l[i, j] = lu[i, j]
}
if (j >= i) {
u[i, j] = lu[i, j]
}
}
}
}
internal fun choleskyHelper(
a: MutableStructure2D<Double>,
l: MutableStructure2D<Double>,
n: Int
) {
for (i in 0 until n) {
for (j in 0 until i) {
var h = a[i, j]
for (k in 0 until j) {
h -= l[i, k] * l[j, k]
}
l[i, j] = h / l[j, j]
}
var h = a[i, i]
for (j in 0 until i) {
h -= l[i, j] * l[i, j]
}
l[i, i] = sqrt(h)
}
}
internal fun luMatrixDet(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>): Double {
if (lu[0, 0] == 0.0) {
return 0.0
}
val m = lu.shape[0]
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
}
internal fun luMatrixInv(
lu: MutableStructure2D<Double>,
pivots: MutableStructure1D<Int>,
invMatrix: MutableStructure2D<Double>
) {
val m = lu.shape[0]
for (j in 0 until m) {
for (i in 0 until m) {
if (pivots[i] == j) {
invMatrix[i, j] = 1.0
}
for (k in 0 until i) {
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
}
}
for (i in m - 1 downTo 0) {
for (k in i + 1 until m) {
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
}
invMatrix[i, j] /= lu[i, i]
}
}
}
internal fun DoubleTensorAlgebra.qrHelper(
matrix: DoubleTensor,
q: DoubleTensor,
r: MutableStructure2D<Double>
) {
checkSquareMatrix(matrix.shape)
val n = matrix.shape[0]
val qM = q.as2D()
val matrixT = matrix.transpose(0, 1)
val qT = q.transpose(0, 1)
for (j in 0 until n) {
val v = matrixT[j]
val vv = v.as1D()
if (j > 0) {
for (i in 0 until j) {
r[i, j] = (qT[i] dot matrixT[j]).value()
for (k in 0 until n) {
val qTi = qT[i].as1D()
vv[k] = vv[k] - r[i, j] * qTi[k]
}
}
}
r[j, j] = DoubleTensorAlgebra { (v dot v).sqrt().value() }
for (i in 0 until n) {
qM[i, j] = vv[i] / r[j, j]
}
}
}
internal fun DoubleTensorAlgebra.svd1d(a: DoubleTensor, epsilon: Double = 1e-10): DoubleTensor {
val (n, m) = a.shape
var v: DoubleTensor
val b: DoubleTensor
if (n > m) {
b = a.transpose(0, 1).dot(a)
v = DoubleTensor(intArrayOf(m), getRandomUnitVector(m, 0))
} else {
b = a.dot(a.transpose(0, 1))
v = DoubleTensor(intArrayOf(n), getRandomUnitVector(n, 0))
}
var lastV: DoubleTensor
while (true) {
lastV = v
v = b.dot(lastV)
val norm = DoubleTensorAlgebra { (v dot v).sqrt().value() }
v = v.times(1.0 / norm)
if (abs(v.dot(lastV).value()) > 1 - epsilon) {
return v
}
}
}
internal fun DoubleTensorAlgebra.svdHelper(
matrix: DoubleTensor,
USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>,
m: Int, n: Int, epsilon: Double
) {
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
val (matrixU, SV) = USV
val (matrixS, matrixV) = SV
for (k in 0 until min(n, m)) {
var a = matrix.copy()
for ((singularValue, u, v) in res.slice(0 until k)) {
val outerProduct = DoubleArray(u.shape[0] * v.shape[0])
for (i in 0 until u.shape[0]) {
for (j in 0 until v.shape[0]) {
outerProduct[i * v.shape[0] + j] = u[i].value() * v[j].value()
}
}
a = a - singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct))
}
var v: DoubleTensor
var u: DoubleTensor
var norm: Double
if (n > m) {
v = svd1d(a, epsilon)
u = matrix.dot(v)
norm = DoubleTensorAlgebra { (u dot u).sqrt().value() }
u = u.times(1.0 / norm)
} else {
u = svd1d(a, epsilon)
v = matrix.transpose(0, 1).dot(u)
norm = DoubleTensorAlgebra { (v dot v).sqrt().value() }
v = v.times(1.0 / norm)
}
res.add(Triple(norm, u, v))
}
val s = res.map { it.first }.toDoubleArray()
val uBuffer = res.map { it.second }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
val vBuffer = res.map { it.third }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
for (i in uBuffer.indices) {
matrixU.mutableBuffer.array()[matrixU.bufferStart + i] = uBuffer[i]
}
for (i in s.indices) {
matrixS.mutableBuffer.array()[matrixS.bufferStart + i] = s[i]
}
for (i in vBuffer.indices) {
matrixV.mutableBuffer.array()[matrixV.bufferStart + i] = vBuffer[i]
}
}
internal fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int) {
for (i in 0 until n)
for (j in 0 until n) {
if (i == j) {
matrix[i, j] = sign(matrix[i, j])
} else {
matrix[i, j] = 0.0
}
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.IntTensor
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor(
this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
)
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
else -> this.copyToBufferedTensor()
}
internal val Tensor<Double>.tensor: DoubleTensor
get() = when (this) {
is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor()
}
internal val Tensor<Int>.tensor: IntTensor
get() = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().asTensor()
}

View File

@ -0,0 +1,124 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.samplers.GaussianSampler
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor
import kotlin.math.*
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer] or copy the data.
*/
internal fun Buffer<Int>.array(): IntArray = when (this) {
is IntBuffer -> array
else -> this.toIntArray()
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
*/
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
is DoubleBuffer -> array
else -> this.toDoubleArray()
}
internal fun getRandomNormals(n: Int, seed: Long): DoubleArray {
val distribution = GaussianSampler(0.0, 1.0)
val generator = RandomGenerator.default(seed)
return distribution.sample(generator).nextBufferBlocking(n).toDoubleArray()
}
internal fun getRandomUnitVector(n: Int, seed: Long): DoubleArray {
val unnorm = getRandomNormals(n, seed)
val norm = sqrt(unnorm.sumOf { it * it })
return unnorm.map { it / norm }.toDoubleArray()
}
internal fun minusIndexFrom(n: Int, i: Int): Int = if (i >= 0) i else {
val ii = n + i
check(ii >= 0) {
"Out of bound index $i for tensor of dim $n"
}
ii
}
internal fun <T> BufferedTensor<T>.minusIndex(i: Int): Int = minusIndexFrom(this.dimension, i)
internal fun format(value: Double, digits: Int = 4): String = buildString {
val res = buildString {
val ten = 10.0
val approxOrder = if (value == 0.0) 0 else ceil(log10(abs(value))).toInt()
val order = if (
((value % ten) == 0.0) ||
(value == 1.0) ||
((1 / value) % ten == 0.0)
) approxOrder else approxOrder - 1
val lead = value / ten.pow(order)
if (value >= 0.0) append(' ')
append(round(lead * ten.pow(digits)) / ten.pow(digits))
when {
order == 0 -> Unit
order > 0 -> {
append("e+")
append(order)
}
else -> {
append('e')
append(order)
}
}
}
val fLength = digits + 6
append(res)
repeat(fLength - res.length) { append(' ') }
}
internal fun DoubleTensor.toPrettyString(): String = buildString {
var offset = 0
val shape = this@toPrettyString.shape
val linearStructure = this@toPrettyString.linearStructure
val vectorSize = shape.last()
append("DoubleTensor(\n")
var charOffset = 3
for (vector in vectorSequence()) {
repeat(charOffset) { append(' ') }
val index = linearStructure.index(offset)
for (ind in index.reversed()) {
if (ind != 0) {
break
}
append('[')
charOffset += 1
}
val values = vector.as1D().toMutableList().map(::format)
values.joinTo(this, separator = ", ")
append(']')
charOffset -= 1
index.reversed().zip(shape.reversed()).drop(1).forEach { (ind, maxInd) ->
if (ind != maxInd - 1) {
return@forEach
}
append(']')
charOffset -= 1
}
offset += vectorSize
if (this@toPrettyString.numElements == offset) {
break
}
append(",\n")
}
append("\n)")
}

View File

@ -0,0 +1,37 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.tensor
/**
* Casts [Tensor] of [Double] to [DoubleTensor]
*/
public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
/**
* Casts [Tensor] of [Int] to [IntTensor]
*/
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
/**
* Returns [DoubleArray] of tensor elements
*/
public fun DoubleTensor.toDoubleArray(): DoubleArray {
return DoubleArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}
/**
* Returns [IntArray] of tensor elements
*/
public fun IntTensor.toIntArray(): IntArray {
return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}

View File

@ -0,0 +1,105 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.*
import kotlin.test.Test
import kotlin.test.assertTrue
internal class TestBroadcasting {
@Test
fun testBroadcastShapes() = DoubleTensorAlgebra {
assertTrue(
broadcastShapes(
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
) contentEquals intArrayOf(1, 2, 3)
)
assertTrue(
broadcastShapes(
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7), intArrayOf(5, 1, 7)
) contentEquals intArrayOf(5, 6, 7)
)
}
@Test
fun testBroadcastTo() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape contentEquals intArrayOf(2, 3))
assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
}
@Test
fun testBroadcastTensors() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val res = broadcastTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
}
@Test
fun testBroadcastOuterTensors() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 1, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 1, 1))
assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0))
assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0))
}
@Test
fun testBroadcastOuterTensorsShapes() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 1, 3, 2, 3), DoubleArray(2 * 1 * 3 * 2 * 3) {0.0})
val tensor2 = fromArray(intArrayOf(4, 2, 5, 1, 3, 3), DoubleArray(4 * 2 * 5 * 1 * 3 * 3) {0.0})
val tensor3 = fromArray(intArrayOf(1, 1), doubleArrayOf(500.0))
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(4, 2, 5, 3, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(4, 2, 5, 3, 3, 3))
assertTrue(res[2].shape contentEquals intArrayOf(4, 2, 5, 3, 1, 1))
}
@Test
fun testMinusTensor() = BroadcastDoubleTensorAlgebra.invoke {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val tensor21 = tensor2 - tensor1
val tensor31 = tensor3 - tensor1
val tensor32 = tensor3 - tensor2
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
assertTrue(tensor21.mutableBuffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3))
assertTrue(
tensor31.mutableBuffer.array()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
)
assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3))
assertTrue(tensor32.mutableBuffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}
}

View File

@ -0,0 +1,158 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.operations.invoke
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertTrue
internal class TestDoubleAnalyticTensorAlgebra {
val shape = intArrayOf(2, 1, 3, 2)
val buffer = doubleArrayOf(
27.1, 20.0, 19.84,
23.123, 3.0, 2.0,
3.23, 133.7, 25.3,
100.3, 11.0, 12.012
)
val tensor = DoubleTensor(shape, buffer)
fun DoubleArray.fmap(transform: (Double) -> Double): DoubleArray {
return this.map(transform).toDoubleArray()
}
fun expectedTensor(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor(shape, buffer.fmap(transform))
}
@Test
fun testExp() = DoubleTensorAlgebra {
assertTrue { tensor.exp() eq expectedTensor(::exp) }
}
@Test
fun testLog() = DoubleTensorAlgebra {
assertTrue { tensor.ln() eq expectedTensor(::ln) }
}
@Test
fun testSqrt() = DoubleTensorAlgebra {
assertTrue { tensor.sqrt() eq expectedTensor(::sqrt) }
}
@Test
fun testCos() = DoubleTensorAlgebra {
assertTrue { tensor.cos() eq expectedTensor(::cos) }
}
@Test
fun testCosh() = DoubleTensorAlgebra {
assertTrue { tensor.cosh() eq expectedTensor(::cosh) }
}
@Test
fun testAcosh() = DoubleTensorAlgebra {
assertTrue { tensor.acosh() eq expectedTensor(::acosh) }
}
@Test
fun testSin() = DoubleTensorAlgebra {
assertTrue { tensor.sin() eq expectedTensor(::sin) }
}
@Test
fun testSinh() = DoubleTensorAlgebra {
assertTrue { tensor.sinh() eq expectedTensor(::sinh) }
}
@Test
fun testAsinh() = DoubleTensorAlgebra {
assertTrue { tensor.asinh() eq expectedTensor(::asinh) }
}
@Test
fun testTan() = DoubleTensorAlgebra {
assertTrue { tensor.tan() eq expectedTensor(::tan) }
}
@Test
fun testAtan() = DoubleTensorAlgebra {
assertTrue { tensor.atan() eq expectedTensor(::atan) }
}
@Test
fun testTanh() = DoubleTensorAlgebra {
assertTrue { tensor.tanh() eq expectedTensor(::tanh) }
}
@Test
fun testCeil() = DoubleTensorAlgebra {
assertTrue { tensor.ceil() eq expectedTensor(::ceil) }
}
@Test
fun testFloor() = DoubleTensorAlgebra {
assertTrue { tensor.floor() eq expectedTensor(::floor) }
}
val shape2 = intArrayOf(2, 2)
val buffer2 = doubleArrayOf(
1.0, 2.0,
-3.0, 4.0
)
val tensor2 = DoubleTensor(shape2, buffer2)
@Test
fun testMin() = DoubleTensorAlgebra {
assertTrue { tensor2.min() == -3.0 }
assertTrue { tensor2.min(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(-3.0, 2.0)
)}
assertTrue { tensor2.min(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(1.0, -3.0)
)}
}
@Test
fun testMax() = DoubleTensorAlgebra {
assertTrue { tensor2.max() == 4.0 }
assertTrue { tensor2.max(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(1.0, 4.0)
)}
assertTrue { tensor2.max(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(2.0, 4.0)
)}
}
@Test
fun testSum() = DoubleTensorAlgebra {
assertTrue { tensor2.sum() == 4.0 }
assertTrue { tensor2.sum(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(-2.0, 6.0)
)}
assertTrue { tensor2.sum(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(3.0, 1.0)
)}
}
@Test
fun testMean() = DoubleTensorAlgebra {
assertTrue { tensor2.mean() == 1.0 }
assertTrue { tensor2.mean(0, true) eq fromArray(
intArrayOf(1, 2),
doubleArrayOf(-1.0, 3.0)
)}
assertTrue { tensor2.mean(1, false) eq fromArray(
intArrayOf(2),
doubleArrayOf(1.5, 0.5)
)}
}
}

Some files were not shown because too many files have changed in this diff Show More