diff --git a/build.gradle.kts b/build.gradle.kts index 10e030520..934cf956f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("scientifik.publish") apply false } -val kmathVersion by extra("0.1.4-dev-7") +val kmathVersion by extra("0.1.4-dev-8") val bintrayRepo by extra("scientifik") val githubProject by extra("kmath") diff --git a/doc/buffers.md b/doc/buffers.md index b0b7489b3..52a9df86e 100644 --- a/doc/buffers.md +++ b/doc/buffers.md @@ -2,7 +2,7 @@ Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). There are different types of buffers: -* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. +* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays. * Boxing `ListBuffer` wrapping a list * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * `MemoryBuffer` allows direct allocation of objects in continuous memory block. diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 2fab47ac0..73def3572 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.71" - id("kotlinx.benchmark") version "0.2.0-dev-7" + kotlin("plugin.allopen") version "1.3.72" + id("kotlinx.benchmark") version "0.2.0-dev-8" } configure { @@ -24,6 +24,7 @@ sourceSets { } dependencies { + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) @@ -33,8 +34,8 @@ dependencies { implementation(project(":kmath-dimensions")) implementation("com.kyonifer:koma-core-ejml:0.12") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") - implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7") - "benchmarksCompile"(sourceSets.main.get().compileClasspath) + implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8") + "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath } // Configure benchmark diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt index 9676b5e4a..e40b0c4b7 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt @@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex class BufferBenchmark { @Benchmark - fun genericDoubleBufferReadWrite() { - val buffer = DoubleBuffer(size){it.toDouble()} + fun genericRealBufferReadWrite() { + val buffer = RealBuffer(size){it.toDouble()} (0 until size).forEach { buffer[it] diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index be4115d81..f7b9661ef 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -20,48 +20,39 @@ class ViktorBenchmark { final val viktorField = ViktorNDField(intArrayOf(dim, dim)) @Benchmark - fun `Automatic field addition`() { + fun automaticFieldAddition() { autoField.run { var res = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += one } } } @Benchmark - fun `Viktor field addition`() { + fun viktorFieldAddition() { viktorField.run { var res = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } @Benchmark - fun `Raw Viktor`() { + fun rawViktor() { val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) var res = one - repeat(n) { - res = res + one - } + repeat(n) { res = res + one } } @Benchmark - fun `Real field log`() { + fun realdFieldLog() { realField.run { val fortyTwo = produce { 42.0 } var res = one - - repeat(n) { - res = ln(fortyTwo) - } + repeat(n) { res = ln(fortyTwo) } } } @Benchmark - fun `Raw Viktor log`() { + fun rawViktorLog() { val fortyTwo = F64Array.full(dim, dim, init = 42.0) var res: F64Array repeat(n) { diff --git a/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt new file mode 100644 index 000000000..17a70a4aa --- /dev/null +++ b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -0,0 +1,70 @@ +package scientifik.kmath.ast + +import scientifik.kmath.asm.compile +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.expressionInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.RealField +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + fun functionalExpression() { + val expr = algebra.expressionInField { + variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) + } + + invokeAndSum(expr) + } + + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + } + + invokeAndSum(expr) + } + + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + }.compile() + + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + val random = Random(0) + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} + +fun main() { + val benchmark = ExpressionsInterpretersBenchmark() + + val fe = measureTimeMillis { + benchmark.functionalExpression() + } + + println("fe=$fe") + + val mst = measureTimeMillis { + benchmark.mstExpression() + } + + println("mst=$mst") + + val asm = measureTimeMillis { + benchmark.asmExpression() + } + + println("asm=$asm") +} diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt index cc8b68d85..991cd34a1 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt @@ -27,7 +27,7 @@ fun main() { val complexTime = measureTimeMillis { complexField.run { - var res = one + var res: NDBuffer = one repeat(n) { res += 1.0 } diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt index cfd1206ff..2aafb504d 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt @@ -23,14 +23,14 @@ fun main() { measureAndPrint("Automatic field addition") { autoField.run { - var res = one + var res: NDBuffer = one repeat(n) { - res += 1.0 + res += number(1.0) } } } - measureAndPrint("Element addition"){ + measureAndPrint("Element addition") { var res = genericField.one repeat(n) { res += 1.0 @@ -63,7 +63,7 @@ fun main() { genericField.run { var res: NDBuffer = one repeat(n) { - res += 1.0 + res += one // con't avoid using `one` due to resolution ambiguity } } } diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt index ecfb4ab20..a33fdb2c4 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt @@ -6,7 +6,7 @@ fun main(args: Array) { val n = 6000 val array = DoubleArray(n * n) { 1.0 } - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) val strides = DefaultStrides(intArrayOf(n, n)) val structure = BufferNDStructure(strides, buffer) diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt index 2d16cc8f4..0241f12ad 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt @@ -26,10 +26,10 @@ fun main(args: Array) { } println("Array mapping finished in $time2 millis") - val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 }) + val buffer = RealBuffer(DoubleArray(n * n) { 1.0 }) val time3 = measureTimeMillis { - val target = DoubleBuffer(DoubleArray(n * n)) + val target = RealBuffer(DoubleArray(n * n)) val res = array.forEachIndexed { index, value -> target[index] = value + 1 } diff --git a/kmath-ast/README.md b/kmath-ast/README.md new file mode 100644 index 000000000..62b18b4b5 --- /dev/null +++ b/kmath-ast/README.md @@ -0,0 +1,62 @@ +# AST-based expression representation and operations (`kmath-ast`) + +This subproject implements the following features: + +- Expression Language and its parser. +- MST as expression language's syntax intermediate representation. +- Type-safe builder of MST. +- Evaluating expressions by traversing MST. + +## Dynamic expression code generation with OW2 ASM + +`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds +a special implementation of `Expression` with implemented `invoke` function. + +For example, the following builder: + +```kotlin + RealField.mstInField { symbol("x") + 2 }.compile() +``` + +… leads to generation of bytecode, which can be decompiled to the following Java class: + +```java +package scientifik.kmath.asm.generated; + +import java.util.Map; +import scientifik.kmath.asm.internal.MapIntrinsics; +import scientifik.kmath.expressions.Expression; +import scientifik.kmath.operations.RealField; + +public final class AsmCompiledExpression_1073786867_0 implements Expression { + private final RealField algebra; + private final Object[] constants; + + public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) { + this.algebra = algebra; + this.constants = constants; + } + + public final Double invoke(Map arguments) { + return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D); + } +} + +``` + +### Example Usage + +This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +RealField.expression("x+2".parseMath()) +``` + +### Known issues + +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid +class loading overhead. +- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts new file mode 100644 index 000000000..511571fc9 --- /dev/null +++ b/kmath-ast/build.gradle.kts @@ -0,0 +1,37 @@ +plugins { + id("scientifik.mpp") +} + +repositories { + maven("https://dl.bintray.com/hotkeytlt/maven") +} + +kotlin.sourceSets { +// all { +// languageSettings.apply{ +// enableLanguageFeature("NewInference") +// } +// } + commonMain { + dependencies { + api(project(":kmath-core")) + implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") + implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3") + } + } + + jvmMain { + dependencies { + implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3") + implementation("org.ow2.asm:asm:8.0.1") + implementation("org.ow2.asm:asm-commons:8.0.1") + implementation(kotlin("reflect")) + } + } + + jsMain { + dependencies { + implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3") + } + } +} \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt new file mode 100644 index 000000000..142d27f93 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -0,0 +1,67 @@ +package scientifik.kmath.ast + +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.operations.RealField + +/** + * A Mathematical Syntax Tree node for mathematical expressions + */ +sealed class MST { + + /** + * A node containing unparsed string + */ + data class Symbolic(val value: String) : MST() + + /** + * A node containing a number + */ + data class Numeric(val value: Number) : MST() + + /** + * A node containing an unary operation + */ + data class Unary(val operation: String, val value: MST) : MST() { + companion object { + const val ABS_OPERATION = "abs" + //TODO add operations + } + } + + /** + * A node containing binary operation + */ + data class Binary(val operation: String, val left: MST, val right: MST) : MST() { + companion object + } +} + +//TODO add a function with positional arguments + +//TODO add a function with named arguments + +fun Algebra.evaluate(node: MST): T { + return when (node) { + is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) + ?: error("Numeric nodes are not supported by $this") + is MST.Symbolic -> symbol(node.value) + is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) + is MST.Binary -> when { + this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + node.left is MST.Numeric && node.right is MST.Numeric -> { + val number = RealField.binaryOperation( + node.operation, + node.left.value.toDouble(), + node.right.value.toDouble() + ) + number(number) + } + node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) + node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) + else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + } + } +} + +fun MST.compile(algebra: Algebra): T = algebra.evaluate(this) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt new file mode 100644 index 000000000..007cf57c4 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt @@ -0,0 +1,72 @@ +package scientifik.kmath.ast + +import scientifik.kmath.operations.* + +object MstAlgebra : NumericAlgebra { + override fun number(value: Number): MST = MST.Numeric(value) + + override fun symbol(value: String): MST = MST.Symbolic(value) + + override fun unaryOperation(operation: String, arg: MST): MST = + MST.Unary(operation, arg) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MST.Binary(operation, left, right) +} + +object MstSpace : Space, NumericAlgebra { + override val zero: MST = number(0.0) + + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + + override fun add(a: MST, b: MST): MST = + binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} + +object MstRing : Ring, NumericAlgebra { + override val zero: MST = number(0.0) + override val one: MST = number(1.0) + + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} + +object MstField : Field { + override val zero: MST = number(0.0) + override val one: MST = number(1.0) + + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt new file mode 100644 index 000000000..1468c3ad4 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -0,0 +1,55 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.FunctionalExpressionField +import scientifik.kmath.expressions.FunctionalExpressionRing +import scientifik.kmath.expressions.FunctionalExpressionSpace +import scientifik.kmath.operations.* + +/** + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. + */ +class MstExpression(val algebra: Algebra, val mst: MST) : Expression { + + /** + * Substitute algebra raw value + */ + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T = + algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if (algebra is NumericAlgebra) + algebra.number(value) + else + error("Numeric nodes are not supported by $this") + } + + override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} + + +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MstExpression = MstExpression(this, mstAlgebra.block()) + +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = + MstExpression(this, MstSpace.block()) + +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = + MstExpression(this, MstRing.block()) + +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = + MstExpression(this, MstField.block()) + +inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression = + algebra.mstInSpace(block) + +inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression = + algebra.mstInRing(block) + +inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression = + algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt new file mode 100644 index 000000000..30a92c5ae --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt @@ -0,0 +1,59 @@ +package scientifik.kmath.ast + +import com.github.h0tk3y.betterParse.combinators.* +import com.github.h0tk3y.betterParse.grammar.Grammar +import com.github.h0tk3y.betterParse.grammar.parseToEnd +import com.github.h0tk3y.betterParse.grammar.parser +import com.github.h0tk3y.betterParse.grammar.tryParseToEnd +import com.github.h0tk3y.betterParse.parser.ParseResult +import com.github.h0tk3y.betterParse.parser.Parser +import scientifik.kmath.operations.FieldOperations +import scientifik.kmath.operations.PowerOperations +import scientifik.kmath.operations.RingOperations +import scientifik.kmath.operations.SpaceOperations + +/** + * TODO move to common + */ +private object ArithmeticsEvaluator : Grammar() { + val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) + val lpar by token("\\(".toRegex()) + val rpar by token("\\)".toRegex()) + val mul by token("\\*".toRegex()) + val pow by token("\\^".toRegex()) + val div by token("/".toRegex()) + val minus by token("-".toRegex()) + val plus by token("\\+".toRegex()) + val ws by token("\\s+".toRegex(), ignore = true) + + val number: Parser by num use { MST.Numeric(text.toDouble()) } + + val term: Parser by number or + (skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or + (skip(lpar) and parser(this::rootParser) and skip(rpar)) + + val powChain by leftAssociative(term, pow) { a, _, b -> + MST.Binary(PowerOperations.POW_OPERATION, a, b) + } + + val divMulChain: Parser by leftAssociative(powChain, div or mul use { type }) { a, op, b -> + if (op == div) { + MST.Binary(FieldOperations.DIV_OPERATION, a, b) + } else { + MST.Binary(RingOperations.TIMES_OPERATION, a, b) + } + } + + val subSumChain: Parser by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b -> + if (op == plus) { + MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) + } else { + MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) + } + } + + override val rootParser: Parser by subSumChain +} + +fun String.tryParseMath(): ParseResult = ArithmeticsEvaluator.tryParseToEnd(this) +fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt new file mode 100644 index 000000000..ef2330533 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -0,0 +1,60 @@ +package scientifik.kmath.asm + +import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.buildAlgebraOperationCall +import scientifik.kmath.asm.internal.buildName +import scientifik.kmath.ast.MST +import scientifik.kmath.ast.MstExpression +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra +import kotlin.reflect.KClass + +/** + * Compile given MST to an Expression using AST compiler + */ +fun MST.compileWith(type: KClass, algebra: Algebra): Expression { + fun AsmBuilder.visit(node: MST) { + when (node) { + is MST.Symbolic -> loadVariable(node.value) + + is MST.Numeric -> { + val constant = if (algebra is NumericAlgebra) + algebra.number(node.value) + else + error("Number literals are not supported in $algebra") + + loadTConstant(constant) + } + + is MST.Unary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "unaryOperation", + arity = 1 + ) { visit(node.value) } + + is MST.Binary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "binaryOperation", + arity = 2 + ) { + visit(node.left) + visit(node.right) + } + } + } + + return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() +} + +/** + * Compile an [MST] to ASM using given algebra + */ +inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) + +/** + * Optimize performance of an [MstExpression] using ASM codegen + */ +inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt new file mode 100644 index 000000000..cea6be933 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -0,0 +1,518 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.commons.InstructionAdapter +import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader +import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import java.util.* +import kotlin.reflect.KClass + +/** + * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. + * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. + * + * @param T the type of AsmExpression to unwrap. + * @param algebra the algebra the applied AsmExpressions use. + * @param className the unique class name of new loaded class. + * @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. + */ +internal class AsmBuilder internal constructor( + private val classOfT: KClass<*>, + private val algebra: Algebra, + private val className: String, + private val invokeLabel0Visitor: AsmBuilder.() -> Unit +) { + /** + * Internal classloader of [AsmBuilder] with alias to define class from byte array. + */ + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + /** + * The instance of [ClassLoader] used by this builder. + */ + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + + /** + * ASM Type for [algebra] + */ + private val tAlgebraType: Type = algebra::class.asm + + /** + * ASM type for [T] + */ + internal val tType: Type = classOfT.asm + + /** + * ASM type for new class + */ + private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + + /** + * Index of `this` variable in invoke method of the built subclass. + */ + private val invokeThisVar: Int = 0 + + /** + * Index of `arguments` variable in invoke method of the built subclass. + */ + private val invokeArgumentsVar: Int = 1 + + /** + * List of constants to provide to the subclass. + */ + private val constants: MutableList = mutableListOf() + + /** + * Method visitor of `invoke` method of the subclass. + */ + private lateinit var invokeMethodVisitor: InstructionAdapter + + /** + * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. + */ + internal var primitiveMode: Boolean = false + + /** + * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMask: Type = OBJECT_TYPE + + /** + * Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMaskBoxed: Type = OBJECT_TYPE + + /** + * Stack of useful objects types on stack to verify types. + */ + private val typeStack: ArrayDeque = ArrayDeque() + + /** + * Stack of useful objects types on stack expected by algebra calls. + */ + internal val expectationStack: ArrayDeque = ArrayDeque().apply { push(tType) } + + /** + * The cache for instance built by this builder. + */ + private var generatedInstance: Expression? = null + + /** + * Subclasses, loads and instantiates [Expression] for given parameters. + * + * The built instance is cached. + */ + @Suppress("UNCHECKED_CAST") + fun getInstance(): Expression { + generatedInstance?.let { return it } + + if (SIGNATURE_LETTERS.containsKey(classOfT)) { + primitiveMode = true + primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) + primitiveMaskBoxed = tType + } + + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { + visit( + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName) + ) + + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "algebra", + descriptor = tAlgebraType.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitMethod( + ACC_PUBLIC, + "", + Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), + null, + null + ).instructionAdapter { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = label() + load(thisVar, classType) + invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + label() + load(thisVar, classType) + load(algebraVar, tAlgebraType) + putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + label() + load(thisVar, classType) + load(constantsVar, OBJECT_ARRAY_TYPE) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + label() + visitInsn(RETURN) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) + + visitLocalVariable( + "algebra", + tAlgebraType.descriptor, + null, + l0, + l4, + algebraVar + ) + + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) + visitMaxs(0, 3) + visitEnd() + } + + visitMethod( + ACC_PUBLIC or ACC_FINAL, + "invoke", + Type.getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", + null + ).instructionAdapter { + invokeMethodVisitor = this + visitCode() + val l0 = label() + invokeLabel0Visitor() + areturn(tType) + val l1 = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + l0, + l1, + invokeThisVar + ) + + visitLocalVariable( + "arguments", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", + l0, + l1, + invokeArgumentsVar + ) + + visitMaxs(0, 2) + visitEnd() + } + + visitMethod( + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, + "invoke", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null + ).instructionAdapter { + val thisVar = 0 + val argumentsVar = 1 + visitCode() + val l0 = label() + load(thisVar, OBJECT_TYPE) + load(argumentsVar, MAP_TYPE) + invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val l1 = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + l0, + l1, + thisVar + ) + + visitMaxs(0, 2) + visitEnd() + } + + visitEnd() + } + + val new = classLoader + .defineClass(className, classWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, constants.toTypedArray()) as Expression + + generatedInstance = new + return new + } + + /** + * Loads a [T] constant from [constants]. + */ + internal fun loadTConstant(value: T) { + if (classOfT in INLINABLE_NUMBERS) { + val expectedType = expectationStack.pop() + val mustBeBoxed = expectedType.sort == Type.OBJECT + loadNumberConstant(value as Number, mustBeBoxed) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) + return + } + + loadConstant(value as Any, tType) + } + + /** + * Boxes the current value and pushes it. + */ + private fun box(): Unit = invokeMethodVisitor.invokestatic( + tType.internalName, + "valueOf", + Type.getMethodDescriptor(tType, primitiveMask), + false + ) + + /** + * Unboxes the current boxed value and pushes it. + */ + private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( + NUMBER_TYPE.internalName, + NUMBER_CONVERTER_METHODS.getValue(primitiveMask), + Type.getMethodDescriptor(primitiveMask), + false + ) + + /** + * Loads [java.lang.Object] constant from constants. + */ + private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + loadThis() + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + checkcast(type) + } + + /** + * Loads this variable. + */ + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) + + /** + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive + * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded + * from it). + */ + private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { + val boxed = value::class.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] + + if (primitive != null) { + when (primitive) { + Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + + if (mustBeBoxed) { + box() + invokeMethodVisitor.checkcast(tType) + } + + return + } + + loadConstant(value, boxed) + + if (!mustBeBoxed) unbox() + else invokeMethodVisitor.checkcast(tType) + } + + /** + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be + * provided. + */ + internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + load(invokeArgumentsVar, MAP_TYPE) + aconst(name) + + if (defaultValue != null) + loadTConstant(defaultValue) + else + aconst(null) + + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE), + false + ) + + checkcast(tType) + + val expectedType = expectationStack.pop() + + if (expectedType.sort == Type.OBJECT) + typeStack.push(tType) + else { + unbox() + typeStack.push(primitiveMask) + } + } + + /** + * Loads algebra from according field of the class and casts it to class of [algebra] provided. + */ + internal fun loadAlgebra() { + loadThis() + invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) + } + + /** + * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is + * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be + * called before the arguments and this operation. + * + * The result is casted to [T] automatically. + */ + internal fun invokeAlgebraOperation( + owner: String, + method: String, + descriptor: String, + expectedArity: Int, + opcode: Int = INVOKEINTERFACE + ) { + run loop@{ + repeat(expectedArity) { + if (typeStack.isEmpty()) return@loop + typeStack.pop() + } + } + + invokeMethodVisitor.visitMethodInsn( + opcode, + owner, + method, + descriptor, + opcode == INVOKEINTERFACE + ) + + invokeMethodVisitor.checkcast(tType) + val isLastExpr = expectationStack.size == 1 + val expectedType = expectationStack.pop() + + if (expectedType.sort == Type.OBJECT || isLastExpr) + typeStack.push(tType) + else { + unbox() + typeStack.push(primitiveMask) + } + } + + /** + * Writes a LDC Instruction with string constant provided. + */ + internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) + + internal companion object { + /** + * Maps JVM primitive numbers boxed types to their primitive ASM types. + */ + private val SIGNATURE_LETTERS: Map, Type> by lazy { + hashMapOf( + java.lang.Byte::class to Type.BYTE_TYPE, + java.lang.Short::class to Type.SHORT_TYPE, + java.lang.Integer::class to Type.INT_TYPE, + java.lang.Long::class to Type.LONG_TYPE, + java.lang.Float::class to Type.FLOAT_TYPE, + java.lang.Double::class to Type.DOUBLE_TYPE + ) + } + + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ + private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + + /** + * Maps primitive ASM types to [Number] functions unboxing them. + */ + private val NUMBER_CONVERTER_METHODS: Map by lazy { + hashMapOf( + Type.BYTE_TYPE to "byteValue", + Type.SHORT_TYPE to "shortValue", + Type.INT_TYPE to "intValue", + Type.LONG_TYPE to "longValue", + Type.FLOAT_TYPE to "floatValue", + Type.DOUBLE_TYPE to "doubleValue" + ) + } + + /** + * Provides boxed number types values of which can be stored in JVM bytecode constant pool. + */ + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + + /** + * ASM type for [Expression]. + */ + internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + + /** + * ASM type for [java.lang.Number]. + */ + internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + + /** + * ASM type for [java.util.Map]. + */ + internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + + /** + * ASM type for [java.lang.Object]. + */ + internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } + + /** + * ASM type for array of [java.lang.Object]. + */ + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") + internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + + /** + * ASM type for [Algebra]. + */ + internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + + /** + * ASM type for [java.lang.String]. + */ + internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } + + /** + * ASM type for MapIntrinsics. + */ + internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt new file mode 100644 index 000000000..46d07976d --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -0,0 +1,148 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.INVOKEVIRTUAL +import org.objectweb.asm.commons.InstructionAdapter +import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import kotlin.reflect.KClass + +private val methodNameAdapters: Map, String> by lazy { + hashMapOf( + "+" to 2 to "add", + "*" to 2 to "multiply", + "/" to 2 to "divide", + "+" to 1 to "unaryPlus", + "-" to 1 to "unaryMinus", + "-" to 2 to "minus" + ) +} + +internal val KClass<*>.asm: Type + get() = Type.getType(java) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor]. + */ +private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. + */ +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) + +/** + * Constructs a [Label], then applies it to this visitor. + */ +internal fun MethodVisitor.label(): Label { + val l = Label() + visitLabel(l) + return l +} + +/** + * Creates a class name for [Expression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} + +@Suppress("FunctionName") +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) + +internal inline fun ClassWriter.visitField( + access: Int, + name: String, + descriptor: String, + signature: String?, + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) + +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds + * type expectation stack for needed arity. + * + * @return `true` if contains, else `false`. + */ +private fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { + val theName = methodNameAdapters[name to arity] ?: name + val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null + val t = if (primitiveMode && hasSpecific) primitiveMask else tType + repeat(arity) { expectationStack.push(t) } + return hasSpecific +} + +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts + * [AsmBuilder.invokeAlgebraOperation] of this method. + * + * @return `true` if contains, else `false`. + */ +private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { + val theName = methodNameAdapters[name to arity] ?: name + + context.javaClass.methods.find { + var suitableSignature = it.name == theName && it.parameters.size == arity + + if (primitiveMode && it.isBridge) + suitableSignature = false + + suitableSignature + } ?: return false + + val owner = context::class.asm + + invokeAlgebraOperation( + owner = owner.internalName, + method = theName, + descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), + expectedArity = arity, + opcode = INVOKEVIRTUAL + ) + + return true +} + +/** + * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. + */ +internal fun AsmBuilder.buildAlgebraOperationCall( + context: Algebra, + name: String, + fallbackMethodName: String, + arity: Int, + parameters: AsmBuilder.() -> Unit +) { + loadAlgebra() + if (!buildExpectationStack(context, name, arity)) loadStringConstant(name) + parameters() + + if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = fallbackMethodName, + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + *Array(arity) { AsmBuilder.OBJECT_TYPE } + ), + + expectedArity = arity + ) +} + diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt new file mode 100644 index 000000000..f47293687 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt @@ -0,0 +1,10 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Label +import org.objectweb.asm.commons.InstructionAdapter + +internal fun InstructionAdapter.label(): Label { + val l = Label() + visitLabel(l) + return l +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt new file mode 100644 index 000000000..7f7126b55 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt @@ -0,0 +1,7 @@ +@file:JvmName("MapIntrinsics") + +package scientifik.kmath.asm.internal + +internal fun Map.getOrFail(key: K, default: V?): V { + return this[key] ?: default ?: error("Parameter not found: $key") +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt new file mode 100644 index 000000000..3acc6eb28 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -0,0 +1,110 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmAlgebras { + @Test + fun space() { + val res1 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }("x" to 2.toByte()) + + val res2 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to 2.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun ring() { + val res1 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }("x" to 3.toByte()) + + val res2 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }.compile()("x" to 3.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun field() { + val res1 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0) + + val res2 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0) + + assertEquals(res1, res2) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt new file mode 100644 index 000000000..36c254c38 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -0,0 +1,31 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmExpressions { + @Test + fun testUnaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") }.compile() + val res = expression("x" to 2.0) + assertEquals(-2.0, res) + } + + @Test + fun testBinaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile() + val res = expression("x" to 2.0) + assertEquals(-1.0, res) + } + + @Test + fun testConstProductInvocation() { + val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) + assertEquals(4.0, res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt new file mode 100644 index 000000000..b571e076f --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt @@ -0,0 +1,46 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + assertEquals(2.0, expr("x" to 2.0)) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + assertEquals(-2.0, expr("x" to 2.0)) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 2.0)) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + assertEquals(1.0, expr("x" to 2.0)) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt new file mode 100644 index 000000000..aafc75448 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt @@ -0,0 +1,22 @@ +package scietifik.kmath.asm + +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestAsmVariables { + @Test + fun testVariableWithoutDefault() { + val expr = ByteRing.mstInRing { symbol("x") } + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testVariableWithoutDefaultFails() { + val expr = ByteRing.mstInRing { symbol("x") } + assertFailsWith { expr() } + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt new file mode 100644 index 000000000..23203172e --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -0,0 +1,26 @@ +package scietifik.kmath.ast + +import scientifik.kmath.asm.compile +import scientifik.kmath.asm.expression +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Complex +import scientifik.kmath.operations.ComplexField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class AsmTest { + @Test + fun `compile MST`() { + val mst = "2+2*(2+2)".parseMath() + val res = ComplexField.expression(mst)() + assertEquals(Complex(10.0, 0.0), res) + } + + @Test + fun `compile MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()() + assertEquals(Complex(10.0, 0.0), res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt new file mode 100644 index 000000000..5394a4b00 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -0,0 +1,25 @@ +package scietifik.kmath.ast + +import scientifik.kmath.ast.evaluate +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Complex +import scientifik.kmath.operations.ComplexField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class ParserTest { + @Test + fun `evaluate MST`() { + val mst = "2+2*(2+2)".parseMath() + val res = ComplexField.evaluate(mst) + assertEquals(Complex(10.0, 0.0), res) + } + + @Test + fun `evaluate MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() + assertEquals(Complex(10.0, 0.0), res) + } +} diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index d5c038dc4..54c404f57 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -2,7 +2,7 @@ package scientifik.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty @@ -59,8 +59,10 @@ class DerivativeStructureField( override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() - override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() + override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() + override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { is Double -> arg.pow(pow) @@ -74,10 +76,10 @@ class DerivativeStructureField( override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() - operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble()) - operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble()) - operator fun Number.plus(s: DerivativeStructure) = s + this - operator fun Number.minus(s: DerivativeStructure) = s - this + override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) + override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) + override operator fun Number.plus(b: DerivativeStructure) = b + this + override operator fun Number.minus(b: DerivativeStructure) = b - this } /** @@ -113,7 +115,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ -object DiffExpressionContext : ExpressionContext, Field { +object DiffExpressionAlgebra : ExpressionAlgebra, Field { override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) } @@ -136,6 +138,3 @@ object DiffExpressionContext : ExpressionContext, Field override fun divide(a: DiffExpression, b: DiffExpression) = DiffExpression { a.function(this) / b.function(this) } } - - - diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt new file mode 100644 index 000000000..13e79d60e --- /dev/null +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.commons.random + +import scientifik.kmath.prob.RandomGenerator + +class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : + org.apache.commons.math3.random.RandomGenerator { + private var generator = factory(intArrayOf()) + + override fun nextBoolean(): Boolean = generator.nextBoolean() + + override fun nextFloat(): Float = generator.nextDouble().toFloat() + + override fun setSeed(seed: Int) { + generator = factory(intArrayOf(seed)) + } + + override fun setSeed(seed: IntArray) { + generator = factory(seed) + } + + override fun setSeed(seed: Long) { + setSeed(seed.toInt()) + } + + override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + override fun nextInt(): Int = generator.nextInt() + + override fun nextInt(n: Int): Int = generator.nextInt(n) + + override fun nextGaussian(): Double = TODO() + + override fun nextDouble(): Double = generator.nextDouble() + + override fun nextLong(): Long = generator.nextLong() +} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt index bcb3ea87b..eb1b5b69a 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt @@ -18,7 +18,7 @@ object Transformations { private fun Buffer.toArray(): Array = Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } - private fun Buffer.asArray() = if (this is DoubleBuffer) { + private fun Buffer.asArray() = if (this is RealBuffer) { array } else { DoubleArray(size) { i -> get(i) } diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 092f3deb7..18c0cc771 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -8,4 +8,4 @@ kotlin.sourceSets { api(project(":kmath-memory")) } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt new file mode 100644 index 000000000..333b77cb4 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * A simple geometric domain + */ +interface Domain { + operator fun contains(point: Point): Boolean + + /** + * Number of hyperspace dimensions + */ + val dimension: Int +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt new file mode 100644 index 000000000..e0019c96b --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.RealBuffer +import scientifik.kmath.structures.indices + +/** + * + * HyperSquareDomain class. + * + * @author Alexander Nozik + */ +class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { + + override operator fun contains(point: Point): Boolean = point.indices.all { i -> + point[i] in lower[i]..upper[i] + } + + override val dimension: Int get() = lower.size + + override fun getLowerBound(num: Int, point: Point): Double? = lower[num] + + override fun getLowerBound(num: Int): Double? = lower[num] + + override fun getUpperBound(num: Int, point: Point): Double? = upper[num] + + override fun getUpperBound(num: Int): Double? = upper[num] + + override fun nearestInDomain(point: Point): Point { + val res: DoubleArray = DoubleArray(point.size) { i -> + when { + point[i] < lower[i] -> lower[i] + point[i] > upper[i] -> upper[i] + else -> point[i] + } + } + return RealBuffer(*res) + } + + override fun volume(): Double { + var res = 1.0 + for (i in 0 until dimension) { + if (lower[i].isInfinite() || upper[i].isInfinite()) { + return Double.POSITIVE_INFINITY + } + if (upper[i] > lower[i]) { + res *= upper[i] - lower[i] + } + } + return res + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt new file mode 100644 index 000000000..89115887e --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * n-dimensional volume + * + * @author Alexander Nozik + */ +interface RealDomain: Domain { + + fun nearestInDomain(point: Point): Point + + /** + * The lower edge for the domain going down from point + * @param num + * @param point + * @return + */ + fun getLowerBound(num: Int, point: Point): Double? + + /** + * The upper edge of the domain going up from point + * @param num + * @param point + * @return + */ + fun getUpperBound(num: Int, point: Point): Double? + + /** + * Global lower edge + * @param num + * @return + */ + fun getLowerBound(num: Int): Double? + + /** + * Global upper edge + * @param num + * @return + */ + fun getUpperBound(num: Int): Double? + + /** + * Hyper volume + * @return + */ + fun volume(): Double + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt new file mode 100644 index 000000000..e49fd3b37 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +class UnconstrainedDomain(override val dimension: Int) : RealDomain { + + override operator fun contains(point: Point): Boolean = true + + override fun getLowerBound(num: Int, point: Point): Double? = Double.NEGATIVE_INFINITY + + override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY + + override fun getUpperBound(num: Int, point: Point): Double? = Double.POSITIVE_INFINITY + + override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY + + override fun nearestInDomain(point: Point): Point = point + + override fun volume(): Double = Double.POSITIVE_INFINITY + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt new file mode 100644 index 000000000..ef521d5ea --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt @@ -0,0 +1,48 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.asBuffer + +inline class UnivariateDomain(val range: ClosedFloatingPointRange) : RealDomain { + + operator fun contains(d: Double): Boolean = range.contains(d) + + override operator fun contains(point: Point): Boolean { + require(point.size == 0) + return contains(point[0]) + } + + override fun nearestInDomain(point: Point): Point { + require(point.size == 1) + val value = point[0] + return when{ + value in range -> point + value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() + else -> doubleArrayOf(range.start).asBuffer() + } + } + + override fun getLowerBound(num: Int, point: Point): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int, point: Point): Double? { + require(num == 0) + return range.endInclusive + } + + override fun getLowerBound(num: Int): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int): Double? { + require(num == 0) + return range.endInclusive + } + + override fun volume(): Double = range.endInclusive - range.start + + override val dimension: Int get() = 1 +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt new file mode 100644 index 000000000..9f1503285 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -0,0 +1,23 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.Space + +/** + * Create a functional expression on this [Space] + */ +fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = + FunctionalExpressionSpace(this).run(block) + +/** + * Create a functional expression on this [Ring] + */ +fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = + FunctionalExpressionRing(this).run(block) + +/** + * Create a functional expression on this [Field] + */ +fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression = + FunctionalExpressionField(this).run(block) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index aa7407c0a..e512b1cd8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -1,14 +1,21 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.Algebra /** * An elementary function that could be invoked on a map of arguments */ interface Expression { operator fun invoke(arguments: Map): T + + companion object +} + +/** + * Create simple lazily evaluated expression inside given algebra + */ +fun Algebra.expression(block: Algebra.(arguments: Map) -> T): Expression = object: Expression { + override fun invoke(arguments: Map): T = block(arguments) } operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) @@ -16,77 +23,14 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext { +interface ExpressionAlgebra : Algebra { /** * Introduce a variable into expression context */ - fun variable(name: String, default: T? = null): Expression + fun variable(name: String, default: T? = null): E /** * A constant expression which does not depend on arguments */ - fun const(value: T): Expression -} - -internal class VariableExpression(val name: String, val default: T? = null) : Expression { - override fun invoke(arguments: Map): T = - arguments[name] ?: default ?: error("Parameter not found: $name") -} - -internal class ConstantExpression(val value: T) : Expression { - override fun invoke(arguments: Map): T = value -} - -internal class SumExpression(val context: Space, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = - context.multiply(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : - Expression { - override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) -} - -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) -} - -open class ExpressionSpace(val space: Space) : Space>, ExpressionContext { - override val zero: Expression = ConstantExpression(space.zero) - - override fun const(value: T): Expression = ConstantExpression(value) - - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) - - - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this -} - - -class ExpressionField(val field: Field) : Field>, ExpressionSpace(field) { - override val one: Expression = ConstantExpression(field.one) - override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) - - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) - - operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - - operator fun T.times(arg: Expression) = arg * this - operator fun T.div(arg: Expression) = arg / this + fun const(value: T): E } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt new file mode 100644 index 000000000..a8a26aa33 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -0,0 +1,146 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.* + +internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : + Expression { + override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) +} + +internal class FunctionalBinaryOperation( + val context: Algebra, + val name: String, + val first: Expression, + val second: Expression +) : Expression { + override fun invoke(arguments: Map): T = + context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) +} + +internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { + override fun invoke(arguments: Map): T = + arguments[name] ?: default ?: error("Parameter not found: $name") +} + +internal class FunctionalConstantExpression(val value: T) : Expression { + override fun invoke(arguments: Map): T = value +} + +internal class FunctionalConstProductExpression( + val context: Space, + private val expr: Expression, + val const: Number +) : Expression { + override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) +} + +/** + * A context class for [Expression] construction. + * + * @param algebra The algebra to provide for Expressions built. + */ +abstract class FunctionalExpressionAlgebra>(val algebra: A) : ExpressionAlgebra> { + + /** + * Builds an Expression of constant expression which does not depend on arguments. + */ + override fun const(value: T): Expression = FunctionalConstantExpression(value) + + /** + * Builds an Expression to access a variable. + */ + override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) + + /** + * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. + */ + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) + + /** + * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. + */ + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) +} + +/** + * A context class for [Expression] construction for [Space] algebras. + */ +open class FunctionalExpressionSpace>(algebra: A) : + FunctionalExpressionAlgebra(algebra), Space> { + + override val zero: Expression get() = const(algebra.zero) + + /** + * Builds an Expression of addition of two another expressions. + */ + override fun add(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + + /** + * Builds an Expression of multiplication of expression by number. + */ + override fun multiply(a: Expression, k: Number): Expression = + FunctionalConstProductExpression(algebra, a, k) + + operator fun Expression.plus(arg: T): Expression = this + const(arg) + operator fun Expression.minus(arg: T): Expression = this - const(arg) + operator fun T.plus(arg: Expression): Expression = arg + this + operator fun T.minus(arg: Expression): Expression = arg - this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), + Ring> where A : Ring, A : NumericAlgebra { + override val one: Expression + get() = const(algebra.one) + + /** + * Builds an Expression of multiplication of two expressions. + */ + override fun multiply(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) + + operator fun Expression.times(arg: T): Expression = this * const(arg) + operator fun T.times(arg: Expression): Expression = arg * this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +open class FunctionalExpressionField(algebra: A) : + FunctionalExpressionRing(algebra), + Field> where A : Field, A : NumericAlgebra { + /** + * Builds an Expression of division an expression by another one. + */ + override fun divide(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) + + operator fun Expression.div(arg: T): Expression = this / const(arg) + operator fun T.div(arg: Expression): Expression = arg / this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = + FunctionalExpressionSpace(this).block() + +inline fun > A.expressionInRing(block: FunctionalExpressionRing.() -> Expression): Expression = + FunctionalExpressionRing(this).block() + +inline fun > A.expressionInField(block: FunctionalExpressionField.() -> Expression): Expression = + FunctionalExpressionField(this).block() \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 73b18b810..c4c38284b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext { override val elementContext get() = RealField override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix { - val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } return BufferMatrix(rows, columns, buffer) } - override inline fun point(size: Int, initializer: (Int) -> Double): Point = DoubleBuffer(size,initializer) + override inline fun point(size: Int, initializer: (Int) -> Double): Point = RealBuffer(size,initializer) } class BufferMatrix( @@ -102,7 +102,7 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix.unsafeArray(): DoubleArray = if (this is DoubleBuffer) { + fun Buffer.unsafeArray(): DoubleArray = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } @@ -119,6 +119,6 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix, F : Field> GenericMatrixContext.lup( luRow[col] = sum // maintain best permutation choice - if (abs(sum) > largest) { - largest = abs(sum) + if (this@lup.abs(sum) > largest) { + largest = this@lup.abs(sum) max = row } } // Singularity check - if (checkSingular(abs(lu[max, col]))) { + if (checkSingular(this@lup.abs(lu[max, col]))) { error("The matrix is singular") } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt index ed77054cf..076701a4f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt @@ -90,20 +90,20 @@ abstract class AutoDiffField> : Field> { // Overloads for Double constants - operator fun Number.plus(that: Variable): Variable = - derive(variable { this@plus.toDouble() * one + that.value }) { z -> - that.d += z.d + override operator fun Number.plus(b: Variable): Variable = + derive(variable { this@plus.toDouble() * one + b.value }) { z -> + b.d += z.d } - operator fun Variable.plus(b: Number): Variable = b.plus(this) + override operator fun Variable.plus(b: Number): Variable = b.plus(this) - operator fun Number.minus(that: Variable): Variable = - derive(variable { this@minus.toDouble() * one - that.value }) { z -> - that.d -= z.d + override operator fun Number.minus(b: Variable): Variable = + derive(variable { this@minus.toDouble() * one - b.value }) { z -> + b.d -= z.d } - operator fun Variable.minus(that: Number): Variable = - derive(variable { this@minus.value - one * that.toDouble() }) { z -> + override operator fun Variable.minus(b: Number): Variable = + derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt index 90ce5da68..f040fb8d4 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt @@ -1,5 +1,7 @@ package scientifik.kmath.misc +import kotlin.math.abs + /** * Convert double range to sequence. * @@ -8,28 +10,36 @@ package scientifik.kmath.misc * * If step is negative, the same goes from upper boundary downwards */ -fun ClosedFloatingPointRange.toSequence(step: Double): Sequence = - when { - step == 0.0 -> error("Zero step in double progression") - step > 0 -> sequence { - var current = start - while (current <= endInclusive) { - yield(current) - current += step - } - } - else -> sequence { - var current = endInclusive - while (current >= start) { - yield(current) - current += step - } - } +fun ClosedFloatingPointRange.toSequenceWithStep(step: Double): Sequence = when { + step == 0.0 -> error("Zero step in double progression") + step > 0 -> sequence { + var current = start + while (current <= endInclusive) { + yield(current) + current += step } + } + else -> sequence { + var current = endInclusive + while (current >= start) { + yield(current) + current += step + } + } +} + +/** + * Convert double range to sequence with the fixed number of points + */ +fun ClosedFloatingPointRange.toSequenceWithPoints(numPoints: Int): Sequence { + require(numPoints > 1) { "The number of points should be more than 2" } + return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1)) +} /** * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] */ +@Deprecated("Replace by 'toSequenceWithPoints'") fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { if (numPoints < 2) error("Can't create generic grid with less than two points") return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 485185526..52b6bba02 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -6,9 +6,43 @@ annotation class KMathContext /** * Marker interface for any algebra */ -interface Algebra +interface Algebra { + /** + * Wrap raw string or variable + */ + fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") -inline operator fun , R> T.invoke(block: T.() -> R): R = run(block) + /** + * Dynamic call of unary operation with name [operation] on [arg] + */ + fun unaryOperation(operation: String, arg: T): T + + /** + * Dynamic call of binary operation [operation] on [left] and [right] + */ + fun binaryOperation(operation: String, left: T, right: T): T +} + +/** + * An algebra with numeric representation of members + */ +interface NumericAlgebra : Algebra { + /** + * Wrap a number + */ + fun number(value: Number): T + + fun leftSideNumberOperation(operation: String, left: Number, right: T): T = + binaryOperation(operation, number(left), right) + + fun rightSideNumberOperation(operation: String, left: T, right: Number): T = + leftSideNumberOperation(operation, right, left) +} + +/** + * Call a block with an [Algebra] as receiver + */ +inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** * Space-like operations without neutral element @@ -24,14 +58,34 @@ interface SpaceOperations : Algebra { */ fun multiply(a: T, k: Number): T - //Operation to be performed in this context + //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 operator fun T.unaryMinus(): T = multiply(this, -1.0) + operator fun T.unaryPlus(): T = this + operator fun T.plus(b: T): T = add(this, b) operator fun T.minus(b: T): T = add(this, -b) operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) operator fun Number.times(b: T) = b * this + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + PLUS_OPERATION -> arg + MINUS_OPERATION -> -arg + else -> error("Unary operation $operation not defined in $this") + } + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + PLUS_OPERATION -> add(left, right) + MINUS_OPERATION -> left - right + else -> error("Binary operation $operation not defined in $this") + } + + companion object { + const val PLUS_OPERATION = "+" + const val MINUS_OPERATION = "-" + const val NOT_OPERATION = "!" + } } @@ -60,22 +114,48 @@ interface RingOperations : SpaceOperations { fun multiply(a: T, b: T): T operator fun T.times(b: T): T = multiply(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + TIMES_OPERATION -> multiply(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object { + const val TIMES_OPERATION = "*" + } } /** * The same as {@link Space} but with additional multiplication operation */ -interface Ring : Space, RingOperations { +interface Ring : Space, RingOperations, NumericAlgebra { /** * neutral operation for multiplication */ val one: T -// operator fun T.plus(b: Number) = this.plus(b * one) -// operator fun Number.plus(b: T) = b + this -// -// operator fun T.minus(b: Number) = this.minus(b * one) -// operator fun Number.minus(b: T) = -b + this + override fun number(value: Number): T = one * value.toDouble() + + override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { + SpaceOperations.PLUS_OPERATION -> left + right + SpaceOperations.MINUS_OPERATION -> left - right + RingOperations.TIMES_OPERATION -> left * right + else -> super.leftSideNumberOperation(operation, left, right) + } + + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { + SpaceOperations.PLUS_OPERATION -> left + right + SpaceOperations.MINUS_OPERATION -> left - right + RingOperations.TIMES_OPERATION -> left * right + else -> super.rightSideNumberOperation(operation, left, right) + } + + + operator fun T.plus(b: Number) = this.plus(number(b)) + operator fun Number.plus(b: T) = b + this + + operator fun T.minus(b: Number) = this.minus(number(b)) + operator fun Number.minus(b: T) = -b + this } /** @@ -85,6 +165,15 @@ interface FieldOperations : RingOperations { fun divide(a: T, b: T): T operator fun T.div(b: T): T = divide(this, b) + + override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { + DIV_OPERATION -> divide(left, right) + else -> super.binaryOperation(operation, left, right) + } + + companion object { + const val DIV_OPERATION = "/" + } } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 6c529f55e..398ea4395 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -8,10 +8,12 @@ import scientifik.memory.MemorySpec import scientifik.memory.MemoryWriter import kotlin.math.* +private val PI_DIV_2 = Complex(PI / 2, 0) + /** * A field for complex numbers */ -object ComplexField : ExtendedFieldOperations, Field { +object ComplexField : ExtendedField { override val zero: Complex = Complex(0.0, 0.0) override val one: Complex = Complex(1.0, 0.0) @@ -30,9 +32,11 @@ object ComplexField : ExtendedFieldOperations, Field { return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) } - override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg)) - + override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 + override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg) + override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg) + override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2 override fun power(arg: Complex, pow: Number): Complex = arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) @@ -50,6 +54,12 @@ object ComplexField : ExtendedFieldOperations, Field { operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) + + override fun symbol(value: String): Complex = if (value == "i") { + i + } else { + super.symbol(value) + } } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 2b6a92f14..953c5a112 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -7,12 +7,32 @@ import kotlin.math.pow as kpow * Advanced Number-like field that implements basic operations */ interface ExtendedFieldOperations : - FieldOperations, - TrigonometricOperations, + InverseTrigonometricOperations, PowerOperations, - ExponentialOperations + ExponentialOperations { -interface ExtendedField : ExtendedFieldOperations, Field + override fun tan(arg: T): T = sin(arg) / cos(arg) + + override fun unaryOperation(operation: String, arg: T): T = when (operation) { + TrigonometricOperations.COS_OPERATION -> cos(arg) + TrigonometricOperations.SIN_OPERATION -> sin(arg) + TrigonometricOperations.TAN_OPERATION -> tan(arg) + InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) + InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) + InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) + PowerOperations.SQRT_OPERATION -> sqrt(arg) + ExponentialOperations.EXP_OPERATION -> exp(arg) + ExponentialOperations.LN_OPERATION -> ln(arg) + else -> super.unaryOperation(operation, arg) + } +} + +interface ExtendedField : ExtendedFieldOperations, Field { + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { + PowerOperations.POW_OPERATION -> power(left, right) + else -> super.rightSideNumberOperation(operation, left, right) + } +} /** * Real field element wrapping double. @@ -44,6 +64,10 @@ object RealField : ExtendedField, Norm { override inline fun sin(arg: Double) = kotlin.math.sin(arg) override inline fun cos(arg: Double) = kotlin.math.cos(arg) + override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) + override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) + override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) + override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) @@ -75,6 +99,10 @@ object FloatField : ExtendedField, Norm { override inline fun sin(arg: Float) = kotlin.math.sin(arg) override inline fun cos(arg: Float) = kotlin.math.cos(arg) + override inline fun tan(arg: Float) = kotlin.math.tan(arg) + override inline fun acos(arg: Float) = kotlin.math.acos(arg) + override inline fun asin(arg: Float) = kotlin.math.asin(arg) + override inline fun atan(arg: Float) = kotlin.math.atan(arg) override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index bd83932e7..709f0260f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -13,16 +13,33 @@ package scientifik.kmath.operations interface TrigonometricOperations : FieldOperations { fun sin(arg: T): T fun cos(arg: T): T + fun tan(arg: T): T - fun tg(arg: T): T = sin(arg) / cos(arg) + companion object { + const val SIN_OPERATION = "sin" + const val COS_OPERATION = "cos" + const val TAN_OPERATION = "tan" + } +} - fun ctg(arg: T): T = cos(arg) / sin(arg) +interface InverseTrigonometricOperations : TrigonometricOperations { + fun asin(arg: T): T + fun acos(arg: T): T + fun atan(arg: T): T + + companion object { + const val ASIN_OPERATION = "asin" + const val ACOS_OPERATION = "acos" + const val ATAN_OPERATION = "atan" + } } fun >> sin(arg: T): T = arg.context.sin(arg) fun >> cos(arg: T): T = arg.context.cos(arg) -fun >> tg(arg: T): T = arg.context.tg(arg) -fun >> ctg(arg: T): T = arg.context.ctg(arg) +fun >> tan(arg: T): T = arg.context.tan(arg) +fun >> asin(arg: T): T = arg.context.asin(arg) +fun >> acos(arg: T): T = arg.context.acos(arg) +fun >> atan(arg: T): T = arg.context.atan(arg) /* Power and roots */ @@ -34,6 +51,11 @@ interface PowerOperations : Algebra { fun sqrt(arg: T) = power(arg, 0.5) infix fun T.pow(pow: Number) = power(this, pow) + + companion object { + const val POW_OPERATION = "pow" + const val SQRT_OPERATION = "sqrt" + } } infix fun >> T.pow(power: Double): T = context.power(this, power) @@ -42,9 +64,14 @@ fun >> sqr(arg: T): T = arg pow 2.0 /* Exponential */ -interface ExponentialOperations: Algebra { +interface ExponentialOperations : Algebra { fun exp(arg: T): T fun ln(arg: T): T + + companion object { + const val EXP_OPERATION = "exp" + const val LN_OPERATION = "ln" + } } fun >> exp(arg: T): T = arg.context.exp(arg) @@ -54,4 +81,4 @@ interface Norm { fun norm(arg: T): R } -fun >, R> norm(arg: T): R = arg.context.norm(arg) \ No newline at end of file +fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index 613a0d7ca..5789de9ee 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -37,9 +37,9 @@ interface Buffer { companion object { - inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { + inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { val array = DoubleArray(size) { initializer(it) } - return DoubleBuffer(array) + return RealBuffer(array) } /** @@ -51,7 +51,7 @@ interface Buffer { inline fun auto(type: KClass, size: Int, crossinline initializer: (Int) -> T): Buffer { //TODO add resolution based on Annotation or companion resolution return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer @@ -93,7 +93,7 @@ interface MutableBuffer : Buffer { @Suppress("UNCHECKED_CAST") inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer { return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer @@ -109,12 +109,11 @@ interface MutableBuffer : Buffer { auto(T::class, size, initializer) val real: MutableBufferFactory = { size: Int, initializer: (Int) -> Double -> - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) } } } - inline class ListBuffer(val list: List) : Buffer { override val size: Int @@ -163,57 +162,6 @@ class ArrayBuffer(private val array: Array) : MutableBuffer { fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) -inline class ShortBuffer(val array: ShortArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Short = array[index] - - override fun set(index: Int, value: Short) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) - -} - -fun ShortArray.asBuffer() = ShortBuffer(this) - -inline class IntBuffer(val array: IntArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Int = array[index] - - override fun set(index: Int, value: Int) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = IntBuffer(array.copyOf()) - -} - -fun IntArray.asBuffer() = IntBuffer(this) - -inline class LongBuffer(val array: LongArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Long = array[index] - - override fun set(index: Int, value: Long) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) - -} - -fun LongArray.asBuffer() = LongBuffer(this) - inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { override val size: Int get() = buffer.size diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt index a79366a99..c7e672c28 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt @@ -79,6 +79,13 @@ class ComplexNDField(override val shape: IntArray) : override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): NDBuffer = map(arg) {acos(it)} + + override fun atan(arg: NDBuffer): NDBuffer = map(arg) {atan(it)} } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 3437644ff..776cff880 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -1,13 +1,8 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.* +import scientifik.kmath.operations.ExtendedField -interface ExtendedNDField> : - NDField, - TrigonometricOperations, - PowerOperations, - ExponentialOperations - where F : ExtendedFieldOperations, F : Field +interface ExtendedNDField, N : NDStructure> : NDField, ExtendedField ///** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt new file mode 100644 index 000000000..749e4eeec --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt @@ -0,0 +1,53 @@ +package scientifik.kmath.structures + +import kotlin.experimental.and + +enum class ValueFlag(val mask: Byte) { + NAN(0b0000_0001), + MISSING(0b0000_0010), + NEGATIVE_INFINITY(0b0000_0100), + POSITIVE_INFINITY(0b0000_1000) +} + +/** + * A buffer with flagged values + */ +interface FlaggedBuffer : Buffer { + fun getFlag(index: Int): Byte +} + +/** + * The value is valid if all flags are down + */ +fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte() + +fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte() + +fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING) + +/** + * A real buffer which supports flags for each value like NaN or Missing + */ +class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer, Buffer { + init { + require(values.size == flags.size) { "Values and flags must have the same dimensions" } + } + + override fun getFlag(index: Int): Byte = flags[index] + + override val size: Int get() = values.size + + override fun get(index: Int): Double? = if (isValid(index)) values[index] else null + + override fun iterator(): Iterator = values.indices.asSequence().map { + if (isValid(it)) values[it] else null + }.iterator() +} + +inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { + for(i in indices){ + if(isValid(i)){ + block(values[i]) + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt new file mode 100644 index 000000000..a354c5de0 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class IntBuffer(val array: IntArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Int = array[index] + + override fun set(index: Int, value: Int) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + IntBuffer(array.copyOf()) + +} + + +fun IntArray.asBuffer() = IntBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt new file mode 100644 index 000000000..fa6229a71 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt @@ -0,0 +1,19 @@ +package scientifik.kmath.structures + +inline class LongBuffer(val array: LongArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Long = array[index] + + override fun set(index: Int, value: Long) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + LongBuffer(array.copyOf()) + +} + +fun LongArray.asBuffer() = LongBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt similarity index 59% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt index c0b7f713b..f48ace3a9 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt @@ -1,6 +1,6 @@ package scientifik.kmath.structures -inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { +inline class RealBuffer(val array: DoubleArray) : MutableBuffer { override val size: Int get() = array.size override fun get(index: Int): Double = array[index] @@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { override fun iterator() = array.iterator() override fun copy(): MutableBuffer = - DoubleBuffer(array.copyOf()) + RealBuffer(array.copyOf()) } @Suppress("FunctionName") -inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) }) +inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) @Suppress("FunctionName") -fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles) +fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) /** * Transform buffer of doubles into array for high performance operations */ val MutableBuffer.array: DoubleArray - get() = if (this is DoubleBuffer) { + get() = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } } -fun DoubleArray.asBuffer() = DoubleBuffer(this) \ No newline at end of file +fun DoubleArray.asBuffer() = RealBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 88c8c29db..826203d1f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -9,145 +9,172 @@ import kotlin.math.* * A simple field over linear buffers of [Double] */ object RealBufferFieldOperations : ExtendedFieldOperations> { - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { val kValue = k.toDouble() - return if (a is DoubleBuffer) { + + return if (a is RealBuffer) { val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } - override fun sin(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - } + override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) + } else { + RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } - override fun cos(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) - } + override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + + override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + + override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) + } else { + RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - } - } + override fun acos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { acos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - override fun exp(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - } - } + override fun atan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { atan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - override fun ln(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) - } - } + override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + } else + RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + + override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + + override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } class RealBufferField(val size: Int) : ExtendedField> { + override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } + override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } - override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } - - override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } - - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, k) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, b) } - - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) } - override fun sin(arg: Buffer): DoubleBuffer { + override fun sin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sin(arg) } - override fun cos(arg: Buffer): DoubleBuffer { + override fun cos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cos(arg) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { + override fun tan(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.tan(arg) + } + + override fun asin(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.asin(arg) + } + + override fun acos(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.acos(arg) + } + + override fun atan(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.atan(arg) + } + + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) } - override fun exp(arg: Buffer): DoubleBuffer { + override fun exp(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.exp(arg) } - override fun ln(arg: Buffer): DoubleBuffer { + override fun ln(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } - -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 8c1bd4239..8c90f90c7 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) : override val one by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) /** * Inline transform an NDStructure to @@ -74,6 +74,13 @@ class RealNDField(override val shape: IntArray) : override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): NDBuffer = map(arg) { acos(it) } + + override fun atan(arg: NDBuffer): NDBuffer = map(arg) { atan(it) } } @@ -82,7 +89,7 @@ class RealNDField(override val shape: IntArray) : */ inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } - return BufferedNDFieldElement(this, DoubleBuffer(array)) + return BufferedNDFieldElement(this, RealBuffer(array)) } /** @@ -96,7 +103,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int */ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } - return BufferedNDFieldElement(context, DoubleBuffer(array)) + return BufferedNDFieldElement(context, RealBuffer(array)) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt new file mode 100644 index 000000000..f4b2f7d13 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class ShortBuffer(val array: ShortArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Short = array[index] + + override fun set(index: Int, value: Short) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + ShortBuffer(array.copyOf()) + +} + + +fun ShortArray.asBuffer() = ShortBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index 033b2792f..9eae60efc 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -9,7 +9,7 @@ import kotlin.test.assertEquals class ExpressionFieldTest { @Test fun testExpression() { - val context = ExpressionField(RealField) + val context = FunctionalExpressionField(RealField) val expression = with(context) { val x = variable("x", 2.0) x * x + 2 * x + one @@ -20,7 +20,7 @@ class ExpressionFieldTest { @Test fun testComplex() { - val context = ExpressionField(ComplexField) + val context = FunctionalExpressionField(ComplexField) val expression = with(context) { val x = variable("x", Complex(2.0, 0.0)) x * x + 2 * x + one @@ -31,23 +31,23 @@ class ExpressionFieldTest { @Test fun separateContext() { - fun ExpressionField.expression(): Expression { + fun FunctionalExpressionField.expression(): Expression { val x = variable("x") return x * x + 2 * x + one } - val expression = ExpressionField(RealField).expression() + val expression = FunctionalExpressionField(RealField).expression() assertEquals(expression("x" to 1.0), 4.0) } @Test fun valueExpression() { - val expressionBuilder: ExpressionField.() -> Expression = { + val expressionBuilder: FunctionalExpressionField.() -> Expression = { val x = variable("x") x * x + 2 * x + one } - val expression = ExpressionField(RealField).expressionBuilder() + val expression = FunctionalExpressionField(RealField).expressionBuilder() assertEquals(expression("x" to 1.0), 4.0) } } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt new file mode 100644 index 000000000..e69de29bb diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt similarity index 92% rename from kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt rename to kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt index 76ca199c5..e6f09c040 100644 --- a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt @@ -1,10 +1,12 @@ package scientifik.kmath.operations -import scientifik.kmath.structures.* import java.math.BigDecimal import java.math.BigInteger import java.math.MathContext +/** + * A field wrapper for Java [BigInteger] + */ object JBigIntegerField : Field { override val zero: BigInteger = BigInteger.ZERO override val one: BigInteger = BigInteger.ONE @@ -18,6 +20,9 @@ object JBigIntegerField : Field { override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) } +/** + * A Field wrapper for Java [BigDecimal] + */ class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field { override val zero: BigDecimal = BigDecimal.ZERO override val one: BigDecimal = BigDecimal.ONE diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt index bef21a680..54da66bb7 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt @@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.* import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer /** @@ -45,7 +45,7 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< /** * Specialized flow chunker for real buffer */ -fun Flow.chunked(bufferSize: Int): Flow = flow { +fun Flow.chunked(bufferSize: Int): Flow = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } if (this@chunked is BlockingRealChain) { @@ -61,13 +61,13 @@ fun Flow.chunked(bufferSize: Int): Flow = flow { array[counter] = element counter++ if (counter == bufferSize) { - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) emit(buffer) counter = 0 } } if (counter > 0) { - emit(DoubleBuffer(counter) { array[it] }) + emit(RealBuffer(counter) { array[it] }) } } } diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index ff4c835ed..2b89904e3 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -7,31 +7,28 @@ import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField import scientifik.kmath.operations.SpaceElement import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asIterable import kotlin.math.sqrt +typealias RealPoint = Point + fun DoubleArray.asVector() = RealVector(this.asBuffer()) fun List.asVector() = RealVector(this.asBuffer()) - object VectorL2Norm : Norm, Double> { override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) } inline class RealVector(private val point: Point) : - SpaceElement, RealVector, VectorSpace>, Point { + SpaceElement>, RealPoint { - override val context: VectorSpace - get() = space( - point.size - ) + override val context: VectorSpace get() = space(point.size) - override fun unwrap(): Point = point + override fun unwrap(): RealPoint = point - override fun Point.wrap(): RealVector = - RealVector(this) + override fun RealPoint.wrap(): RealVector = RealVector(this) override val size: Int get() = point.size @@ -44,16 +41,12 @@ inline class RealVector(private val point: Point) : private val spaceCache = HashMap>() inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = - RealVector(DoubleBuffer(dim, initializer)) + RealVector(RealBuffer(dim, initializer)) operator fun invoke(vararg values: Double): RealVector = values.asVector() - fun space(dim: Int): BufferVectorSpace = - spaceCache.getOrPut(dim) { - BufferVectorSpace( - dim, - RealField - ) { size, init -> Buffer.real(size, init) } - } + fun space(dim: Int): BufferVectorSpace = spaceCache.getOrPut(dim) { + BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } + } } } \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt index d9ee4d90b..82c0e86b2 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt @@ -1,8 +1,8 @@ package scientifik.kmath.real -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer /** - * Simplified [DoubleBuffer] to array comparison + * Simplified [RealBuffer] to array comparison */ -fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file +fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 813d89577..65f86eec7 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asIterable import kotlin.math.pow @@ -27,6 +27,10 @@ typealias RealMatrix = Matrix fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) +fun Array.toMatrix(): RealMatrix{ + return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } +} + fun Sequence.toMatrix(): RealMatrix = toList().let { MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } } @@ -129,22 +133,22 @@ fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = fun Matrix.extractColumn(columnIndex: Int): RealMatrix = extractColumns(columnIndex..columnIndex) -fun Matrix.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> val column = columns[j] with(elementContext) { sum(column.asIterable()) } } -fun Matrix.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().average() } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 329af72a1..43d50ad20 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -1,17 +1,9 @@ package scientifik.kmath.histogram +import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer -import scientifik.kmath.structures.DoubleBuffer - -/** - * A simple geometric domain - * TODO move to geometry module - */ -interface Domain { - operator fun contains(vector: Point): Boolean - val dimension: Int -} +import scientifik.kmath.structures.RealBuffer /** * The bin in the histogram. The histogram is by definition always done in the real space @@ -51,9 +43,9 @@ interface MutableHistogram> : Histogram { fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) fun MutableHistogram.put(vararg point: Number) = - put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) + put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(DoubleBuffer(point)) +fun MutableHistogram.put(vararg point: Double) = put(RealBuffer(point)) fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 4438f5d60..628a68461 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -1,8 +1,8 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point -import scientifik.kmath.real.asVector import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.real.asVector import scientifik.kmath.structures.* import kotlin.math.floor @@ -21,7 +21,7 @@ data class BinDef>(val space: SpaceOperations>, val c class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - override fun contains(vector: Point): Boolean = def.contains(vector) + override fun contains(point: Point): Boolean = def.contains(point) override val dimension: Int get() = def.center.size @@ -50,7 +50,7 @@ class RealHistogram( override val dimension: Int get() = lower.size - private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } + private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index dcc5ac0eb..af01205bf 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) - override fun contains(vector: Buffer): Boolean = contains(vector[0]) + override fun contains(point: Buffer): Boolean = contains(point[0]) internal operator fun inc() = this.also { counter.increment() } diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt index 0896f0dcb..7999aa2ab 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt @@ -10,6 +10,7 @@ interface MemorySpec { val objectSize: Int fun MemoryReader.read(offset: Int): T + //TODO consider thread safety fun MemoryWriter.write(offset: Int, value: T) } diff --git a/settings.gradle.kts b/settings.gradle.kts index afb5598b4..c7348b34e 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -3,10 +3,12 @@ pluginManagement { val toolsVersion = "0.5.0" plugins { + id("kotlinx.benchmark") version "0.2.0-dev-8" id("scientifik.mpp") version toolsVersion id("scientifik.jvm") version toolsVersion id("scientifik.atomic") version toolsVersion id("scientifik.publish") version toolsVersion + kotlin("plugin.allopen") version "1.3.72" } repositories { @@ -45,5 +47,6 @@ include( ":kmath-dimensions", ":kmath-for-real", ":kmath-geometry", + ":kmath-ast", ":examples" )