diff --git a/CHANGELOG.md b/CHANGELOG.md index f3fe37b6b..89e02d3b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140). - Automatic README generation for features (#139) - Native support for `memory`, `core` and `dimensions` +- `kmath-ejml` to supply EJML SimpleMatrix wrapper. ### Changed - Package changed from `scientifik` to `kscience.kmath`. @@ -14,6 +15,7 @@ - Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`) - `Polynomial` secondary constructor made function. - Kotlin version: 1.3.72 -> 1.4.20-M1 +- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library. ### Deprecated diff --git a/README.md b/README.md index 8bc85bf2b..708bd8eb1 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,9 @@ can be used for a wide variety of purposes from high performance calculations to * **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free to submit a feature request if you want something to be done first. - + +* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures. + ## Planned features * **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks. diff --git a/build.gradle.kts b/build.gradle.kts index 499f49d1d..05e2d5979 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,9 +2,9 @@ plugins { id("ru.mipt.npm.project") } -val kmathVersion by extra("0.2.0-dev-2") -val bintrayRepo by extra("kscience") -val githubProject by extra("kmath") +val kmathVersion: String by extra("0.2.0-dev-2") +val bintrayRepo: String by extra("kscience") +val githubProject: String by extra("kmath") allprojects { repositories { @@ -22,6 +22,6 @@ subprojects { if (name.startsWith("kmath")) apply() } -readme{ +readme { readmeTemplate = file("docs/templates/README-TEMPLATE.md") } diff --git a/docs/linear.md b/docs/linear.md index 883df275e..6ccc6caac 100644 --- a/docs/linear.md +++ b/docs/linear.md @@ -6,10 +6,10 @@ back-ends. The new operations added as extensions to contexts instead of being m Two major contexts used for linear algebra and hyper-geometry: -* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its typealias `Point` used for geometry). +* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its type alias `Point` used for geometry). * `MatrixContext` forms a space-like context for 2d-structures. It does not store matrix size and therefore does not implement -`Space` interface (it is not possible to create zero element without knowing the matrix size). +`Space` interface (it is impossible to create zero element without knowing the matrix size). ## Vector spaces diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 3f18d3cf3..900da966b 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -26,9 +26,12 @@ dependencies { implementation(project(":kmath-prob")) implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) + implementation(project(":kmath-ejml")) implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") - "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath + implementation("org.slf4j:slf4j-simple:1.7.30") + "benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8") + "benchmarksImplementation"(sourceSets.main.get().output + sourceSets.main.get().runtimeClasspath) } // Configure benchmark diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt index 2673552f5..a91d02253 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt @@ -6,34 +6,33 @@ import org.openjdk.jmh.annotations.State import java.nio.IntBuffer @State(Scope.Benchmark) -class ArrayBenchmark { +internal class ArrayBenchmark { @Benchmark fun benchmarkArrayRead() { var res = 0 - for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.array[_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i] + for (i in 1..size) res += array[size - i] } @Benchmark fun benchmarkBufferRead() { var res = 0 - for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.arrayBuffer.get( - _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i) + for (i in 1..size) res += arrayBuffer.get( + size - i + ) } @Benchmark fun nativeBufferRead() { var res = 0 - for (i in 1.._root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) res += _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.nativeBuffer.get( - _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size - i) + for (i in 1..size) res += nativeBuffer.get( + size - i + ) } companion object { const val size: Int = 1000 - val array: IntArray = IntArray(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) { it } - val arrayBuffer: IntBuffer = IntBuffer.wrap(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.array) - - val nativeBuffer: IntBuffer = IntBuffer.allocate(_root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size).also { - for (i in 0 until _root_ide_package_.kscience.kmath.structures.ArrayBenchmark.Companion.size) it.put(i, i) - } + val array: IntArray = IntArray(size) { it } + val arrayBuffer: IntBuffer = IntBuffer.wrap(array) + val nativeBuffer: IntBuffer = IntBuffer.allocate(size).also { for (i in 0 until size) it.put(i, i) } } } diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/BufferBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/BufferBenchmark.kt index 009d51001..8b6fd4a51 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/BufferBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/BufferBenchmark.kt @@ -7,11 +7,10 @@ import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State @State(Scope.Benchmark) -class BufferBenchmark { - +internal class BufferBenchmark { @Benchmark fun genericRealBufferReadWrite() { - val buffer = RealBuffer(size){it.toDouble()} + val buffer = RealBuffer(size) { it.toDouble() } (0 until size).forEach { buffer[it] @@ -20,7 +19,7 @@ class BufferBenchmark { @Benchmark fun complexBufferReadWrite() { - val buffer = MutableBuffer.complex(size / 2){Complex(it.toDouble(), -it.toDouble())} + val buffer = MutableBuffer.complex(size / 2) { Complex(it.toDouble(), -it.toDouble()) } (0 until size / 2).forEach { buffer[it] @@ -28,6 +27,6 @@ class BufferBenchmark { } companion object { - const val size = 100 + const val size: Int = 100 } } \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/NDFieldBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/NDFieldBenchmark.kt index 64f279c39..8ec47ae81 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/NDFieldBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/NDFieldBenchmark.kt @@ -7,7 +7,7 @@ import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State @State(Scope.Benchmark) -class NDFieldBenchmark { +internal class NDFieldBenchmark { @Benchmark fun autoFieldAdd() { bufferedField { @@ -40,11 +40,10 @@ class NDFieldBenchmark { } companion object { - val dim = 1000 - val n = 100 - - val bufferedField = NDField.auto(RealField, dim, dim) - val specializedField = NDField.real(dim, dim) - val genericField = NDField.boxing(RealField, dim, dim) + const val dim: Int = 1000 + const val n: Int = 100 + val bufferedField: BufferedNDField = NDField.auto(RealField, dim, dim) + val specializedField: RealNDField = NDField.real(dim, dim) + val genericField: BoxingNDField = NDField.boxing(RealField, dim, dim) } } \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt index a4b831f7c..464925ca0 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt @@ -9,9 +9,9 @@ import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State @State(Scope.Benchmark) -class ViktorBenchmark { - final val dim = 1000 - final val n = 100 +internal class ViktorBenchmark { + final val dim: Int = 1000 + final val n: Int = 100 // automatically build context most suited for given type. final val autoField: BufferedNDField = NDField.auto(RealField, dim, dim) @@ -42,7 +42,7 @@ class ViktorBenchmark { } @Benchmark - fun realdFieldLog() { + fun realFieldLog() { realField { val fortyTwo = produce { 42.0 } var res = one diff --git a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index a5768f1f5..f0a32e5bd 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -1,4 +1,4 @@ -//package kscience.kmath.ast +package kscience.kmath.ast // //import kscience.kmath.asm.compile //import kscience.kmath.expressions.Expression diff --git a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt index aeb760998..7f06f4a1f 100644 --- a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt +++ b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt @@ -6,9 +6,9 @@ import kscience.kmath.chains.collectWithState import kscience.kmath.prob.RandomGenerator import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler -data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) +private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) -fun Chain.mean(): Chain = collectWithState(AveragingChainState(), { it.copy() }) { chain -> +private fun Chain.mean(): Chain = collectWithState(AveragingChainState(), { it.copy() }) { chain -> val next = chain.next() num++ value += next diff --git a/examples/src/main/kotlin/kscience/kmath/linear/LinearAlgebraBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/linear/LinearAlgebraBenchmark.kt new file mode 100644 index 000000000..3316f3236 --- /dev/null +++ b/examples/src/main/kotlin/kscience/kmath/linear/LinearAlgebraBenchmark.kt @@ -0,0 +1,50 @@ +package kscience.kmath.linear + +import kscience.kmath.commons.linear.CMMatrixContext +import kscience.kmath.commons.linear.inverse +import kscience.kmath.commons.linear.toCM +import kscience.kmath.ejml.EjmlMatrixContext +import kscience.kmath.ejml.inverse +import kscience.kmath.operations.RealField +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Matrix +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +fun main() { + val random = Random(1224) + val dim = 100 + //creating invertible matrix + val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 } + val matrix = l dot u + + val n = 5000 // iterations + + MatrixContext.real { + repeat(50) { inverse(matrix) } + val inverseTime = measureTimeMillis { repeat(n) { inverse(matrix) } } + println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis") + } + + //commons-math + + val commonsTime = measureTimeMillis { + CMMatrixContext { + val cm = matrix.toCM() //avoid overhead on conversion + repeat(n) { inverse(cm) } + } + } + + + println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis") + + val ejmlTime = measureTimeMillis { + (EjmlMatrixContext(RealField)) { + val km = matrix.toEjml() //avoid overhead on conversion + repeat(n) { inverse(km) } + } + } + + println("[ejml] Inversion of $n matrices $dim x $dim finished in $ejmlTime millis") +} diff --git a/examples/src/main/kotlin/kscience/kmath/linear/MultiplicationBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/linear/MultiplicationBenchmark.kt new file mode 100644 index 000000000..d1011e8f5 --- /dev/null +++ b/examples/src/main/kotlin/kscience/kmath/linear/MultiplicationBenchmark.kt @@ -0,0 +1,38 @@ +package kscience.kmath.linear + +import kscience.kmath.commons.linear.CMMatrixContext +import kscience.kmath.commons.linear.toCM +import kscience.kmath.ejml.EjmlMatrixContext +import kscience.kmath.operations.RealField +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Matrix +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +fun main() { + val random = Random(12224) + val dim = 1000 + //creating invertible matrix + val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + +// //warmup +// matrix1 dot matrix2 + + CMMatrixContext { + val cmMatrix1 = matrix1.toCM() + val cmMatrix2 = matrix2.toCM() + val cmTime = measureTimeMillis { cmMatrix1 dot cmMatrix2 } + println("CM implementation time: $cmTime") + } + + (EjmlMatrixContext(RealField)) { + val ejmlMatrix1 = matrix1.toEjml() + val ejmlMatrix2 = matrix2.toEjml() + val ejmlTime = measureTimeMillis { ejmlMatrix1 dot ejmlMatrix2 } + println("EJML implementation time: $ejmlTime") + } + + val genericTime = measureTimeMillis { val res = matrix1 dot matrix2 } + println("Generic implementation time: $genericTime") +} diff --git a/examples/src/main/kotlin/kscience/kmath/operations/BigIntDemo.kt b/examples/src/main/kotlin/kscience/kmath/operations/BigIntDemo.kt index 692ea6d8c..0e9811ff8 100644 --- a/examples/src/main/kotlin/kscience/kmath/operations/BigIntDemo.kt +++ b/examples/src/main/kotlin/kscience/kmath/operations/BigIntDemo.kt @@ -1,8 +1,6 @@ package kscience.kmath.operations fun main() { - val res = BigIntField { - number(1) * 2 - } + val res = BigIntField { number(1) * 2 } println("bigint:$res") } \ No newline at end of file diff --git a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt index 3c97940a8..34b3c9981 100644 --- a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt +++ b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt @@ -5,15 +5,19 @@ import kscience.kmath.structures.NDField import kscience.kmath.structures.complex fun main() { + // 2d element val element = NDElement.complex(2, 2) { index: IntArray -> Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) } + println(element) - val compute = (NDField.complex(8)) { + // 1d element operation + val result = with(NDField.complex(8)) { val a = produce { (it) -> i * it - it.toDouble() } val b = 3 val c = Complex(1.0, 1.0) (a pow b) + c } + println(result) } diff --git a/examples/src/main/kotlin/kscience/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/structures/StructureReadBenchmark.kt index a2bfea2f9..51fd4f956 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/StructureReadBenchmark.kt @@ -4,32 +4,30 @@ import kotlin.system.measureTimeMillis fun main() { val n = 6000 - val array = DoubleArray(n * n) { 1.0 } val buffer = RealBuffer(array) val strides = DefaultStrides(intArrayOf(n, n)) - val structure = BufferNDStructure(strides, buffer) measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = structure[it] } } // warmup val time1 = measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = structure[it] } } println("Structure reading finished in $time1 millis") val time2 = measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = buffer[strides.offset(it)] } } println("Buffer reading finished in $time2 millis") val time3 = measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = array[strides.offset(it)] } } println("Array reading finished in $time3 millis") diff --git a/examples/src/main/kotlin/kscience/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/structures/StructureWriteBenchmark.kt index b2975393f..db55b454f 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/StructureWriteBenchmark.kt @@ -4,24 +4,17 @@ import kotlin.system.measureTimeMillis fun main() { val n = 6000 - val structure = NDStructure.build(intArrayOf(n, n), Buffer.Companion::auto) { 1.0 } - structure.mapToBuffer { it + 1 } // warm-up - - val time1 = measureTimeMillis { - val res = structure.mapToBuffer { it + 1 } - } + val time1 = measureTimeMillis { val res = structure.mapToBuffer { it + 1 } } println("Structure mapping finished in $time1 millis") - val array = DoubleArray(n * n) { 1.0 } val time2 = measureTimeMillis { val target = DoubleArray(n * n) - val res = array.forEachIndexed { index, value -> - target[index] = value + 1 - } + val res = array.forEachIndexed { index, value -> target[index] = value + 1 } } + println("Array mapping finished in $time2 millis") val buffer = RealBuffer(DoubleArray(n * n) { 1.0 }) diff --git a/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt index bf83a9f05..987eea16f 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt @@ -6,7 +6,7 @@ import kscience.kmath.dimensions.DMatrixContext import kscience.kmath.dimensions.Dimension import kscience.kmath.operations.RealField -fun DMatrixContext.simple() { +private fun DMatrixContext.simple() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i + j).toDouble() } @@ -14,12 +14,11 @@ fun DMatrixContext.simple() { m1.transpose() + m2 } - -object D5 : Dimension { +private object D5 : Dimension { override val dim: UInt = 5u } -fun DMatrixContext.custom() { +private fun DMatrixContext.custom() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i - j).toDouble() } val m3 = produce { i, j -> (i - j).toDouble() } diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index df876df10..a0afcdc4f 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -14,7 +14,6 @@ kotlin.sourceSets { implementation("org.ow2.asm:asm:8.0.1") implementation("org.ow2.asm:asm-commons:8.0.1") implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0") - implementation(kotlin("reflect")) } } } diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt index 53425e7e3..2b6fa6247 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt @@ -8,7 +8,6 @@ import kscience.kmath.ast.MST import kscience.kmath.ast.MstExpression import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra -import kotlin.reflect.KClass /** * Compiles given MST to an Expression using AST compiler. @@ -18,7 +17,8 @@ import kotlin.reflect.KClass * @return the compiled expression. * @author Alexander Nozik */ -public fun MST.compileWith(type: KClass, algebra: Algebra): Expression { +@PublishedApi +internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { fun AsmBuilder.visit(node: MST): Unit = when (node) { is MST.Symbolic -> { val symbol = try { @@ -61,11 +61,12 @@ public fun MST.compileWith(type: KClass, algebra: Algebra): Expr * * @author Alexander Nozik. */ -public inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) +public inline fun Algebra.expression(mst: MST): Expression = + mst.compileWith(T::class.java, this) /** * Optimizes performance of an [MstExpression] using ASM codegen. * * @author Alexander Nozik. */ -public inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra) +public inline fun MstExpression.compile(): Expression = mst.compileWith(T::class.java, algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index ab2de97aa..06f02a94d 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -10,7 +10,6 @@ import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter import java.util.* import java.util.stream.Collectors -import kotlin.reflect.KClass /** * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. @@ -23,7 +22,7 @@ import kotlin.reflect.KClass * @author Iaroslav Postovalov */ internal class AsmBuilder internal constructor( - private val classOfT: KClass<*>, + private val classOfT: Class<*>, private val algebra: Algebra, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit @@ -32,7 +31,7 @@ internal class AsmBuilder internal constructor( * Internal classloader of [AsmBuilder] with alias to define class from byte array. */ private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { - internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } /** @@ -43,7 +42,7 @@ internal class AsmBuilder internal constructor( /** * ASM Type for [algebra]. */ - private val tAlgebraType: Type = algebra::class.asm + private val tAlgebraType: Type = algebra.javaClass.asm /** * ASM type for [T]. @@ -55,16 +54,6 @@ internal class AsmBuilder internal constructor( */ private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! - /** - * Index of `this` variable in invoke method of the built subclass. - */ - private val invokeThisVar: Int = 0 - - /** - * Index of `arguments` variable in invoke method of the built subclass. - */ - private val invokeArgumentsVar: Int = 1 - /** * List of constants to provide to the subclass. */ @@ -76,22 +65,22 @@ internal class AsmBuilder internal constructor( private lateinit var invokeMethodVisitor: InstructionAdapter /** - * State if this [AsmBuilder] needs to generate constants field. + * States whether this [AsmBuilder] needs to generate constants field. */ private var hasConstants: Boolean = true /** - * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. + * States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. */ internal var primitiveMode: Boolean = false /** - * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + * Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. */ internal var primitiveMask: Type = OBJECT_TYPE /** - * Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + * Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. */ internal var primitiveMaskBoxed: Type = OBJECT_TYPE @@ -103,7 +92,7 @@ internal class AsmBuilder internal constructor( /** * Stack of useful objects types on stack expected by algebra calls. */ - internal val expectationStack: ArrayDeque = ArrayDeque(listOf(tType)) + internal val expectationStack: ArrayDeque = ArrayDeque(1).also { it.push(tType) } /** * The cache for instance built by this builder. @@ -361,7 +350,7 @@ internal class AsmBuilder internal constructor( * from it). */ private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { - val boxed = value::class.asm + val boxed = value.javaClass.asm val primitive = BOXED_TO_PRIMITIVES[boxed] if (primitive != null) { @@ -475,17 +464,27 @@ internal class AsmBuilder internal constructor( internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) internal companion object { + /** + * Index of `this` variable in invoke method of the built subclass. + */ + private const val invokeThisVar: Int = 0 + + /** + * Index of `arguments` variable in invoke method of the built subclass. + */ + private const val invokeArgumentsVar: Int = 1 + /** * Maps JVM primitive numbers boxed types to their primitive ASM types. */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class to Type.BYTE_TYPE, - java.lang.Short::class to Type.SHORT_TYPE, - java.lang.Integer::class to Type.INT_TYPE, - java.lang.Long::class to Type.LONG_TYPE, - java.lang.Float::class to Type.FLOAT_TYPE, - java.lang.Double::class to Type.DOUBLE_TYPE + java.lang.Byte::class.java to Type.BYTE_TYPE, + java.lang.Short::class.java to Type.SHORT_TYPE, + java.lang.Integer::class.java to Type.INT_TYPE, + java.lang.Long::class.java to Type.LONG_TYPE, + java.lang.Float::class.java to Type.FLOAT_TYPE, + java.lang.Double::class.java to Type.DOUBLE_TYPE ) } @@ -523,43 +522,43 @@ internal class AsmBuilder internal constructor( /** * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } /** * ASM type for [Expression]. */ - internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") } /** * ASM type for [java.lang.Number]. */ - internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") } /** * ASM type for [java.util.Map]. */ - internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") } /** * ASM type for [java.lang.Object]. */ - internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } + internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") } /** * ASM type for array of [java.lang.Object]. */ @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") } /** * ASM type for [Algebra]. */ - internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") } /** * ASM type for [java.lang.String]. */ - internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } + internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") } /** * ASM type for MapIntrinsics. diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt index 67fce40ac..ef9751502 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt @@ -10,9 +10,9 @@ import org.objectweb.asm.* import org.objectweb.asm.Opcodes.INVOKEVIRTUAL import org.objectweb.asm.commons.InstructionAdapter import java.lang.reflect.Method +import java.util.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract -import kotlin.reflect.KClass private val methodNameAdapters: Map, String> by lazy { hashMapOf( @@ -26,12 +26,12 @@ private val methodNameAdapters: Map, String> by lazy { } /** - * Returns ASM [Type] for given [KClass]. + * Returns ASM [Type] for given [Class]. * * @author Iaroslav Postovalov */ -internal val KClass<*>.asm: Type - get() = Type.getType(java) +internal inline val Class<*>.asm: Type + get() = Type.getType(this) /** * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. @@ -140,7 +140,7 @@ private fun AsmBuilder.buildExpectationStack( if (specific != null) mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) } else - repeat(arity) { expectationStack.push(tType) } + expectationStack.addAll(Collections.nCopies(arity, tType)) return specific != null } @@ -169,7 +169,7 @@ private fun AsmBuilder.tryInvokeSpecific( val arity = parameterTypes.size val theName = methodNameAdapters[name to arity] ?: name val spec = findSpecific(context, theName, parameterTypes) ?: return false - val owner = context::class.asm + val owner = context.javaClass.asm invokeAlgebraOperation( owner = owner.internalName, diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt index 15e6625db..94cd3b321 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt @@ -7,6 +7,7 @@ import com.github.h0tk3y.betterParse.grammar.parser import com.github.h0tk3y.betterParse.grammar.tryParseToEnd import com.github.h0tk3y.betterParse.lexer.Token import com.github.h0tk3y.betterParse.lexer.TokenMatch +import com.github.h0tk3y.betterParse.lexer.literalToken import com.github.h0tk3y.betterParse.lexer.regexToken import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.Parser @@ -23,14 +24,14 @@ public object ArithmeticsEvaluator : Grammar() { // TODO replace with "...".toRegex() when better-parse 0.4.1 is released private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?") private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*") - private val lpar: Token by regexToken("\\(") - private val rpar: Token by regexToken("\\)") - private val comma: Token by regexToken(",") - private val mul: Token by regexToken("\\*") - private val pow: Token by regexToken("\\^") - private val div: Token by regexToken("/") - private val minus: Token by regexToken("-") - private val plus: Token by regexToken("\\+") + private val lpar: Token by literalToken("(") + private val rpar: Token by literalToken(")") + private val comma: Token by literalToken(",") + private val mul: Token by literalToken("*") + private val pow: Token by literalToken("^") + private val div: Token by literalToken("/") + private val minus: Token by literalToken("-") + private val plus: Token by literalToken("+") private val ws: Token by regexToken("\\s+", ignore = true) private val number: Parser by num use { MST.Numeric(text.toDouble()) } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt index 8a09cc793..c39f0d04c 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt @@ -9,14 +9,17 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import kotlin.properties.ReadOnlyProperty /** - * A field wrapping commons-math derivative structures + * A field over commons-math [DerivativeStructure]. + * + * @property order The derivation order. + * @property parameters The map of free parameters. */ public class DerivativeStructureField( public val order: Int, public val parameters: Map ) : ExtendedField { - public override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } - public override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } + public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) } + public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) } private val variables: Map = parameters.mapValues { (key, value) -> DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 49844a2be..5b050dd36 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -4,7 +4,7 @@ import kscience.kmath.operations.* internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : Expression { - public override operator fun invoke(arguments: Map): T = + override operator fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) } @@ -14,17 +14,17 @@ internal class FunctionalBinaryOperation( val first: Expression, val second: Expression ) : Expression { - public override operator fun invoke(arguments: Map): T = + override operator fun invoke(arguments: Map): T = context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) } internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { - public override operator fun invoke(arguments: Map): T = + override operator fun invoke(arguments: Map): T = arguments[name] ?: default ?: error("Parameter not found: $name") } internal class FunctionalConstantExpression(val value: T) : Expression { - public override operator fun invoke(arguments: Map): T = value + override operator fun invoke(arguments: Map): T = value } internal class FunctionalConstProductExpression( @@ -32,7 +32,7 @@ internal class FunctionalConstProductExpression( private val expr: Expression, val const: Number ) : Expression { - public override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) + override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } /** @@ -139,16 +139,27 @@ public open class FunctionalExpressionField(algebra: A) : public open class FunctionalExpressionExtendedField(algebra: A) : FunctionalExpressionField(algebra), ExtendedField> where A : ExtendedField, A : NumericAlgebra { - public override fun sin(arg: Expression): Expression = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - public override fun cos(arg: Expression): Expression = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - public override fun asin(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - public override fun acos(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - public override fun atan(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) + public override fun sin(arg: Expression): Expression = + unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + + public override fun cos(arg: Expression): Expression = + unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + + public override fun asin(arg: Expression): Expression = + unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + + public override fun acos(arg: Expression): Expression = + unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + + public override fun atan(arg: Expression): Expression = + unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) public override fun power(arg: Expression, pow: Number): Expression = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) - public override fun exp(arg: Expression): Expression = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + public override fun exp(arg: Expression): Expression = + unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + public override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION, arg) public override fun unaryOperation(operation: String, arg: Expression): Expression = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt index 65dc8df76..5d9af8608 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt @@ -24,7 +24,11 @@ public interface FeaturedMatrix : Matrix { public companion object } -public inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix = +public inline fun Structure2D.Companion.real( + rows: Int, + columns: Int, + initializer: (Int, Int) -> Double +): Matrix = MatrixContext.real.produce(rows, columns, initializer) /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index d66530472..f4dbce89a 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -18,20 +18,52 @@ public interface MatrixContext : SpaceOperations> { */ public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix + public override fun binaryOperation(operation: String, left: Matrix, right: Matrix): Matrix = when (operation) { + "dot" -> left dot right + else -> super.binaryOperation(operation, left, right) + } + + /** + * Computes the dot product of this matrix and another one. + * + * @receiver the multiplicand. + * @param other the multiplier. + * @return the dot product. + */ public infix fun Matrix.dot(other: Matrix): Matrix + /** + * Computes the dot product of this matrix and a vector. + * + * @receiver the multiplicand. + * @param vector the multiplier. + * @return the dot product. + */ public infix fun Matrix.dot(vector: Point): Point + /** + * Multiplies a matrix by its element. + * + * @receiver the multiplicand. + * @param value the multiplier. + * @receiver the product. + */ public operator fun Matrix.times(value: T): Matrix + /** + * Multiplies an element by a matrix of it. + * + * @receiver the multiplicand. + * @param value the multiplier. + * @receiver the product. + */ public operator fun T.times(m: Matrix): Matrix = m * this public companion object { /** * Non-boxing double matrix */ - public val real: RealMatrixContext - get() = RealMatrixContext + public val real: RealMatrixContext = RealMatrixContext /** * A structured matrix with custom buffer @@ -60,7 +92,7 @@ public interface GenericMatrixContext> : MatrixContext { */ public fun point(size: Int, initializer: (Int) -> T): Point - override infix fun Matrix.dot(other: Matrix): Matrix { + public override infix fun Matrix.dot(other: Matrix): Matrix { //TODO add typed error require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } @@ -71,7 +103,7 @@ public interface GenericMatrixContext> : MatrixContext { } } - override infix fun Matrix.dot(vector: Point): Point { + public override infix fun Matrix.dot(vector: Point): Point { //TODO add typed error require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } @@ -81,10 +113,10 @@ public interface GenericMatrixContext> : MatrixContext { } } - override operator fun Matrix.unaryMinus(): Matrix = + public override operator fun Matrix.unaryMinus(): Matrix = produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } } - override fun add(a: Matrix, b: Matrix): Matrix { + public override fun add(a: Matrix, b: Matrix): Matrix { require(a.rowNum == b.rowNum && a.colNum == b.colNum) { "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]" } @@ -92,7 +124,7 @@ public interface GenericMatrixContext> : MatrixContext { return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } } } - override operator fun Matrix.minus(b: Matrix): Matrix { + public override operator fun Matrix.minus(b: Matrix): Matrix { require(rowNum == b.rowNum && colNum == b.colNum) { "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]" } @@ -100,11 +132,11 @@ public interface GenericMatrixContext> : MatrixContext { return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } } } - override fun multiply(a: Matrix, k: Number): Matrix = + public override fun multiply(a: Matrix, k: Number): Matrix = produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } public operator fun Number.times(matrix: FeaturedMatrix): Matrix = matrix * this - override operator fun Matrix.times(value: T): Matrix = + public override operator fun Matrix.times(value: T): Matrix = produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt index b4a610eb1..bfcd5959f 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt @@ -1,10 +1,7 @@ package kscience.kmath.misc import kscience.kmath.linear.Point -import kscience.kmath.operations.ExtendedField -import kscience.kmath.operations.Field -import kscience.kmath.operations.invoke -import kscience.kmath.operations.sum +import kscience.kmath.operations.* import kscience.kmath.structures.asBuffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -17,23 +14,37 @@ import kotlin.contracts.contract /** * Differentiable variable with value and derivative of differentiation ([deriv]) result * with respect to this variable. + * + * @param T the non-nullable type of value. + * @property value The value of this variable. */ public open class Variable(public val value: T) +/** + * Represents result of [deriv] call. + * + * @param T the non-nullable type of value. + * @param value the value of result. + * @property deriv The mapping of differentiated variables to their derivatives. + * @property context The field over [T]. + */ public class DerivationResult( value: T, public val deriv: Map, T>, public val context: Field ) : Variable(value) { + /** + * Returns derivative of [variable] or returns [Ring.zero] in [context]. + */ public fun deriv(variable: Variable): T = deriv[variable] ?: context.zero /** - * compute divergence + * Computes the divergence. */ public fun div(): T = context { sum(deriv.values) } /** - * Compute a gradient for variables in given order + * Computes the gradient for variables in given order. */ public fun grad(vararg variables: Variable): Point { check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } @@ -53,6 +64,9 @@ public class DerivationResult( * assertEquals(17.0, y.x) // the value of result (y) * assertEquals(9.0, x.d) // dy/dx * ``` + * + * @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to. + * @return the result of differentiation. */ public inline fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } @@ -65,12 +79,15 @@ public inline fun > F.deriv(body: AutoDiffField.() - } } +/** + * Represents field in context of which functions can be derived. + */ public abstract class AutoDiffField> : Field> { public abstract val context: F /** * A variable accessing inner state of derivatives. - * Use this function in inner builders to avoid creating additional derivative bindings + * Use this value in inner builders to avoid creating additional derivative bindings. */ public abstract var Variable.d: T @@ -87,6 +104,9 @@ public abstract class AutoDiffField> : Field> */ public abstract fun derive(value: R, block: F.(R) -> Unit): R + /** + * + */ public abstract fun variable(value: T): Variable public inline fun variable(block: F.() -> T): Variable = variable(context.block()) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt index 4590c58fc..20f289596 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt @@ -299,7 +299,7 @@ public class BigInt internal constructor( for (i in mag.indices) { val cur: ULong = carry + mag[i].toULong() * x.toULong() - result[i] = (cur and BASE.toULong()).toUInt() + result[i] = (cur and BASE).toUInt() carry = cur shr BASE_SIZE } result[resultLength - 1] = (carry and BASE).toUInt() @@ -316,7 +316,7 @@ public class BigInt internal constructor( for (j in mag2.indices) { val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry - result[i + j] = (cur and BASE.toULong()).toUInt() + result[i + j] = (cur and BASE).toUInt() carry = cur shr BASE_SIZE } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt index 24bfec054..37055a5c8 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt @@ -6,6 +6,7 @@ import kscience.kmath.memory.MemoryWriter import kscience.kmath.structures.Buffer import kscience.kmath.structures.MemoryBuffer import kscience.kmath.structures.MutableBuffer +import kscience.kmath.structures.MutableMemoryBuffer import kotlin.math.* /** @@ -159,7 +160,7 @@ public object ComplexField : ExtendedField, Norm { } /** - * Represents complex number. + * Represents `double`-based complex number. * * @property re The real part. * @property im The imaginary part. @@ -176,11 +177,16 @@ public data class Complex(val re: Double, val im: Double) : FieldElement { - override val objectSize: Int = 16 + override fun toString(): String { + return "($re + i*$im)" + } - override fun MemoryReader.read(offset: Int): Complex = - Complex(readDouble(offset), readDouble(offset + 8)) + + public companion object : MemorySpec { + override val objectSize: Int + get() = 16 + + override fun MemoryReader.read(offset: Int): Complex = Complex(readDouble(offset), readDouble(offset + 8)) override fun MemoryWriter.write(offset: Int, value: Complex) { writeDouble(offset, value.re) @@ -197,8 +203,16 @@ public data class Complex(val re: Double, val im: Double) : FieldElement Complex): Buffer = MemoryBuffer.create(Complex, size, init) -public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer = - MemoryBuffer.create(Complex, size, init) +/** + * Creates a new buffer of complex numbers with the specified [size], where each element is calculated by calling the + * specified [init] function. + */ +public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): MutableBuffer = + MutableMemoryBuffer.create(Complex, size, init) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt index c3c859f7a..dc65b12c4 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt @@ -15,8 +15,9 @@ public class BoxingNDField>( public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) - public override fun check(vararg elements: NDBuffer) { - check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + public override fun check(vararg elements: NDBuffer): Array> { + require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + return elements } public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement = @@ -75,6 +76,6 @@ public inline fun , R> F.nd( vararg shape: Int, action: NDField.() -> R ): R { - val ndfield: BoxingNDField = NDField.boxing(this, *shape, bufferFactory = bufferFactory) + val ndfield = NDField.boxing(this, *shape, bufferFactory = bufferFactory) return ndfield.action() } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDRing.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDRing.kt index 461b0387c..b6794984c 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDRing.kt @@ -14,8 +14,9 @@ public class BoxingNDRing>( public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) - override fun check(vararg elements: NDBuffer) { - require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + override fun check(vararg elements: NDBuffer): Array> { + if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") + return elements } override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt index 66b4f19e1..3dcd0322c 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt @@ -5,8 +5,10 @@ import kscience.kmath.operations.* public interface BufferedNDAlgebra : NDAlgebra> { public val strides: Strides - public override fun check(vararg elements: NDBuffer): Unit = - require(elements.all { it.strides == strides }) { ("Strides mismatch") } + public override fun check(vararg elements: NDBuffer): Array> { + require(elements.all { it.strides == strides }) { "Strides mismatch" } + return elements + } /** * Convert any [NDStructure] to buffered structure using strides from this context. diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt index 53587e503..5174eb314 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt @@ -46,35 +46,48 @@ public interface Buffer { asSequence().mapIndexed { index, value -> value == other[index] }.all { it } public companion object { - public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { - val array = DoubleArray(size) { initializer(it) } - return RealBuffer(array) - } + /** + * Creates a [RealBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer = + RealBuffer(size) { initializer(it) } /** - * Create a boxing buffer of given type + * Creates a [ListBuffer] of given type [T] with given [size]. Each element is calculated by calling the + * specified [initializer] function. */ public inline fun boxing(size: Int, initializer: (Int) -> T): Buffer = ListBuffer(List(size, initializer)) + // TODO add resolution based on Annotation or companion resolution + + /** + * Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([IntBuffer], + * [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ @Suppress("UNCHECKED_CAST") - public inline fun auto(type: KClass, size: Int, crossinline initializer: (Int) -> T): Buffer { - //TODO add resolution based on Annotation or companion resolution - return when (type) { - Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer - Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer - Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer - Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer + public inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): Buffer = + when (type) { + Double::class -> RealBuffer(size) { initializer(it) as Double } as Buffer + Short::class -> ShortBuffer(size) { initializer(it) as Short } as Buffer + Int::class -> IntBuffer(size) { initializer(it) as Int } as Buffer + Long::class -> LongBuffer(size) { initializer(it) as Long } as Buffer + Float::class -> FloatBuffer(size) { initializer(it) as Float } as Buffer Complex::class -> complex(size) { initializer(it) as Complex } as Buffer else -> boxing(size, initializer) } - } /** - * Create most appropriate immutable buffer for given type avoiding boxing wherever possible + * Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([IntBuffer], + * [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. */ @Suppress("UNCHECKED_CAST") - public inline fun auto(size: Int, crossinline initializer: (Int) -> T): Buffer = + public inline fun auto(size: Int, initializer: (Int) -> T): Buffer = auto(T::class, size, initializer) } } @@ -117,25 +130,40 @@ public interface MutableBuffer : Buffer { public inline fun boxing(size: Int, initializer: (Int) -> T): MutableBuffer = MutableListBuffer(MutableList(size, initializer)) + /** + * Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used + * ([IntBuffer], [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ @Suppress("UNCHECKED_CAST") public inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer = when (type) { - Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer - Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer - Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer - Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer + Double::class -> RealBuffer(size) { initializer(it) as Double } as MutableBuffer + Short::class -> ShortBuffer(size) { initializer(it) as Short } as MutableBuffer + Int::class -> IntBuffer(size) { initializer(it) as Int } as MutableBuffer + Float::class -> FloatBuffer(size) { initializer(it) as Float } as MutableBuffer + Long::class -> LongBuffer(size) { initializer(it) as Long } as MutableBuffer + Complex::class -> complex(size) { initializer(it) as Complex } as MutableBuffer else -> boxing(size, initializer) } /** - * Create most appropriate mutable buffer for given type avoiding boxing wherever possible + * Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used + * ([IntBuffer], [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. */ @Suppress("UNCHECKED_CAST") public inline fun auto(size: Int, initializer: (Int) -> T): MutableBuffer = auto(T::class, size, initializer) - public val real: MutableBufferFactory = - { size, initializer -> RealBuffer(DoubleArray(size) { initializer(it) }) } + /** + * Creates a [RealBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer = + RealBuffer(size) { initializer(it) } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt index e3fda0e10..4965e37cf 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt @@ -48,7 +48,8 @@ public fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, Valu /** * A real buffer which supports flags for each value like NaN or Missing */ -public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer, Buffer { +public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer, + Buffer { init { require(values.size == flags.size) { "Values and flags must have the same dimensions" } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt index c171e7c1d..66c9212cf 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt @@ -53,7 +53,7 @@ public class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : public inline fun create( spec: MemorySpec, size: Int, - crossinline initializer: (Int) -> T + initializer: (Int) -> T ): MutableMemoryBuffer = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> (0 until size).forEach { buffer[it] = initializer(it) } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt index 03c601717..c1cfcbe49 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt @@ -7,49 +7,77 @@ import kscience.kmath.operations.Space import kotlin.native.concurrent.ThreadLocal /** - * An exception is thrown when the expected ans actual shape of NDArray differs + * An exception is thrown when the expected ans actual shape of NDArray differs. + * + * @property expected the expected shape. + * @property actual the actual shape. */ -public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : RuntimeException() +public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : + RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") /** - * The base interface for all nd-algebra implementations - * @param T the type of nd-structure element - * @param C the type of the element context - * @param N the type of the structure + * The base interface for all ND-algebra implementations. + * + * @param T the type of ND-structure element. + * @param C the type of the element context. + * @param N the type of the structure. */ public interface NDAlgebra> { + /** + * The shape of ND-structures this algebra operates on. + */ public val shape: IntArray + + /** + * The algebra over elements of ND structure. + */ public val elementContext: C /** - * Produce a new [N] structure using given initializer function + * Produces a new [N] structure using given initializer function. */ public fun produce(initializer: C.(IntArray) -> T): N /** - * Map elements from one structure to another one + * Maps elements from one structure to another one by applying [transform] to them. */ public fun map(arg: N, transform: C.(T) -> T): N /** - * Map indexed elements + * Maps elements from one structure to another one by applying [transform] to them alongside with their indices. */ public fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N /** - * Combine two structures into one + * Combines two structures into one. */ public fun combine(a: N, b: N, transform: C.(T, T) -> T): N /** - * Check if given elements are consistent with this context + * Checks if given element is consistent with this context. + * + * @param element the structure to check. + * @return the valid structure. */ - public fun check(vararg elements: N): Unit = elements.forEach { - if (!shape.contentEquals(it.shape)) throw ShapeMismatchException(shape, it.shape) + public fun check(element: N): N { + if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape) + return element } /** - * element-by-element invoke a function working on [T] on a [NDStructure] + * Checks if given elements are consistent with this context. + * + * @param elements the structures to check. + * @return the array of valid structures. + */ + public fun check(vararg elements: N): Array = elements + .map(NDStructure::shape) + .singleOrNull { !shape.contentEquals(it) } + ?.let { throw ShapeMismatchException(shape, it) } + ?: elements + + /** + * Element-wise invocation of function working on [T] on a [NDStructure]. */ public operator fun Function1.invoke(structure: N): N = map(structure) { value -> this@invoke(value) } @@ -57,42 +85,107 @@ public interface NDAlgebra> { } /** - * An nd-space over element space + * Space of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param S the type of space of structure elements. */ public interface NDSpace, N : NDStructure> : Space, NDAlgebra { /** - * Element-by-element addition + * Element-wise addition. + * + * @param a the addend. + * @param b the augend. + * @return the sum. */ - override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) } + public override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) } /** - * Multiply all elements by constant + * Element-wise multiplication by scalar. + * + * @param a the multiplicand. + * @param k the multiplier. + * @return the product. */ - override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) } + public override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) } - //TODO move to extensions after KEEP-176 + // TODO move to extensions after KEEP-176 + + /** + * Adds an ND structure to an element of it. + * + * @receiver the addend. + * @param arg the augend. + * @return the sum. + */ public operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) } + /** + * Subtracts an element from ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ public operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) } + /** + * Adds an element to ND structure of it. + * + * @receiver the addend. + * @param arg the augend. + * @return the sum. + */ public operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) } + + /** + * Subtracts an ND structure from an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ public operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) } public companion object } /** - * An nd-ring over element ring + * Ring of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param R the type of ring of structure elements. */ public interface NDRing, N : NDStructure> : Ring, NDSpace { /** - * Element-by-element multiplication + * Element-wise multiplication. + * + * @param a the multiplicand. + * @param b the multiplier. + * @return the product. */ - override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } + public override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } //TODO move to extensions after KEEP-176 + + /** + * Multiplies an ND structure by an element of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ public operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) } + /** + * Multiplies an element by a ND structure of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ public operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) } public companion object @@ -103,17 +196,35 @@ public interface NDRing, N : NDStructure> : Ring, NDSpace, N : NDStructure> : Field, NDRing { /** - * Element-by-element division + * Element-wise division. + * + * @param a the dividend. + * @param b the divisor. + * @return the quotient. */ - override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } + public override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } //TODO move to extensions after KEEP-176 + /** + * Divides an ND structure by an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ public operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) } + /** + * Divides an element by an ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) } @ThreadLocal @@ -121,12 +232,12 @@ public interface NDField, N : NDStructure> : Field, NDRing private val realNDFieldCache: MutableMap = hashMapOf() /** - * Create a nd-field for [Double] values or pull it from cache if it was created previously + * Create a nd-field for [Double] values or pull it from cache if it was created previously. */ public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } /** - * Create a nd-field with boxing generic buffer + * Create an ND field with boxing generic buffer. */ public fun > boxing( field: F, diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt index fd679d073..08160adf4 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt @@ -38,9 +38,8 @@ public interface NDStructure { */ public fun elements(): Sequence> - override fun equals(other: Any?): Boolean - - override fun hashCode(): Int + public override fun equals(other: Any?): Boolean + public override fun hashCode(): Int public companion object { /** @@ -50,13 +49,8 @@ public interface NDStructure { if (st1 === st2) return true // fast comparison of buffers if possible - if ( - st1 is NDBuffer && - st2 is NDBuffer && - st1.strides == st2.strides - ) { + if (st1 is NDBuffer && st2 is NDBuffer && st1.strides == st2.strides) return st1.buffer.contentEquals(st2.buffer) - } //element by element comparison if it could not be avoided return st1.elements().all { (index, value) -> value == st2[index] } @@ -70,7 +64,7 @@ public interface NDStructure { public fun build( strides: Strides, bufferFactory: BufferFactory = Buffer.Companion::boxing, - initializer: (IntArray) -> T + initializer: (IntArray) -> T, ): BufferNDStructure = BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) @@ -79,40 +73,40 @@ public interface NDStructure { */ public inline fun auto( strides: Strides, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) public inline fun auto( type: KClass, strides: Strides, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) public fun build( shape: IntArray, bufferFactory: BufferFactory = Buffer.Companion::boxing, - initializer: (IntArray) -> T + initializer: (IntArray) -> T, ): BufferNDStructure = build(DefaultStrides(shape), bufferFactory, initializer) public inline fun auto( shape: IntArray, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = auto(DefaultStrides(shape), initializer) @JvmName("autoVarArg") public inline fun auto( vararg shape: Int, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = auto(DefaultStrides(shape), initializer) public inline fun auto( type: KClass, vararg shape: Int, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = auto(type, DefaultStrides(shape), initializer) } @@ -274,6 +268,22 @@ public abstract class NDBuffer : NDStructure { result = 31 * result + buffer.hashCode() return result } + + override fun toString(): String { + val bufferRepr: String = when (shape.size) { + 1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ") + 2 -> (0 until shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i -> + (0 until shape[1]).joinToString(prefix = "[", postfix = "]", separator = ", ") { j -> + val offset = strides.offset(intArrayOf(i, j)) + buffer[offset].toString() + } + } + else -> "..." + } + return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)" + } + + } /** @@ -281,7 +291,7 @@ public abstract class NDBuffer : NDStructure { */ public class BufferNDStructure( override val strides: Strides, - override val buffer: Buffer + override val buffer: Buffer, ) : NDBuffer() { init { if (strides.linearSize != buffer.size) { @@ -295,7 +305,7 @@ public class BufferNDStructure( */ public inline fun NDStructure.mapToBuffer( factory: BufferFactory = Buffer.Companion::auto, - crossinline transform: (T) -> R + crossinline transform: (T) -> R, ): BufferNDStructure { return if (this is BufferNDStructure) BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) @@ -310,7 +320,7 @@ public inline fun NDStructure.mapToBuffer( */ public class MutableBufferNDStructure( override val strides: Strides, - override val buffer: MutableBuffer + override val buffer: MutableBuffer, ) : NDBuffer(), MutableNDStructure { init { @@ -324,7 +334,7 @@ public class MutableBufferNDStructure( public inline fun NDStructure.combine( struct: NDStructure, - crossinline block: (T, T) -> T + crossinline block: (T, T) -> T, ): NDStructure { require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" } return NDStructure.auto(shape) { block(this[it], struct[it]) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure1D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure1D.kt index af5cc9e3f..95422ac60 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure1D.kt @@ -17,7 +17,7 @@ public interface Structure1D : NDStructure, Buffer { /** * A 1D wrapper for nd-structure */ -private inline class Structure1DWrapper(public val structure: NDStructure) : Structure1D { +private inline class Structure1DWrapper(val structure: NDStructure) : Structure1D { override val shape: IntArray get() = structure.shape override val size: Int get() = structure.shape[0] diff --git a/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/coroutines/coroutinesExtra.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/coroutines/coroutinesExtra.kt index 351207111..7dcdc0d62 100644 --- a/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/coroutines/coroutinesExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/coroutines/coroutinesExtra.kt @@ -21,7 +21,8 @@ internal class LazyDeferred(val dispatcher: CoroutineDispatcher, val block: s } public class AsyncFlow internal constructor(internal val deferredFlow: Flow>) : Flow { - override suspend fun collect(collector: FlowCollector): Unit = deferredFlow.collect { collector.emit((it.await())) } + override suspend fun collect(collector: FlowCollector): Unit = + deferredFlow.collect { collector.emit((it.await())) } } public fun Flow.async( diff --git a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Dimensions.kt b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Dimensions.kt index f49e1e0f0..9450f9174 100644 --- a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Dimensions.kt +++ b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Dimensions.kt @@ -3,8 +3,9 @@ package kscience.kmath.dimensions import kotlin.reflect.KClass /** - * An abstract class which is not used in runtime. Designates a size of some structure. - * Could be replaced later by fully inline constructs + * Represents a quantity of dimensions in certain structure. + * + * @property dim The number of dimensions. */ public interface Dimension { public val dim: UInt @@ -16,18 +17,33 @@ public fun KClass.dim(): UInt = Dimension.resolve(this).dim public expect fun Dimension.Companion.resolve(type: KClass): D +/** + * Finds or creates [Dimension] with [Dimension.dim] equal to [dim]. + */ public expect fun Dimension.Companion.of(dim: UInt): Dimension +/** + * Finds [Dimension.dim] of given type [D]. + */ public inline fun Dimension.Companion.dim(): UInt = D::class.dim() +/** + * Type representing 1 dimension. + */ public object D1 : Dimension { override val dim: UInt get() = 1U } +/** + * Type representing 2 dimensions. + */ public object D2 : Dimension { override val dim: UInt get() = 2U } +/** + * Type representing 3 dimensions. + */ public object D3 : Dimension { override val dim: UInt get() = 3U } diff --git a/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt similarity index 100% rename from kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt rename to kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt diff --git a/kmath-ejml/build.gradle.kts b/kmath-ejml/build.gradle.kts new file mode 100644 index 000000000..fa4aa3e39 --- /dev/null +++ b/kmath-ejml/build.gradle.kts @@ -0,0 +1,8 @@ +plugins { + id("ru.mipt.npm.jvm") +} + +dependencies { + implementation("org.ejml:ejml-simple:0.39") + implementation(project(":kmath-core")) +} diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt new file mode 100644 index 000000000..ed6b1571e --- /dev/null +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt @@ -0,0 +1,71 @@ +package kscience.kmath.ejml + +import org.ejml.dense.row.factory.DecompositionFactory_DDRM +import org.ejml.simple.SimpleMatrix +import kscience.kmath.linear.DeterminantFeature +import kscience.kmath.linear.FeaturedMatrix +import kscience.kmath.linear.LUPDecompositionFeature +import kscience.kmath.linear.MatrixFeature +import kscience.kmath.structures.NDStructure + +/** + * Represents featured matrix over EJML [SimpleMatrix]. + * + * @property origin the underlying [SimpleMatrix]. + * @author Iaroslav Postovalov + */ +public class EjmlMatrix(public val origin: SimpleMatrix, features: Set? = null) : FeaturedMatrix { + public override val rowNum: Int + get() = origin.numRows() + + public override val colNum: Int + get() = origin.numCols() + + public override val shape: IntArray + get() = intArrayOf(origin.numRows(), origin.numCols()) + + public override val features: Set = setOf( + object : LUPDecompositionFeature, DeterminantFeature { + override val determinant: Double + get() = origin.determinant() + + private val lup by lazy { + val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()) + .also { it.decompose(origin.ddrm.copy()) } + + Triple( + EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), + EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), + EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), + ) + } + + override val l: FeaturedMatrix + get() = lup.second + + override val u: FeaturedMatrix + get() = lup.third + + override val p: FeaturedMatrix + get() = lup.first + } + ) union features.orEmpty() + + public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix = + EjmlMatrix(origin, this.features + features) + + public override operator fun get(i: Int, j: Int): Double = origin[i, j] + + public override fun equals(other: Any?): Boolean { + if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0) + return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + } + + public override fun hashCode(): Int { + var result = origin.hashCode() + result = 31 * result + features.hashCode() + return result + } + + public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)" +} diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt new file mode 100644 index 000000000..52826a7b1 --- /dev/null +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -0,0 +1,86 @@ +package kscience.kmath.ejml + +import org.ejml.simple.SimpleMatrix +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.Point +import kscience.kmath.operations.Space +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Matrix + +/** + * Represents context of basic operations operating with [EjmlMatrix]. + * + * @author Iaroslav Postovalov + */ +public class EjmlMatrixContext(private val space: Space) : MatrixContext { + /** + * Converts this matrix to EJML one. + */ + public fun Matrix.toEjml(): EjmlMatrix = + if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) } + + /** + * Converts this vector to EJML one. + */ + public fun Point.toEjml(): EjmlVector = + if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also { + (0 until it.numRows()).forEach { row -> it[row, 0] = get(row) } + }) + + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix = + EjmlMatrix(SimpleMatrix(rows, columns).also { + (0 until it.numRows()).forEach { row -> + (0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) } + } + }) + + public override fun Matrix.dot(other: Matrix): EjmlMatrix = + EjmlMatrix(toEjml().origin.mult(other.toEjml().origin)) + + public override fun Matrix.dot(vector: Point): EjmlVector = + EjmlVector(toEjml().origin.mult(vector.toEjml().origin)) + + public override fun add(a: Matrix, b: Matrix): EjmlMatrix = + EjmlMatrix(a.toEjml().origin + b.toEjml().origin) + + public override operator fun Matrix.minus(b: Matrix): EjmlMatrix = + EjmlMatrix(toEjml().origin - b.toEjml().origin) + + public override fun multiply(a: Matrix, k: Number): EjmlMatrix = + produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } + + public override operator fun Matrix.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value)) + + public companion object +} + +/** + * Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix. + * + * @param a the base matrix. + * @param b n by p matrix. + * @return the solution for 'x' that is n by p. + * @author Iaroslav Postovalov + */ +public fun EjmlMatrixContext.solve(a: Matrix, b: Matrix): EjmlMatrix = + EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin)) + +/** + * Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix. + * + * @param a the base matrix. + * @param b n by p vector. + * @return the solution for 'x' that is n by p. + * @author Iaroslav Postovalov + */ +public fun EjmlMatrixContext.solve(a: Matrix, b: Point): EjmlVector = + EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) + +/** + * Returns the inverse of given matrix: b = a^(-1). + * + * @param a the matrix. + * @return the inverse of this matrix. + * @author Iaroslav Postovalov + */ +public fun EjmlMatrixContext.inverse(a: Matrix): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert()) diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt new file mode 100644 index 000000000..f7cd1b66d --- /dev/null +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt @@ -0,0 +1,40 @@ +package kscience.kmath.ejml + +import org.ejml.simple.SimpleMatrix +import kscience.kmath.linear.Point +import kscience.kmath.structures.Buffer + +/** + * Represents point over EJML [SimpleMatrix]. + * + * @property origin the underlying [SimpleMatrix]. + * @author Iaroslav Postovalov + */ +public class EjmlVector internal constructor(public val origin: SimpleMatrix) : Point { + public override val size: Int + get() = origin.numRows() + + init { + require(origin.numCols() == 1) { "Only single column matrices are allowed" } + } + + public override operator fun get(index: Int): Double = origin[index] + + public override operator fun iterator(): Iterator = object : Iterator { + private var cursor: Int = 0 + + override fun next(): Double { + cursor += 1 + return origin[cursor - 1] + } + + override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows() + } + + public override fun contentEquals(other: Buffer<*>): Boolean { + if (other is EjmlVector) return origin.isIdentical(other.origin, 0.0) + return super.contentEquals(other) + } + + public override fun toString(): String = "EjmlVector(origin=$origin)" +} diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt new file mode 100644 index 000000000..e0f15be83 --- /dev/null +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -0,0 +1,75 @@ +package kscience.kmath.ejml + +import kscience.kmath.linear.DeterminantFeature +import kscience.kmath.linear.LUPDecompositionFeature +import kscience.kmath.linear.MatrixFeature +import kscience.kmath.linear.getFeature +import org.ejml.dense.row.factory.DecompositionFactory_DDRM +import org.ejml.simple.SimpleMatrix +import kotlin.random.Random +import kotlin.random.asJavaRandom +import kotlin.test.* + +internal class EjmlMatrixTest { + private val random = Random(0) + + private val randomMatrix: SimpleMatrix + get() { + val s = random.nextInt(2, 100) + return SimpleMatrix.random_DDRM(s, s, 0.0, 10.0, random.asJavaRandom()) + } + + @Test + fun rowNum() { + val m = randomMatrix + assertEquals(m.numRows(), EjmlMatrix(m).rowNum) + } + + @Test + fun colNum() { + val m = randomMatrix + assertEquals(m.numCols(), EjmlMatrix(m).rowNum) + } + + @Test + fun shape() { + val m = randomMatrix + val w = EjmlMatrix(m) + assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList()) + } + + @Test + fun features() { + val m = randomMatrix + val w = EjmlMatrix(m) + val det = w.getFeature>() ?: fail() + assertEquals(m.determinant(), det.determinant) + val lup = w.getFeature>() ?: fail() + + val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols()) + .also { it.decompose(m.ddrm.copy()) } + + assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), lup.l) + assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), lup.u) + assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), lup.p) + } + + private object SomeFeature : MatrixFeature {} + + @Test + fun suggestFeature() { + assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature()) + } + + @Test + fun get() { + val m = randomMatrix + assertEquals(m[0, 0], EjmlMatrix(m)[0, 0]) + } + + @Test + fun origin() { + val m = randomMatrix + assertSame(m, EjmlMatrix(m).origin) + } +} diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlVectorTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlVectorTest.kt new file mode 100644 index 000000000..e27f977d2 --- /dev/null +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlVectorTest.kt @@ -0,0 +1,47 @@ +package kscience.kmath.ejml + +import org.ejml.simple.SimpleMatrix +import kotlin.random.Random +import kotlin.random.asJavaRandom +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertSame + +internal class EjmlVectorTest { + private val random = Random(0) + + private val randomMatrix: SimpleMatrix + get() = SimpleMatrix.random_DDRM(random.nextInt(2, 100), 1, 0.0, 10.0, random.asJavaRandom()) + + @Test + fun size() { + val m = randomMatrix + val w = EjmlVector(m) + assertEquals(m.numRows(), w.size) + } + + @Test + fun get() { + val m = randomMatrix + val w = EjmlVector(m) + assertEquals(m[0, 0], w[0]) + } + + @Test + fun iterator() { + val m = randomMatrix + val w = EjmlVector(m) + + assertEquals( + m.iterator(true, 0, 0, m.numRows() - 1, 0).asSequence().toList(), + w.iterator().asSequence().toList() + ) + } + + @Test + fun origin() { + val m = randomMatrix + val w = EjmlVector(m) + assertSame(m, w.origin) + } +} diff --git a/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/GeometrySpace.kt b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/GeometrySpace.kt index 64badacf5..54d2510cf 100644 --- a/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/GeometrySpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/GeometrySpace.kt @@ -4,7 +4,7 @@ import kscience.kmath.operations.Space public interface Vector -public interface GeometrySpace: Space { +public interface GeometrySpace : Space { /** * L2 distance */ diff --git a/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Histogram.kt index 98300dada..370a01215 100644 --- a/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Histogram.kt @@ -10,9 +10,10 @@ import kscience.kmath.structures.RealBuffer */ public interface Bin : Domain { /** - * The value of this bin + * The value of this bin. */ public val value: Number + public val center: Point } diff --git a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt index 87292f17e..eebb41019 100644 --- a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt @@ -5,10 +5,7 @@ import kscience.kmath.histogram.fill import kscience.kmath.histogram.put import kscience.kmath.real.RealVector import kotlin.random.Random -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertTrue +import kotlin.test.* internal class MultivariateHistogramTest { @Test @@ -18,7 +15,7 @@ internal class MultivariateHistogramTest { (-1.0..1.0) ) histogram.put(0.55, 0.55) - val bin = histogram.find { it.value.toInt() > 0 }!! + val bin = histogram.find { it.value.toInt() > 0 } ?: fail() assertTrue { bin.contains(RealVector(0.55, 0.55)) } assertTrue { bin.contains(RealVector(0.6, 0.5)) } assertFalse { bin.contains(RealVector(-0.55, 0.55)) } diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt index b3f1524ea..965635e09 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt @@ -63,7 +63,7 @@ public fun Sampler.sampleBuffer( //clear list from previous run tmp.clear() //Fill list - repeat(size){ + repeat(size) { tmp.add(chain.next()) } //return new buffer with elements from tmp diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomGenerator.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomGenerator.kt index 0d95d6f97..2dd4ce51e 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomGenerator.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/RandomGenerator.kt @@ -3,16 +3,59 @@ package kscience.kmath.prob import kotlin.random.Random /** - * A basic generator + * An interface that is implemented by random number generator algorithms. */ public interface RandomGenerator { + /** + * Gets the next random [Boolean] value. + */ public fun nextBoolean(): Boolean + + /** + * Gets the next random [Double] value uniformly distributed between 0 (inclusive) and 1 (exclusive). + */ public fun nextDouble(): Double + + /** + * Gets the next random `Int` from the random number generator. + * + * Generates an `Int` random value uniformly distributed between [Int.MIN_VALUE] and [Int.MAX_VALUE] (inclusive). + */ public fun nextInt(): Int + + /** + * Gets the next random non-negative `Int` from the random number generator less than the specified [until] bound. + * + * Generates an `Int` random value uniformly distributed between `0` (inclusive) and the specified [until] bound + * (exclusive). + */ public fun nextInt(until: Int): Int + + /** + * Gets the next random `Long` from the random number generator. + * + * Generates a `Long` random value uniformly distributed between [Long.MIN_VALUE] and [Long.MAX_VALUE] (inclusive). + */ public fun nextLong(): Long + + /** + * Gets the next random non-negative `Long` from the random number generator less than the specified [until] bound. + * + * Generates a `Long` random value uniformly distributed between `0` (inclusive) and the specified [until] bound (exclusive). + */ public fun nextLong(until: Long): Long + + /** + * Fills a subrange of the specified byte [array] starting from [fromIndex] inclusive and ending [toIndex] exclusive + * with random bytes. + * + * @return [array] with the subrange filled with random bytes. + */ public fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size) + + /** + * Creates a byte array of the specified [size], filled with random bytes. + */ public fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) } /** @@ -25,12 +68,21 @@ public interface RandomGenerator { public fun fork(): RandomGenerator public companion object { - public val default: DefaultGenerator by lazy { DefaultGenerator() } + /** + * The [DefaultGenerator] instance. + */ + public val default: DefaultGenerator by lazy(::DefaultGenerator) + /** + * Returns [DefaultGenerator] of given [seed]. + */ public fun default(seed: Long): DefaultGenerator = DefaultGenerator(Random(seed)) } } +/** + * Implements [RandomGenerator] by delegating all operations to [Random]. + */ public inline class DefaultGenerator(public val random: Random = Random) : RandomGenerator { public override fun nextBoolean(): Boolean = random.nextBoolean() public override fun nextDouble(): Double = random.nextDouble() diff --git a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/SamplerTest.kt b/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/SamplerTest.kt index 3d8a4f531..75db5c402 100644 --- a/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/SamplerTest.kt +++ b/kmath-prob/src/jvmTest/kotlin/kscience/kmath/prob/SamplerTest.kt @@ -6,7 +6,7 @@ import kotlin.test.Test class SamplerTest { @Test - fun bufferSamplerTest(){ + fun bufferSamplerTest() { val sampler: Sampler = BasicSampler { it.chain { nextDouble() } } val data = sampler.sampleBuffer(RandomGenerator.default, 100) diff --git a/settings.gradle.kts b/settings.gradle.kts index 78372f1fa..ad42ec250 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -40,5 +40,6 @@ include( ":kmath-for-real", ":kmath-geometry", ":kmath-ast", - ":examples" + ":examples", + ":kmath-ejml" )