diff --git a/build.gradle.kts b/build.gradle.kts index 6d102a77a..052b457c5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("scientifik.publish") apply false } -val kmathVersion by extra("0.1.4-dev-7") +val kmathVersion by extra("0.1.4-dev-8") val bintrayRepo by extra("scientifik") val githubProject by extra("kmath") diff --git a/doc/buffers.md b/doc/buffers.md index b0b7489b3..52a9df86e 100644 --- a/doc/buffers.md +++ b/doc/buffers.md @@ -2,7 +2,7 @@ Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). There are different types of buffers: -* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. +* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays. * Boxing `ListBuffer` wrapping a list * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * `MemoryBuffer` allows direct allocation of objects in continuous memory block. diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 2fab47ac0..73def3572 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.71" - id("kotlinx.benchmark") version "0.2.0-dev-7" + kotlin("plugin.allopen") version "1.3.72" + id("kotlinx.benchmark") version "0.2.0-dev-8" } configure { @@ -24,6 +24,7 @@ sourceSets { } dependencies { + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) @@ -33,8 +34,8 @@ dependencies { implementation(project(":kmath-dimensions")) implementation("com.kyonifer:koma-core-ejml:0.12") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") - implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7") - "benchmarksCompile"(sourceSets.main.get().compileClasspath) + implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8") + "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath } // Configure benchmark diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt index 9676b5e4a..e40b0c4b7 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt @@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex class BufferBenchmark { @Benchmark - fun genericDoubleBufferReadWrite() { - val buffer = DoubleBuffer(size){it.toDouble()} + fun genericRealBufferReadWrite() { + val buffer = RealBuffer(size){it.toDouble()} (0 until size).forEach { buffer[it] diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index be4115d81..f7b9661ef 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -20,48 +20,39 @@ class ViktorBenchmark { final val viktorField = ViktorNDField(intArrayOf(dim, dim)) @Benchmark - fun `Automatic field addition`() { + fun automaticFieldAddition() { autoField.run { var res = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += one } } } @Benchmark - fun `Viktor field addition`() { + fun viktorFieldAddition() { viktorField.run { var res = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } @Benchmark - fun `Raw Viktor`() { + fun rawViktor() { val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) var res = one - repeat(n) { - res = res + one - } + repeat(n) { res = res + one } } @Benchmark - fun `Real field log`() { + fun realdFieldLog() { realField.run { val fortyTwo = produce { 42.0 } var res = one - - repeat(n) { - res = ln(fortyTwo) - } + repeat(n) { res = ln(fortyTwo) } } } @Benchmark - fun `Raw Viktor log`() { + fun rawViktorLog() { val fortyTwo = F64Array.full(dim, dim, init = 42.0) var res: F64Array repeat(n) { diff --git a/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt new file mode 100644 index 000000000..17a70a4aa --- /dev/null +++ b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -0,0 +1,70 @@ +package scientifik.kmath.ast + +import scientifik.kmath.asm.compile +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.expressionInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.RealField +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + fun functionalExpression() { + val expr = algebra.expressionInField { + variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) + } + + invokeAndSum(expr) + } + + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + } + + invokeAndSum(expr) + } + + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + }.compile() + + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + val random = Random(0) + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} + +fun main() { + val benchmark = ExpressionsInterpretersBenchmark() + + val fe = measureTimeMillis { + benchmark.functionalExpression() + } + + println("fe=$fe") + + val mst = measureTimeMillis { + benchmark.mstExpression() + } + + println("mst=$mst") + + val asm = measureTimeMillis { + benchmark.asmExpression() + } + + println("asm=$asm") +} diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt index ecfb4ab20..a33fdb2c4 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt @@ -6,7 +6,7 @@ fun main(args: Array) { val n = 6000 val array = DoubleArray(n * n) { 1.0 } - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) val strides = DefaultStrides(intArrayOf(n, n)) val structure = BufferNDStructure(strides, buffer) diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt index 2d16cc8f4..0241f12ad 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt @@ -26,10 +26,10 @@ fun main(args: Array) { } println("Array mapping finished in $time2 millis") - val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 }) + val buffer = RealBuffer(DoubleArray(n * n) { 1.0 }) val time3 = measureTimeMillis { - val target = DoubleBuffer(DoubleArray(n * n)) + val target = RealBuffer(DoubleArray(n * n)) val res = array.forEachIndexed { index, value -> target[index] = value + 1 } diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 4563e17cf..62b18b4b5 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -24,6 +24,7 @@ For example, the following builder: package scientifik.kmath.asm.generated; import java.util.Map; +import scientifik.kmath.asm.internal.MapIntrinsics; import scientifik.kmath.expressions.Expression; import scientifik.kmath.operations.RealField; @@ -37,23 +38,23 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression arguments) { - return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); + return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D); } } + ``` ### Example Usage -This API is an extension to MST and MSTExpression APIs. You may optimize both MST and MSTExpression: +This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression: ```kotlin RealField.mstInField { symbol("x") + 2 }.compile() -RealField.expression("2+2".parseMath()) +RealField.expression("x+2".parseMath()) ``` ### Known issues -- Using numeric algebras causes boxing and calling bridge methods. - The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead. - This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt deleted file mode 100644 index 07194a7bb..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ /dev/null @@ -1,76 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.operations.* - -object MSTAlgebra : NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) - - override fun symbol(value: String): MST = MST.Symbolic(value) - - override fun unaryOperation(operation: String, arg: MST): MST = - MST.Unary(operation, arg) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MST.Binary(operation, left, right) -} - -object MSTSpace : Space, NumericAlgebra { - override val zero: MST = number(0.0) - - override fun number(value: Number): MST = MST.Numeric(value) - - override fun symbol(value: String): MST = MST.Symbolic(value) - - override fun add(a: MST, b: MST): MST = - binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - - override fun multiply(a: MST, k: Number): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) -} - -object MSTRing : Ring, NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) - - override val zero: MST = MSTSpace.number(0.0) - override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - - override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) -} - -object MSTField : Field{ - override fun symbol(value: String): MST = MST.Symbolic(value) - override fun number(value: Number): MST = MST.Numeric(value) - - override val zero: MST = MSTSpace.number(0.0) - override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - - - override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun divide(a: MST, b: MST): MST = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) -} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt deleted file mode 100644 index 61703cac7..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.FunctionalExpressionField -import scientifik.kmath.expressions.FunctionalExpressionRing -import scientifik.kmath.expressions.FunctionalExpressionSpace -import scientifik.kmath.operations.* - -/** - * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. - */ -class MSTExpression(val algebra: Algebra, val mst: MST) : Expression { - - /** - * Substitute algebra raw value - */ - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra{ - override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right) - - override fun number(value: Number): T = if(algebra is NumericAlgebra){ - algebra.number(value) - } else{ - error("Numeric nodes are not supported by $this") - } - } - - override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} - - -inline fun , E : Algebra> A.mst( - mstAlgebra: E, - block: E.() -> MST -): MSTExpression = MSTExpression(this, mstAlgebra.block()) - -inline fun Space.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - MSTExpression(this, MSTSpace.block()) - -inline fun Ring.mstInRing(block: MSTRing.() -> MST): MSTExpression = - MSTExpression(this, MSTRing.block()) - -inline fun Field.mstInField(block: MSTField.() -> MST): MSTExpression = - MSTExpression(this, MSTField.block()) - -inline fun > FunctionalExpressionSpace.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - algebra.mstInSpace(block) - -inline fun > FunctionalExpressionRing.mstInRing(block: MSTRing.() -> MST): MSTExpression = - algebra.mstInRing(block) - -inline fun > FunctionalExpressionField.mstInField(block: MSTField.() -> MST): MSTExpression = - algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt new file mode 100644 index 000000000..007cf57c4 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt @@ -0,0 +1,72 @@ +package scientifik.kmath.ast + +import scientifik.kmath.operations.* + +object MstAlgebra : NumericAlgebra { + override fun number(value: Number): MST = MST.Numeric(value) + + override fun symbol(value: String): MST = MST.Symbolic(value) + + override fun unaryOperation(operation: String, arg: MST): MST = + MST.Unary(operation, arg) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MST.Binary(operation, left, right) +} + +object MstSpace : Space, NumericAlgebra { + override val zero: MST = number(0.0) + + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + + override fun add(a: MST, b: MST): MST = + binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} + +object MstRing : Ring, NumericAlgebra { + override val zero: MST = number(0.0) + override val one: MST = number(1.0) + + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} + +object MstField : Field { + override val zero: MST = number(0.0) + override val one: MST = number(1.0) + + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MstAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) +} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt new file mode 100644 index 000000000..1468c3ad4 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -0,0 +1,55 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.FunctionalExpressionField +import scientifik.kmath.expressions.FunctionalExpressionRing +import scientifik.kmath.expressions.FunctionalExpressionSpace +import scientifik.kmath.operations.* + +/** + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. + */ +class MstExpression(val algebra: Algebra, val mst: MST) : Expression { + + /** + * Substitute algebra raw value + */ + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T = + algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if (algebra is NumericAlgebra) + algebra.number(value) + else + error("Numeric nodes are not supported by $this") + } + + override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} + + +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MstExpression = MstExpression(this, mstAlgebra.block()) + +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = + MstExpression(this, MstSpace.block()) + +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = + MstExpression(this, MstRing.block()) + +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = + MstExpression(this, MstField.block()) + +inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression = + algebra.mstInSpace(block) + +inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression = + algebra.mstInRing(block) + +inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression = + algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt index cec61a8ff..30a92c5ae 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt @@ -16,15 +16,15 @@ import scientifik.kmath.operations.SpaceOperations * TODO move to common */ private object ArithmeticsEvaluator : Grammar() { - val num by token("-?[\\d.]+(?:[eE]-?\\d+)?") - val lpar by token("\\(") - val rpar by token("\\)") - val mul by token("\\*") - val pow by token("\\^") - val div by token("/") - val minus by token("-") - val plus by token("\\+") - val ws by token("\\s+", ignore = true) + val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) + val lpar by token("\\(".toRegex()) + val rpar by token("\\)".toRegex()) + val mul by token("\\*".toRegex()) + val pow by token("\\^".toRegex()) + val div by token("/".toRegex()) + val minus by token("-".toRegex()) + val plus by token("\\+".toRegex()) + val ws by token("\\s+".toRegex(), ignore = true) val number: Parser by num use { MST.Numeric(text.toDouble()) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index a3af80ccd..ef2330533 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,12 +1,10 @@ package scientifik.kmath.asm -import org.objectweb.asm.Type import scientifik.kmath.asm.internal.AsmBuilder -import scientifik.kmath.asm.internal.buildExpectationStack +import scientifik.kmath.asm.internal.buildAlgebraOperationCall import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTExpression +import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra @@ -29,43 +27,21 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< loadTConstant(constant) } - is MST.Unary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) - visit(node.value) + is MST.Unary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "unaryOperation", + arity = 1 + ) { visit(node.value) } - if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "unaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - tArity = 1 - ) - } - is MST.Binary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) + is MST.Binary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "binaryOperation", + arity = 2 + ) { visit(node.left) visit(node.right) - - if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "binaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - tArity = 2 - ) } } } @@ -79,6 +55,6 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) /** - * Optimize performance of an [MSTExpression] using ASM codegen + * Optimize performance of an [MstExpression] using ASM codegen */ -inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) +inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 536d6136d..cea6be933 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -1,8 +1,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.AALOAD -import org.objectweb.asm.Opcodes.RETURN +import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST @@ -18,6 +17,7 @@ import kotlin.reflect.KClass * @param T the type of AsmExpression to unwrap. * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. + * @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. */ internal class AsmBuilder internal constructor( private val classOfT: KClass<*>, @@ -37,8 +37,19 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + /** + * ASM Type for [algebra] + */ private val tAlgebraType: Type = algebra::class.asm + + /** + * ASM type for [T] + */ internal val tType: Type = classOfT.asm + + /** + * ASM type for new class + */ private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** @@ -60,15 +71,31 @@ internal class AsmBuilder internal constructor( * Method visitor of `invoke` method of the subclass. */ private lateinit var invokeMethodVisitor: InstructionAdapter - internal var primitiveMode = false - @Suppress("PropertyName") - internal var PRIMITIVE_MASK: Type = OBJECT_TYPE + /** + * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. + */ + internal var primitiveMode: Boolean = false - @Suppress("PropertyName") - internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE - private val typeStack = Stack() - internal val expectationStack: Stack = Stack().apply { push(tType) } + /** + * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMask: Type = OBJECT_TYPE + + /** + * Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMaskBoxed: Type = OBJECT_TYPE + + /** + * Stack of useful objects types on stack to verify types. + */ + private val typeStack: ArrayDeque = ArrayDeque() + + /** + * Stack of useful objects types on stack expected by algebra calls. + */ + internal val expectationStack: ArrayDeque = ArrayDeque().apply { push(tType) } /** * The cache for instance built by this builder. @@ -86,14 +113,14 @@ internal class AsmBuilder internal constructor( if (SIGNATURE_LETTERS.containsKey(classOfT)) { primitiveMode = true - PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) - PRIMITIVE_MASK_BOXED = tType + primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) + primitiveMaskBoxed = tType } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( - Opcodes.V1_8, - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, classType.internalName, "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", OBJECT_TYPE.internalName, @@ -101,7 +128,7 @@ internal class AsmBuilder internal constructor( ) visitField( - access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + access = ACC_PRIVATE or ACC_FINAL, name = "algebra", descriptor = tAlgebraType.descriptor, signature = null, @@ -110,7 +137,7 @@ internal class AsmBuilder internal constructor( ) visitField( - access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + access = ACC_PRIVATE or ACC_FINAL, name = "constants", descriptor = OBJECT_ARRAY_TYPE.descriptor, signature = null, @@ -119,7 +146,7 @@ internal class AsmBuilder internal constructor( ) visitMethod( - Opcodes.ACC_PUBLIC, + ACC_PUBLIC, "", Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), null, @@ -159,7 +186,7 @@ internal class AsmBuilder internal constructor( } visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + ACC_PUBLIC or ACC_FINAL, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", @@ -195,7 +222,7 @@ internal class AsmBuilder internal constructor( } visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, "invoke", Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), null, @@ -238,34 +265,43 @@ internal class AsmBuilder internal constructor( } /** - * Loads a constant from + * Loads a [T] constant from [constants]. */ internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) - if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) return } loadConstant(value as Any, tType) } + /** + * Boxes the current value and pushes it. + */ private fun box(): Unit = invokeMethodVisitor.invokestatic( tType.internalName, "valueOf", - Type.getMethodDescriptor(tType, PRIMITIVE_MASK), + Type.getMethodDescriptor(tType, primitiveMask), false ) + /** + * Unboxes the current boxed value and pushes it. + */ private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK), - Type.getMethodDescriptor(PRIMITIVE_MASK), + NUMBER_CONVERTER_METHODS.getValue(primitiveMask), + Type.getMethodDescriptor(primitiveMask), false ) + /** + * Loads [java.lang.Object] constant from constants. + */ private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() @@ -275,6 +311,9 @@ internal class AsmBuilder internal constructor( checkcast(type) } + /** + * Loads this variable. + */ private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) /** @@ -305,46 +344,40 @@ internal class AsmBuilder internal constructor( } loadConstant(value, boxed) + if (!mustBeBoxed) unbox() else invokeMethodVisitor.checkcast(tType) } /** - * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be + * provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) + load(invokeArgumentsVar, MAP_TYPE) + aconst(name) - if (defaultValue != null) { - loadStringConstant(name) + if (defaultValue != null) loadTConstant(defaultValue) + else + aconst(null) - invokeinterface( - MAP_TYPE.internalName, - "getOrDefault", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) - ) - - invokeMethodVisitor.checkcast(tType) - return - } - - loadStringConstant(name) - - invokeinterface( - MAP_TYPE.internalName, - "get", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE), + false ) - invokeMethodVisitor.checkcast(tType) - val expectedType = expectationStack.pop()!! + checkcast(tType) + + val expectedType = expectationStack.pop() if (expectedType.sort == Type.OBJECT) typeStack.push(tType) else { unbox() - typeStack.push(PRIMITIVE_MASK) + typeStack.push(primitiveMask) } } @@ -358,7 +391,7 @@ internal class AsmBuilder internal constructor( /** * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is - * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be + * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be * called before the arguments and this operation. * * The result is casted to [T] automatically. @@ -367,12 +400,12 @@ internal class AsmBuilder internal constructor( owner: String, method: String, descriptor: String, - tArity: Int, - opcode: Int = Opcodes.INVOKEINTERFACE + expectedArity: Int, + opcode: Int = INVOKEINTERFACE ) { run loop@{ - repeat(tArity) { - if (typeStack.empty()) return@loop + repeat(expectedArity) { + if (typeStack.isEmpty()) return@loop typeStack.pop() } } @@ -382,18 +415,18 @@ internal class AsmBuilder internal constructor( owner, method, descriptor, - opcode == Opcodes.INVOKEINTERFACE + opcode == INVOKEINTERFACE ) invokeMethodVisitor.checkcast(tType) val isLastExpr = expectationStack.size == 1 - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() if (expectedType.sort == Type.OBJECT || isLastExpr) typeStack.push(tType) else { unbox() - typeStack.push(PRIMITIVE_MASK) + typeStack.push(primitiveMask) } } @@ -404,7 +437,7 @@ internal class AsmBuilder internal constructor( internal companion object { /** - * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. + * Maps JVM primitive numbers boxed types to their primitive ASM types. */ private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( @@ -417,8 +450,14 @@ internal class AsmBuilder internal constructor( ) } + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + /** + * Maps primitive ASM types to [Number] functions unboxing them. + */ private val NUMBER_CONVERTER_METHODS: Map by lazy { hashMapOf( Type.BYTE_TYPE to "byteValue", @@ -434,14 +473,46 @@ internal class AsmBuilder internal constructor( * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + + /** + * ASM type for [Expression]. + */ internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + + /** + * ASM type for [java.lang.Number]. + */ internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + + /** + * ASM type for [java.util.Map]. + */ internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + + /** + * ASM type for [java.lang.Object]. + */ internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } + /** + * ASM type for array of [java.lang.Object]. + */ @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + + /** + * ASM type for [Algebra]. + */ internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + + /** + * ASM type for [java.lang.String]. + */ internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } + + /** + * ASM type for MapIntrinsics. + */ + internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt deleted file mode 100644 index 41dbf5807..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ /dev/null @@ -1,22 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.ast.MST -import scientifik.kmath.expressions.Expression - -/** - * Creates a class name for [Expression] subclassed to implement [mst] provided. - * - * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there - * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. - */ -internal tailrec fun buildName(mst: MST, collision: Int = 0): String { - val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(mst, collision + 1) -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt deleted file mode 100644 index af5c1049d..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ /dev/null @@ -1,17 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.FieldVisitor -import org.objectweb.asm.MethodVisitor - -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = - ClassWriter(flags).apply(block) - -internal inline fun ClassWriter.visitField( - access: Int, - name: String, - descriptor: String, - signature: String?, - value: Any?, - block: FieldVisitor.() -> Unit -): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt deleted file mode 100644 index dc0b35531..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt +++ /dev/null @@ -1,7 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.Type -import kotlin.reflect.KClass - -internal val KClass<*>.asm: Type - get() = Type.getType(java) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt new file mode 100644 index 000000000..46d07976d --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -0,0 +1,148 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.INVOKEVIRTUAL +import org.objectweb.asm.commons.InstructionAdapter +import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import kotlin.reflect.KClass + +private val methodNameAdapters: Map, String> by lazy { + hashMapOf( + "+" to 2 to "add", + "*" to 2 to "multiply", + "/" to 2 to "divide", + "+" to 1 to "unaryPlus", + "-" to 1 to "unaryMinus", + "-" to 2 to "minus" + ) +} + +internal val KClass<*>.asm: Type + get() = Type.getType(java) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor]. + */ +private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. + */ +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) + +/** + * Constructs a [Label], then applies it to this visitor. + */ +internal fun MethodVisitor.label(): Label { + val l = Label() + visitLabel(l) + return l +} + +/** + * Creates a class name for [Expression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} + +@Suppress("FunctionName") +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) + +internal inline fun ClassWriter.visitField( + access: Int, + name: String, + descriptor: String, + signature: String?, + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) + +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds + * type expectation stack for needed arity. + * + * @return `true` if contains, else `false`. + */ +private fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { + val theName = methodNameAdapters[name to arity] ?: name + val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null + val t = if (primitiveMode && hasSpecific) primitiveMask else tType + repeat(arity) { expectationStack.push(t) } + return hasSpecific +} + +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts + * [AsmBuilder.invokeAlgebraOperation] of this method. + * + * @return `true` if contains, else `false`. + */ +private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { + val theName = methodNameAdapters[name to arity] ?: name + + context.javaClass.methods.find { + var suitableSignature = it.name == theName && it.parameters.size == arity + + if (primitiveMode && it.isBridge) + suitableSignature = false + + suitableSignature + } ?: return false + + val owner = context::class.asm + + invokeAlgebraOperation( + owner = owner.internalName, + method = theName, + descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), + expectedArity = arity, + opcode = INVOKEVIRTUAL + ) + + return true +} + +/** + * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. + */ +internal fun AsmBuilder.buildAlgebraOperationCall( + context: Algebra, + name: String, + fallbackMethodName: String, + arity: Int, + parameters: AsmBuilder.() -> Unit +) { + loadAlgebra() + if (!buildExpectationStack(context, name, arity)) loadStringConstant(name) + parameters() + + if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = fallbackMethodName, + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + *Array(arity) { AsmBuilder.OBJECT_TYPE } + ), + + expectedArity = arity + ) +} + diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt new file mode 100644 index 000000000..7f7126b55 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt @@ -0,0 +1,7 @@ +@file:JvmName("MapIntrinsics") + +package scientifik.kmath.asm.internal + +internal fun Map.getOrFail(key: K, default: V?): V { + return this[key] ?: default ?: error("Parameter not found: $key") +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt deleted file mode 100644 index aaae02ebb..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ /dev/null @@ -1,9 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.commons.InstructionAdapter - -internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) - -internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = - instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt deleted file mode 100644 index 4c7a0d57e..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ /dev/null @@ -1,61 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.Opcodes -import org.objectweb.asm.Type -import scientifik.kmath.operations.Algebra - -private val methodNameAdapters: Map by lazy { - hashMapOf( - "+" to "add", - "*" to "multiply", - "/" to "divide" - ) -} - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds - * type expectation stack for needed arity. - * - * @return `true` if contains, else `false`. - */ -internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name - - val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType - repeat(arity) { expectationStack.push(t) } - - return hasSpecific -} - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts - * [AsmBuilder.invokeAlgebraOperation] of this method. - * - * @return `true` if contains, else `false`. - */ -internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name - - val method = - context.javaClass.methods.find { - var suitableSignature = it.name == aName && it.parameters.size == arity - - if (primitiveMode && it.isBridge) - suitableSignature = false - - suitableSignature - } ?: return false - - val owner = context::class.java.name.replace('.', '/') - - invokeAlgebraOperation( - owner = owner, - method = aName, - descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }), - tArity = arity, - opcode = Opcodes.INVOKEVIRTUAL - ) - - return true -} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 4c2be811e..3acc6eb28 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmAlgebras { +internal class TestAsmAlgebras { @Test fun space() { val res1 = ByteRing.mstInSpace { @@ -92,8 +92,8 @@ class TestAsmAlgebras { "+", (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), - 1 / 2 + number(2.0) * one - ) + number(1) / 2 + number(2.0) * one + ) + zero }("x" to 2.0) val res2 = RealField.mstInField { @@ -101,8 +101,8 @@ class TestAsmAlgebras { "+", (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), - 1 / 2 + number(2.0) * one - ) + number(1) / 2 + number(2.0) * one + ) + zero }.compile()("x" to 2.0) assertEquals(res1, res2) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 824201aa7..36c254c38 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmExpressions { +internal class TestAsmExpressions { @Test fun testUnaryOperationInvocation() { val expression = RealField.mstInSpace { -symbol("x") }.compile() diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt new file mode 100644 index 000000000..b571e076f --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt @@ -0,0 +1,46 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + assertEquals(2.0, expr("x" to 2.0)) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + assertEquals(-2.0, expr("x" to 2.0)) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 2.0)) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + assertEquals(1.0, expr("x" to 2.0)) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt new file mode 100644 index 000000000..aafc75448 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt @@ -0,0 +1,22 @@ +package scietifik.kmath.asm + +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestAsmVariables { + @Test + fun testVariableWithoutDefault() { + val expr = ByteRing.mstInRing { symbol("x") } + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testVariableWithoutDefaultFails() { + val expr = ByteRing.mstInRing { symbol("x") } + assertFailsWith { expr() } + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 08d7fff47..23203172e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField import kotlin.test.Test import kotlin.test.assertEquals -class AsmTest { +internal class AsmTest { @Test fun `compile MST`() { val mst = "2+2*(2+2)".parseMath() diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt new file mode 100644 index 000000000..13e79d60e --- /dev/null +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.commons.random + +import scientifik.kmath.prob.RandomGenerator + +class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : + org.apache.commons.math3.random.RandomGenerator { + private var generator = factory(intArrayOf()) + + override fun nextBoolean(): Boolean = generator.nextBoolean() + + override fun nextFloat(): Float = generator.nextDouble().toFloat() + + override fun setSeed(seed: Int) { + generator = factory(intArrayOf(seed)) + } + + override fun setSeed(seed: IntArray) { + generator = factory(seed) + } + + override fun setSeed(seed: Long) { + setSeed(seed.toInt()) + } + + override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + override fun nextInt(): Int = generator.nextInt() + + override fun nextInt(n: Int): Int = generator.nextInt(n) + + override fun nextGaussian(): Double = TODO() + + override fun nextDouble(): Double = generator.nextDouble() + + override fun nextLong(): Long = generator.nextLong() +} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt index bcb3ea87b..eb1b5b69a 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt @@ -18,7 +18,7 @@ object Transformations { private fun Buffer.toArray(): Array = Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } - private fun Buffer.asArray() = if (this is DoubleBuffer) { + private fun Buffer.asArray() = if (this is RealBuffer) { array } else { DoubleArray(size) { i -> get(i) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt new file mode 100644 index 000000000..333b77cb4 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * A simple geometric domain + */ +interface Domain { + operator fun contains(point: Point): Boolean + + /** + * Number of hyperspace dimensions + */ + val dimension: Int +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt new file mode 100644 index 000000000..e0019c96b --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.RealBuffer +import scientifik.kmath.structures.indices + +/** + * + * HyperSquareDomain class. + * + * @author Alexander Nozik + */ +class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { + + override operator fun contains(point: Point): Boolean = point.indices.all { i -> + point[i] in lower[i]..upper[i] + } + + override val dimension: Int get() = lower.size + + override fun getLowerBound(num: Int, point: Point): Double? = lower[num] + + override fun getLowerBound(num: Int): Double? = lower[num] + + override fun getUpperBound(num: Int, point: Point): Double? = upper[num] + + override fun getUpperBound(num: Int): Double? = upper[num] + + override fun nearestInDomain(point: Point): Point { + val res: DoubleArray = DoubleArray(point.size) { i -> + when { + point[i] < lower[i] -> lower[i] + point[i] > upper[i] -> upper[i] + else -> point[i] + } + } + return RealBuffer(*res) + } + + override fun volume(): Double { + var res = 1.0 + for (i in 0 until dimension) { + if (lower[i].isInfinite() || upper[i].isInfinite()) { + return Double.POSITIVE_INFINITY + } + if (upper[i] > lower[i]) { + res *= upper[i] - lower[i] + } + } + return res + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt new file mode 100644 index 000000000..89115887e --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * n-dimensional volume + * + * @author Alexander Nozik + */ +interface RealDomain: Domain { + + fun nearestInDomain(point: Point): Point + + /** + * The lower edge for the domain going down from point + * @param num + * @param point + * @return + */ + fun getLowerBound(num: Int, point: Point): Double? + + /** + * The upper edge of the domain going up from point + * @param num + * @param point + * @return + */ + fun getUpperBound(num: Int, point: Point): Double? + + /** + * Global lower edge + * @param num + * @return + */ + fun getLowerBound(num: Int): Double? + + /** + * Global upper edge + * @param num + * @return + */ + fun getUpperBound(num: Int): Double? + + /** + * Hyper volume + * @return + */ + fun volume(): Double + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt new file mode 100644 index 000000000..e49fd3b37 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +class UnconstrainedDomain(override val dimension: Int) : RealDomain { + + override operator fun contains(point: Point): Boolean = true + + override fun getLowerBound(num: Int, point: Point): Double? = Double.NEGATIVE_INFINITY + + override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY + + override fun getUpperBound(num: Int, point: Point): Double? = Double.POSITIVE_INFINITY + + override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY + + override fun nearestInDomain(point: Point): Point = point + + override fun volume(): Double = Double.POSITIVE_INFINITY + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt new file mode 100644 index 000000000..ef521d5ea --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt @@ -0,0 +1,48 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.asBuffer + +inline class UnivariateDomain(val range: ClosedFloatingPointRange) : RealDomain { + + operator fun contains(d: Double): Boolean = range.contains(d) + + override operator fun contains(point: Point): Boolean { + require(point.size == 0) + return contains(point[0]) + } + + override fun nearestInDomain(point: Point): Point { + require(point.size == 1) + val value = point[0] + return when{ + value in range -> point + value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() + else -> doubleArrayOf(range.start).asBuffer() + } + } + + override fun getLowerBound(num: Int, point: Point): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int, point: Point): Double? { + require(num == 0) + return range.endInclusive + } + + override fun getLowerBound(num: Int): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int): Double? { + require(num == 0) + return range.endInclusive + } + + override fun volume(): Double = range.endInclusive - range.start + + override val dimension: Int get() = 1 +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 73b18b810..c4c38284b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext { override val elementContext get() = RealField override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix { - val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } return BufferMatrix(rows, columns, buffer) } - override inline fun point(size: Int, initializer: (Int) -> Double): Point = DoubleBuffer(size,initializer) + override inline fun point(size: Int, initializer: (Int) -> Double): Point = RealBuffer(size,initializer) } class BufferMatrix( @@ -102,7 +102,7 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix.unsafeArray(): DoubleArray = if (this is DoubleBuffer) { + fun Buffer.unsafeArray(): DoubleArray = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } @@ -119,6 +119,6 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { companion object { - inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { + inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { val array = DoubleArray(size) { initializer(it) } - return DoubleBuffer(array) + return RealBuffer(array) } /** @@ -51,7 +51,7 @@ interface Buffer { inline fun auto(type: KClass, size: Int, crossinline initializer: (Int) -> T): Buffer { //TODO add resolution based on Annotation or companion resolution return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer @@ -93,7 +93,7 @@ interface MutableBuffer : Buffer { @Suppress("UNCHECKED_CAST") inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer { return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer @@ -109,12 +109,11 @@ interface MutableBuffer : Buffer { auto(T::class, size, initializer) val real: MutableBufferFactory = { size: Int, initializer: (Int) -> Double -> - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) } } } - inline class ListBuffer(val list: List) : Buffer { override val size: Int @@ -163,57 +162,6 @@ class ArrayBuffer(private val array: Array) : MutableBuffer { fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) -inline class ShortBuffer(val array: ShortArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Short = array[index] - - override fun set(index: Int, value: Short) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) - -} - -fun ShortArray.asBuffer() = ShortBuffer(this) - -inline class IntBuffer(val array: IntArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Int = array[index] - - override fun set(index: Int, value: Int) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = IntBuffer(array.copyOf()) - -} - -fun IntArray.asBuffer() = IntBuffer(this) - -inline class LongBuffer(val array: LongArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Long = array[index] - - override fun set(index: Int, value: Long) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) - -} - -fun LongArray.asBuffer() = LongBuffer(this) - inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { override val size: Int get() = buffer.size diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt new file mode 100644 index 000000000..749e4eeec --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt @@ -0,0 +1,53 @@ +package scientifik.kmath.structures + +import kotlin.experimental.and + +enum class ValueFlag(val mask: Byte) { + NAN(0b0000_0001), + MISSING(0b0000_0010), + NEGATIVE_INFINITY(0b0000_0100), + POSITIVE_INFINITY(0b0000_1000) +} + +/** + * A buffer with flagged values + */ +interface FlaggedBuffer : Buffer { + fun getFlag(index: Int): Byte +} + +/** + * The value is valid if all flags are down + */ +fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte() + +fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte() + +fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING) + +/** + * A real buffer which supports flags for each value like NaN or Missing + */ +class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer, Buffer { + init { + require(values.size == flags.size) { "Values and flags must have the same dimensions" } + } + + override fun getFlag(index: Int): Byte = flags[index] + + override val size: Int get() = values.size + + override fun get(index: Int): Double? = if (isValid(index)) values[index] else null + + override fun iterator(): Iterator = values.indices.asSequence().map { + if (isValid(it)) values[it] else null + }.iterator() +} + +inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { + for(i in indices){ + if(isValid(i)){ + block(values[i]) + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt new file mode 100644 index 000000000..a354c5de0 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class IntBuffer(val array: IntArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Int = array[index] + + override fun set(index: Int, value: Int) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + IntBuffer(array.copyOf()) + +} + + +fun IntArray.asBuffer() = IntBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt new file mode 100644 index 000000000..fa6229a71 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt @@ -0,0 +1,19 @@ +package scientifik.kmath.structures + +inline class LongBuffer(val array: LongArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Long = array[index] + + override fun set(index: Int, value: Long) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + LongBuffer(array.copyOf()) + +} + +fun LongArray.asBuffer() = LongBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt similarity index 59% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt index c0b7f713b..f48ace3a9 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt @@ -1,6 +1,6 @@ package scientifik.kmath.structures -inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { +inline class RealBuffer(val array: DoubleArray) : MutableBuffer { override val size: Int get() = array.size override fun get(index: Int): Double = array[index] @@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { override fun iterator() = array.iterator() override fun copy(): MutableBuffer = - DoubleBuffer(array.copyOf()) + RealBuffer(array.copyOf()) } @Suppress("FunctionName") -inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) }) +inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) @Suppress("FunctionName") -fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles) +fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) /** * Transform buffer of doubles into array for high performance operations */ val MutableBuffer.array: DoubleArray - get() = if (this is DoubleBuffer) { + get() = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } } -fun DoubleArray.asBuffer() = DoubleBuffer(this) \ No newline at end of file +fun DoubleArray.asBuffer() = RealBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 22b33aa4d..8c90f90c7 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) : override val one by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) /** * Inline transform an NDStructure to @@ -89,7 +89,7 @@ class RealNDField(override val shape: IntArray) : */ inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } - return BufferedNDFieldElement(this, DoubleBuffer(array)) + return BufferedNDFieldElement(this, RealBuffer(array)) } /** @@ -103,7 +103,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int */ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } - return BufferedNDFieldElement(context, DoubleBuffer(array)) + return BufferedNDFieldElement(context, RealBuffer(array)) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt new file mode 100644 index 000000000..f4b2f7d13 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class ShortBuffer(val array: ShortArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Short = array[index] + + override fun set(index: Int, value: Short) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + ShortBuffer(array.copyOf()) + +} + + +fun ShortArray.asBuffer() = ShortBuffer(this) \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt index bef21a680..54da66bb7 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt @@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.* import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer /** @@ -45,7 +45,7 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< /** * Specialized flow chunker for real buffer */ -fun Flow.chunked(bufferSize: Int): Flow = flow { +fun Flow.chunked(bufferSize: Int): Flow = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } if (this@chunked is BlockingRealChain) { @@ -61,13 +61,13 @@ fun Flow.chunked(bufferSize: Int): Flow = flow { array[counter] = element counter++ if (counter == bufferSize) { - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) emit(buffer) counter = 0 } } if (counter > 0) { - emit(DoubleBuffer(counter) { array[it] }) + emit(RealBuffer(counter) { array[it] }) } } } diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index ff4c835ed..2b89904e3 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -7,31 +7,28 @@ import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField import scientifik.kmath.operations.SpaceElement import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asIterable import kotlin.math.sqrt +typealias RealPoint = Point + fun DoubleArray.asVector() = RealVector(this.asBuffer()) fun List.asVector() = RealVector(this.asBuffer()) - object VectorL2Norm : Norm, Double> { override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) } inline class RealVector(private val point: Point) : - SpaceElement, RealVector, VectorSpace>, Point { + SpaceElement>, RealPoint { - override val context: VectorSpace - get() = space( - point.size - ) + override val context: VectorSpace get() = space(point.size) - override fun unwrap(): Point = point + override fun unwrap(): RealPoint = point - override fun Point.wrap(): RealVector = - RealVector(this) + override fun RealPoint.wrap(): RealVector = RealVector(this) override val size: Int get() = point.size @@ -44,16 +41,12 @@ inline class RealVector(private val point: Point) : private val spaceCache = HashMap>() inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = - RealVector(DoubleBuffer(dim, initializer)) + RealVector(RealBuffer(dim, initializer)) operator fun invoke(vararg values: Double): RealVector = values.asVector() - fun space(dim: Int): BufferVectorSpace = - spaceCache.getOrPut(dim) { - BufferVectorSpace( - dim, - RealField - ) { size, init -> Buffer.real(size, init) } - } + fun space(dim: Int): BufferVectorSpace = spaceCache.getOrPut(dim) { + BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } + } } } \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt index d9ee4d90b..82c0e86b2 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt @@ -1,8 +1,8 @@ package scientifik.kmath.real -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer /** - * Simplified [DoubleBuffer] to array comparison + * Simplified [RealBuffer] to array comparison */ -fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file +fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 813d89577..65f86eec7 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asIterable import kotlin.math.pow @@ -27,6 +27,10 @@ typealias RealMatrix = Matrix fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) +fun Array.toMatrix(): RealMatrix{ + return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } +} + fun Sequence.toMatrix(): RealMatrix = toList().let { MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } } @@ -129,22 +133,22 @@ fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = fun Matrix.extractColumn(columnIndex: Int): RealMatrix = extractColumns(columnIndex..columnIndex) -fun Matrix.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> val column = columns[j] with(elementContext) { sum(column.asIterable()) } } -fun Matrix.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().average() } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 329af72a1..43d50ad20 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -1,17 +1,9 @@ package scientifik.kmath.histogram +import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer -import scientifik.kmath.structures.DoubleBuffer - -/** - * A simple geometric domain - * TODO move to geometry module - */ -interface Domain { - operator fun contains(vector: Point): Boolean - val dimension: Int -} +import scientifik.kmath.structures.RealBuffer /** * The bin in the histogram. The histogram is by definition always done in the real space @@ -51,9 +43,9 @@ interface MutableHistogram> : Histogram { fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) fun MutableHistogram.put(vararg point: Number) = - put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) + put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(DoubleBuffer(point)) +fun MutableHistogram.put(vararg point: Double) = put(RealBuffer(point)) fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 4438f5d60..628a68461 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -1,8 +1,8 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point -import scientifik.kmath.real.asVector import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.real.asVector import scientifik.kmath.structures.* import kotlin.math.floor @@ -21,7 +21,7 @@ data class BinDef>(val space: SpaceOperations>, val c class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - override fun contains(vector: Point): Boolean = def.contains(vector) + override fun contains(point: Point): Boolean = def.contains(point) override val dimension: Int get() = def.center.size @@ -50,7 +50,7 @@ class RealHistogram( override val dimension: Int get() = lower.size - private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } + private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index dcc5ac0eb..af01205bf 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) - override fun contains(vector: Buffer): Boolean = contains(vector[0]) + override fun contains(point: Buffer): Boolean = contains(point[0]) internal operator fun inc() = this.also { counter.increment() } diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt index 0896f0dcb..7999aa2ab 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt @@ -10,6 +10,7 @@ interface MemorySpec { val objectSize: Int fun MemoryReader.read(offset: Int): T + //TODO consider thread safety fun MemoryWriter.write(offset: Int, value: T) } diff --git a/settings.gradle.kts b/settings.gradle.kts index 465ecfca8..487e1d87f 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -3,10 +3,12 @@ pluginManagement { val toolsVersion = "0.5.0" plugins { + id("kotlinx.benchmark") version "0.2.0-dev-8" id("scientifik.mpp") version toolsVersion id("scientifik.jvm") version toolsVersion id("scientifik.atomic") version toolsVersion id("scientifik.publish") version toolsVersion + kotlin("plugin.allopen") version "1.3.72" } repositories {