From 675ace272c06f5bada1640d8c9219687f3e1ee91 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Tue, 23 Jun 2020 03:38:20 +0700 Subject: [PATCH 01/23] Minor Gradle settings modification, add benchmarks of different Expression implementatinos --- examples/build.gradle.kts | 10 ++-- .../ast/ExpressionsInterpretersBenchmark.kt | 54 +++++++++++++++++++ .../kmath/structures/ViktorBenchmark.kt | 17 ++---- settings.gradle.kts | 2 + 4 files changed, 64 insertions(+), 19 deletions(-) create mode 100644 examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 2fab47ac0..eadfc3f6b 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -1,16 +1,13 @@ -import org.jetbrains.kotlin.allopen.gradle.AllOpenExtension import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.71" - id("kotlinx.benchmark") version "0.2.0-dev-7" + kotlin("plugin.allopen") + id("kotlinx.benchmark") } -configure { - annotation("org.openjdk.jmh.annotations.State") -} +allOpen.annotation("org.openjdk.jmh.annotations.State") repositories { maven("http://dl.bintray.com/kyonifer/maven") @@ -24,6 +21,7 @@ sourceSets { } dependencies { + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt new file mode 100644 index 000000000..c5474e1d2 --- /dev/null +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -0,0 +1,54 @@ +package scientifik.kmath.ast + +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import scientifik.kmath.asm.compile +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.expressionInField +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.RealField +import kotlin.random.Random + +@State(Scope.Benchmark) +class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + private val random: Random = Random(1) + + @Benchmark + fun functionalExpression() { + val expr = algebra.expressionInField { + variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) + } + + invokeAndSum(expr) + } + + @Benchmark + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + } + + invokeAndSum(expr) + } + + @Benchmark + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + }.compile() + + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index be4115d81..5dc166cd9 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -23,9 +23,7 @@ class ViktorBenchmark { fun `Automatic field addition`() { autoField.run { var res = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += one } } } @@ -33,9 +31,7 @@ class ViktorBenchmark { fun `Viktor field addition`() { viktorField.run { var res = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } @@ -43,9 +39,7 @@ class ViktorBenchmark { fun `Raw Viktor`() { val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) var res = one - repeat(n) { - res = res + one - } + repeat(n) { res = res + one } } @Benchmark @@ -53,10 +47,7 @@ class ViktorBenchmark { realField.run { val fortyTwo = produce { 42.0 } var res = one - - repeat(n) { - res = ln(fortyTwo) - } + repeat(n) { res = ln(fortyTwo) } } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 465ecfca8..487e1d87f 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -3,10 +3,12 @@ pluginManagement { val toolsVersion = "0.5.0" plugins { + id("kotlinx.benchmark") version "0.2.0-dev-8" id("scientifik.mpp") version toolsVersion id("scientifik.jvm") version toolsVersion id("scientifik.atomic") version toolsVersion id("scientifik.publish") version toolsVersion + kotlin("plugin.allopen") version "1.3.72" } repositories { From 668d13c9d1b276d81f155d05d39e89ebf29b6c4a Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 23 Jun 2020 20:03:45 +0300 Subject: [PATCH 02/23] Minor refactoring + domains --- build.gradle.kts | 2 +- examples/build.gradle.kts | 8 +-- .../kmath/structures/ViktorBenchmark.kt | 10 +-- .../random/CMRandomGeneratorWrapper.kt | 38 +++++++++++ .../kotlin/scientifik/kmath/domains/Domain.kt | 15 +++++ .../kmath/domains/HyperSquareDomain.kt | 67 +++++++++++++++++++ .../scientifik/kmath/domains/RealDomain.kt | 65 ++++++++++++++++++ .../kmath/domains/UnconstrainedDomain.kt | 36 ++++++++++ .../kmath/domains/UnivariateDomain.kt | 48 +++++++++++++ .../scientifik/kmath/real/RealVector.kt | 25 +++---- .../scientifik/kmath/real/realMatrix.kt | 4 ++ .../scientifik/kmath/histogram/Histogram.kt | 10 +-- .../kotlin/scientifik/memory/MemorySpec.kt | 1 + 13 files changed, 294 insertions(+), 35 deletions(-) create mode 100644 kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt diff --git a/build.gradle.kts b/build.gradle.kts index 6d102a77a..052b457c5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("scientifik.publish") apply false } -val kmathVersion by extra("0.1.4-dev-7") +val kmathVersion by extra("0.1.4-dev-8") val bintrayRepo by extra("scientifik") val githubProject by extra("kmath") diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 2fab47ac0..fb47c998f 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.71" - id("kotlinx.benchmark") version "0.2.0-dev-7" + kotlin("plugin.allopen") version "1.3.72" + id("kotlinx.benchmark") version "0.2.0-dev-8" } configure { @@ -33,8 +33,8 @@ dependencies { implementation(project(":kmath-dimensions")) implementation("com.kyonifer:koma-core-ejml:0.12") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") - implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7") - "benchmarksCompile"(sourceSets.main.get().compileClasspath) + implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8") + "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath } // Configure benchmark diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index be4115d81..54105f778 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -20,7 +20,7 @@ class ViktorBenchmark { final val viktorField = ViktorNDField(intArrayOf(dim, dim)) @Benchmark - fun `Automatic field addition`() { + fun automaticFieldAddition() { autoField.run { var res = one repeat(n) { @@ -30,7 +30,7 @@ class ViktorBenchmark { } @Benchmark - fun `Viktor field addition`() { + fun viktorFieldAddition() { viktorField.run { var res = one repeat(n) { @@ -40,7 +40,7 @@ class ViktorBenchmark { } @Benchmark - fun `Raw Viktor`() { + fun rawViktor() { val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) var res = one repeat(n) { @@ -49,7 +49,7 @@ class ViktorBenchmark { } @Benchmark - fun `Real field log`() { + fun realdFieldLog() { realField.run { val fortyTwo = produce { 42.0 } var res = one @@ -61,7 +61,7 @@ class ViktorBenchmark { } @Benchmark - fun `Raw Viktor log`() { + fun rawViktorLog() { val fortyTwo = F64Array.full(dim, dim, init = 42.0) var res: F64Array repeat(n) { diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt new file mode 100644 index 000000000..13e79d60e --- /dev/null +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.commons.random + +import scientifik.kmath.prob.RandomGenerator + +class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : + org.apache.commons.math3.random.RandomGenerator { + private var generator = factory(intArrayOf()) + + override fun nextBoolean(): Boolean = generator.nextBoolean() + + override fun nextFloat(): Float = generator.nextDouble().toFloat() + + override fun setSeed(seed: Int) { + generator = factory(intArrayOf(seed)) + } + + override fun setSeed(seed: IntArray) { + generator = factory(seed) + } + + override fun setSeed(seed: Long) { + setSeed(seed.toInt()) + } + + override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + override fun nextInt(): Int = generator.nextInt() + + override fun nextInt(n: Int): Int = generator.nextInt(n) + + override fun nextGaussian(): Double = TODO() + + override fun nextDouble(): Double = generator.nextDouble() + + override fun nextLong(): Long = generator.nextLong() +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt new file mode 100644 index 000000000..333b77cb4 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * A simple geometric domain + */ +interface Domain { + operator fun contains(point: Point): Boolean + + /** + * Number of hyperspace dimensions + */ + val dimension: Int +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt new file mode 100644 index 000000000..21912b87c --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.indices + +/** + * + * HyperSquareDomain class. + * + * @author Alexander Nozik + */ +class HyperSquareDomain(private val lower: DoubleBuffer, private val upper: DoubleBuffer) : RealDomain { + + override operator fun contains(point: Point): Boolean = point.indices.all { i -> + point[i] in lower[i]..upper[i] + } + + override val dimension: Int get() = lower.size + + override fun getLowerBound(num: Int, point: Point): Double? = lower[num] + + override fun getLowerBound(num: Int): Double? = lower[num] + + override fun getUpperBound(num: Int, point: Point): Double? = upper[num] + + override fun getUpperBound(num: Int): Double? = upper[num] + + override fun nearestInDomain(point: Point): Point { + val res: DoubleArray = DoubleArray(point.size) { i -> + when { + point[i] < lower[i] -> lower[i] + point[i] > upper[i] -> upper[i] + else -> point[i] + } + } + return DoubleBuffer(*res) + } + + override fun volume(): Double { + var res = 1.0 + for (i in 0 until dimension) { + if (lower[i].isInfinite() || upper[i].isInfinite()) { + return Double.POSITIVE_INFINITY + } + if (upper[i] > lower[i]) { + res *= upper[i] - lower[i] + } + } + return res + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt new file mode 100644 index 000000000..89115887e --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +/** + * n-dimensional volume + * + * @author Alexander Nozik + */ +interface RealDomain: Domain { + + fun nearestInDomain(point: Point): Point + + /** + * The lower edge for the domain going down from point + * @param num + * @param point + * @return + */ + fun getLowerBound(num: Int, point: Point): Double? + + /** + * The upper edge of the domain going up from point + * @param num + * @param point + * @return + */ + fun getUpperBound(num: Int, point: Point): Double? + + /** + * Global lower edge + * @param num + * @return + */ + fun getLowerBound(num: Int): Double? + + /** + * Global upper edge + * @param num + * @return + */ + fun getUpperBound(num: Int): Double? + + /** + * Hyper volume + * @return + */ + fun volume(): Double + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt new file mode 100644 index 000000000..e49fd3b37 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point + +class UnconstrainedDomain(override val dimension: Int) : RealDomain { + + override operator fun contains(point: Point): Boolean = true + + override fun getLowerBound(num: Int, point: Point): Double? = Double.NEGATIVE_INFINITY + + override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY + + override fun getUpperBound(num: Int, point: Point): Double? = Double.POSITIVE_INFINITY + + override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY + + override fun nearestInDomain(point: Point): Point = point + + override fun volume(): Double = Double.POSITIVE_INFINITY + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt new file mode 100644 index 000000000..ef521d5ea --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt @@ -0,0 +1,48 @@ +package scientifik.kmath.domains + +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.asBuffer + +inline class UnivariateDomain(val range: ClosedFloatingPointRange) : RealDomain { + + operator fun contains(d: Double): Boolean = range.contains(d) + + override operator fun contains(point: Point): Boolean { + require(point.size == 0) + return contains(point[0]) + } + + override fun nearestInDomain(point: Point): Point { + require(point.size == 1) + val value = point[0] + return when{ + value in range -> point + value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() + else -> doubleArrayOf(range.start).asBuffer() + } + } + + override fun getLowerBound(num: Int, point: Point): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int, point: Point): Double? { + require(num == 0) + return range.endInclusive + } + + override fun getLowerBound(num: Int): Double? { + require(num == 0) + return range.start + } + + override fun getUpperBound(num: Int): Double? { + require(num == 0) + return range.endInclusive + } + + override fun volume(): Double = range.endInclusive - range.start + + override val dimension: Int get() = 1 +} \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index ff4c835ed..23c7e19cb 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -12,26 +12,23 @@ import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asIterable import kotlin.math.sqrt +typealias RealPoint = Point + fun DoubleArray.asVector() = RealVector(this.asBuffer()) fun List.asVector() = RealVector(this.asBuffer()) - object VectorL2Norm : Norm, Double> { override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) } inline class RealVector(private val point: Point) : - SpaceElement, RealVector, VectorSpace>, Point { + SpaceElement>, RealPoint { - override val context: VectorSpace - get() = space( - point.size - ) + override val context: VectorSpace get() = space(point.size) - override fun unwrap(): Point = point + override fun unwrap(): RealPoint = point - override fun Point.wrap(): RealVector = - RealVector(this) + override fun RealPoint.wrap(): RealVector = RealVector(this) override val size: Int get() = point.size @@ -48,12 +45,8 @@ inline class RealVector(private val point: Point) : operator fun invoke(vararg values: Double): RealVector = values.asVector() - fun space(dim: Int): BufferVectorSpace = - spaceCache.getOrPut(dim) { - BufferVectorSpace( - dim, - RealField - ) { size, init -> Buffer.real(size, init) } - } + fun space(dim: Int): BufferVectorSpace = spaceCache.getOrPut(dim) { + BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } + } } } \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 813d89577..0f4ccf2a8 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -27,6 +27,10 @@ typealias RealMatrix = Matrix fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) +fun Array.toMatrix(): RealMatrix{ + return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } +} + fun Sequence.toMatrix(): RealMatrix = toList().let { MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 329af72a1..5199669f5 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -1,18 +1,10 @@ package scientifik.kmath.histogram +import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.DoubleBuffer -/** - * A simple geometric domain - * TODO move to geometry module - */ -interface Domain { - operator fun contains(vector: Point): Boolean - val dimension: Int -} - /** * The bin in the histogram. The histogram is by definition always done in the real space */ diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt index 0896f0dcb..7999aa2ab 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt @@ -10,6 +10,7 @@ interface MemorySpec { val objectSize: Int fun MemoryReader.read(offset: Int): T + //TODO consider thread safety fun MemoryWriter.write(offset: Int, value: T) } From ea8c0db85445cdcdaa06d8ce78f839cf0d2162d6 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 23 Jun 2020 21:46:05 +0300 Subject: [PATCH 03/23] Histogram bin fix --- .../kotlin/scientifik/kmath/histogram/RealHistogram.kt | 4 ++-- .../kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 4438f5d60..f9d815421 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -1,8 +1,8 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point -import scientifik.kmath.real.asVector import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.real.asVector import scientifik.kmath.structures.* import kotlin.math.floor @@ -21,7 +21,7 @@ data class BinDef>(val space: SpaceOperations>, val c class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - override fun contains(vector: Point): Boolean = def.contains(vector) + override fun contains(point: Point): Boolean = def.contains(point) override val dimension: Int get() = def.center.size diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index dcc5ac0eb..af01205bf 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) - override fun contains(vector: Buffer): Boolean = contains(vector[0]) + override fun contains(point: Buffer): Boolean = contains(point[0]) internal operator fun inc() = this.also { counter.increment() } From f7f9ce7817cb66c8b5af6e0a590c5332dc3dc1f0 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 25 Jun 2020 10:07:36 +0700 Subject: [PATCH 04/23] Delete AsmCompiledExpression abstract class, implement dynamic field generation to reduce quantity of cast instructions, minor refactor and renaming of internal APIs --- kmath-ast/README.md | 18 +- .../kmath/asm/internal/AsmBuilder.kt | 220 +++++++++--------- .../asm/internal/AsmCompiledExpression.kt | 18 -- .../kmath/asm/internal/buildName.kt | 3 +- .../kmath/asm/internal/classWriters.kt | 12 +- .../kmath/asm/internal/instructionAdapters.kt | 10 + .../kmath/asm/internal/methodVisitors.kt | 4 +- .../kmath/asm/internal/specialization.kt | 2 +- 8 files changed, 140 insertions(+), 147 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt diff --git a/kmath-ast/README.md b/kmath-ast/README.md index b5ca5886f..4563e17cf 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -24,20 +24,20 @@ For example, the following builder: package scientifik.kmath.asm.generated; import java.util.Map; -import scientifik.kmath.asm.internal.AsmCompiledExpression; -import scientifik.kmath.operations.Algebra; +import scientifik.kmath.expressions.Expression; import scientifik.kmath.operations.RealField; -// The class's name is build with MST's hash-code and collision fixing number. -public final class AsmCompiledExpression_45045_0 extends AsmCompiledExpression { - // Plain constructor - public AsmCompiledExpression_45045_0(Algebra algebra, Object[] constants) { - super(algebra, constants); +public final class AsmCompiledExpression_1073786867_0 implements Expression { + private final RealField algebra; + private final Object[] constants; + + public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) { + this.algebra = algebra; + this.constants = constants; } - // The actual dynamic code: public final Double invoke(Map arguments) { - return (Double)((RealField)super.algebra).add((Double)arguments.get("x"), (Double)2.0D); + return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); } } ``` diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 8f45c4044..536d6136d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.Opcodes.RETURN import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import java.util.* import kotlin.reflect.KClass @@ -36,32 +37,27 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - @Suppress("PrivatePropertyName") - private val T_ALGEBRA_TYPE: Type = algebra::class.asm - - @Suppress("PrivatePropertyName") - internal val T_TYPE: Type = classOfT.asm - - @Suppress("PrivatePropertyName") - private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + private val tAlgebraType: Type = algebra::class.asm + internal val tType: Type = classOfT.asm + private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** - * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `this` variable in invoke method of the built subclass. */ private val invokeThisVar: Int = 0 /** - * Index of `arguments` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `arguments` variable in invoke method of the built subclass. */ private val invokeArgumentsVar: Int = 1 /** - * List of constants to provide to [AsmCompiledExpression] subclass. + * List of constants to provide to the subclass. */ private val constants: MutableList = mutableListOf() /** - * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. + * Method visitor of `invoke` method of the subclass. */ private lateinit var invokeMethodVisitor: InstructionAdapter internal var primitiveMode = false @@ -72,78 +68,92 @@ internal class AsmBuilder internal constructor( @Suppress("PropertyName") internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE private val typeStack = Stack() - internal val expectationStack = Stack().apply { push(T_TYPE) } + internal val expectationStack: Stack = Stack().apply { push(tType) } /** - * The cache of [AsmCompiledExpression] subclass built by this builder. + * The cache for instance built by this builder. */ - private var generatedInstance: AsmCompiledExpression? = null + private var generatedInstance: Expression? = null /** - * Subclasses, loads and instantiates the [AsmCompiledExpression] for given parameters. + * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - fun getInstance(): AsmCompiledExpression { + fun getInstance(): Expression { generatedInstance?.let { return it } - if (SIGNATURE_LETTERS.containsKey(classOfT.java)) { + if (SIGNATURE_LETTERS.containsKey(classOfT)) { primitiveMode = true - PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java) - PRIMITIVE_MASK_BOXED = T_TYPE + PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) + PRIMITIVE_MASK_BOXED = tType } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - CLASS_TYPE.internalName, - "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;", - ASM_COMPILED_EXPRESSION_TYPE.internalName, - arrayOf() + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName) + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "algebra", + descriptor = tAlgebraType.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd ) visitMethod( Opcodes.ACC_PUBLIC, "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), null, null ).instructionAdapter { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 - val l0 = Label() - visitLabel(l0) - load(thisVar, CLASS_TYPE) - load(algebraVar, ALGEBRA_TYPE) + val l0 = label() + load(thisVar, classType) + invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + label() + load(thisVar, classType) + load(algebraVar, tAlgebraType) + putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + label() + load(thisVar, classType) load(constantsVar, OBJECT_ARRAY_TYPE) - - invokespecial( - ASM_COMPILED_EXPRESSION_TYPE.internalName, - "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), - false - ) - - val l1 = Label() - visitLabel(l1) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + label() visitInsn(RETURN) - val l2 = Label() - visitLabel(l2) - visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) visitLocalVariable( "algebra", - ALGEBRA_TYPE.descriptor, - "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;", + tAlgebraType.descriptor, + null, l0, - l2, + l4, algebraVar ) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) visitMaxs(0, 3) visitEnd() } @@ -151,22 +161,20 @@ internal class AsmBuilder internal constructor( visitMethod( Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, "invoke", - Type.getMethodDescriptor(T_TYPE, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}", + Type.getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", null ).instructionAdapter { invokeMethodVisitor = this visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() invokeLabel0Visitor() - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -176,7 +184,7 @@ internal class AsmBuilder internal constructor( visitLocalVariable( "arguments", MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;", + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", l0, l1, invokeArgumentsVar @@ -196,18 +204,16 @@ internal class AsmBuilder internal constructor( val thisVar = 0 val argumentsVar = 1 visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() load(thisVar, OBJECT_TYPE) load(argumentsVar, MAP_TYPE) - invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false) - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -225,7 +231,7 @@ internal class AsmBuilder internal constructor( .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as Expression generatedInstance = new return new @@ -235,21 +241,21 @@ internal class AsmBuilder internal constructor( * Loads a constant from */ internal fun loadTConstant(value: T) { - if (classOfT.java in INLINABLE_NUMBERS) { + if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop()!! val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) - if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) return } - loadConstant(value as Any, T_TYPE) + loadConstant(value as Any, tType) } private fun box(): Unit = invokeMethodVisitor.invokestatic( - T_TYPE.internalName, + tType.internalName, "valueOf", - Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK), + Type.getMethodDescriptor(tType, PRIMITIVE_MASK), false ) @@ -263,16 +269,16 @@ internal class AsmBuilder internal constructor( private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() - getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) iconst(idx) visitInsn(AALOAD) checkcast(type) } - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE) + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) /** - * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * from it). */ @@ -292,7 +298,7 @@ internal class AsmBuilder internal constructor( if (mustBeBoxed) { box() - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) } return @@ -300,11 +306,11 @@ internal class AsmBuilder internal constructor( loadConstant(value, boxed) if (!mustBeBoxed) unbox() - else invokeMethodVisitor.checkcast(T_TYPE) + else invokeMethodVisitor.checkcast(tType) } /** - * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. + * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) @@ -319,7 +325,7 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) return } @@ -331,11 +337,11 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -343,15 +349,11 @@ internal class AsmBuilder internal constructor( } /** - * Loads algebra from according field of [AsmCompiledExpression] and casts it to class of [algebra] provided. + * Loads algebra from according field of the class and casts it to class of [algebra] provided. */ internal fun loadAlgebra() { loadThis() - - invokeMethodVisitor.run { - getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor) - checkcast(T_ALGEBRA_TYPE) - } + invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) } /** @@ -368,7 +370,12 @@ internal class AsmBuilder internal constructor( tArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { - repeat(tArity) { if (!typeStack.empty()) typeStack.pop() } + run loop@{ + repeat(tArity) { + if (typeStack.empty()) return@loop + typeStack.pop() + } + } invokeMethodVisitor.visitMethodInsn( opcode, @@ -378,12 +385,12 @@ internal class AsmBuilder internal constructor( opcode == Opcodes.INVOKEINTERFACE ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val isLastExpr = expectationStack.size == 1 val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -399,27 +406,18 @@ internal class AsmBuilder internal constructor( /** * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class.java to Type.BYTE_TYPE, - java.lang.Short::class.java to Type.SHORT_TYPE, - java.lang.Integer::class.java to Type.INT_TYPE, - java.lang.Long::class.java to Type.LONG_TYPE, - java.lang.Float::class.java to Type.FLOAT_TYPE, - java.lang.Double::class.java to Type.DOUBLE_TYPE + java.lang.Byte::class to Type.BYTE_TYPE, + java.lang.Short::class to Type.SHORT_TYPE, + java.lang.Integer::class to Type.INT_TYPE, + java.lang.Long::class to Type.LONG_TYPE, + java.lang.Float::class to Type.FLOAT_TYPE, + java.lang.Double::class to Type.DOUBLE_TYPE ) } - private val BOXED_TO_PRIMITIVES: Map by lazy { - hashMapOf( - java.lang.Byte::class.asm to Type.BYTE_TYPE, - java.lang.Short::class.asm to Type.SHORT_TYPE, - java.lang.Integer::class.asm to Type.INT_TYPE, - java.lang.Long::class.asm to Type.LONG_TYPE, - java.lang.Float::class.asm to Type.FLOAT_TYPE, - java.lang.Double::class.asm to Type.DOUBLE_TYPE - ) - } + private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } private val NUMBER_CONVERTER_METHODS: Map by lazy { hashMapOf( @@ -435,15 +433,15 @@ internal class AsmBuilder internal constructor( /** * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm - internal val NUMBER_TYPE: Type = java.lang.Number::class.asm - internal val MAP_TYPE: Type = java.util.Map::class.asm - internal val OBJECT_TYPE: Type = java.lang.Object::class.asm + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type = Array::class.asm - internal val ALGEBRA_TYPE: Type = Algebra::class.asm - internal val STRING_TYPE: Type = java.lang.String::class.asm + internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt deleted file mode 100644 index 7c4a9fc99..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ /dev/null @@ -1,18 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra - -/** - * [Expression] partial implementation to have it subclassed by actual implementations. Provides unified storage for - * objects needed to implement the expression. - * - * @property algebra the algebra to delegate calls. - * @property constants the constants array to have persistent objects to reference in [invoke]. - */ -internal abstract class AsmCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: Array -) : Expression { - abstract override fun invoke(arguments: Map): T -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt index 66bd039c3..41dbf5807 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -1,9 +1,10 @@ package scientifik.kmath.asm.internal import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression /** - * Creates a class name for [AsmCompiledExpression] subclassed to implement [mst] provided. + * Creates a class name for [Expression] subclassed to implement [mst] provided. * * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index 95d713b18..af5c1049d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -1,15 +1,17 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter +import org.objectweb.asm.FieldVisitor import org.objectweb.asm.MethodVisitor -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) -internal inline fun ClassWriter.visitMethod( +internal inline fun ClassWriter.visitField( access: Int, name: String, descriptor: String, signature: String?, - exceptions: Array?, - block: MethodVisitor.() -> Unit -): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block) + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt new file mode 100644 index 000000000..f47293687 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt @@ -0,0 +1,10 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Label +import org.objectweb.asm.commons.InstructionAdapter + +internal fun InstructionAdapter.label(): Label { + val l = Label() + visitLabel(l) + return l +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 7b0d346b7..aaae02ebb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor import org.objectweb.asm.commons.InstructionAdapter -fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) +internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) -fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 2e15a1a93..4c7a0d57e 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -22,7 +22,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: val aName = methodNameAdapters[name] ?: name val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE + val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType repeat(arity) { expectationStack.push(t) } return hasSpecific From c9de04a6106384c0128b4d0445b3e00bc0216379 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 25 Jun 2020 10:24:21 +0700 Subject: [PATCH 05/23] Make benchmarks 'naive' --- .../ast/ExpressionsInterpretersBenchmark.kt | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) rename examples/src/{benchmarks => main}/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt (70%) diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt similarity index 70% rename from examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt rename to examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt index c5474e1d2..17a70a4aa 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -1,21 +1,16 @@ package scientifik.kmath.ast -import org.openjdk.jmh.annotations.Benchmark -import org.openjdk.jmh.annotations.Scope -import org.openjdk.jmh.annotations.State import scientifik.kmath.asm.compile import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.expressionInField +import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Field import scientifik.kmath.operations.RealField import kotlin.random.Random +import kotlin.system.measureTimeMillis -@State(Scope.Benchmark) class ExpressionsInterpretersBenchmark { private val algebra: Field = RealField - private val random: Random = Random(1) - - @Benchmark fun functionalExpression() { val expr = algebra.expressionInField { variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) @@ -24,7 +19,6 @@ class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } - @Benchmark fun mstExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -33,7 +27,6 @@ class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } - @Benchmark fun asmExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -43,6 +36,7 @@ class ExpressionsInterpretersBenchmark { } private fun invokeAndSum(expr: Expression) { + val random = Random(0) var sum = 0.0 repeat(1000000) { @@ -52,3 +46,25 @@ class ExpressionsInterpretersBenchmark { println(sum) } } + +fun main() { + val benchmark = ExpressionsInterpretersBenchmark() + + val fe = measureTimeMillis { + benchmark.functionalExpression() + } + + println("fe=$fe") + + val mst = measureTimeMillis { + benchmark.mstExpression() + } + + println("mst=$mst") + + val asm = measureTimeMillis { + benchmark.asmExpression() + } + + println("asm=$asm") +} From b11a7f1426a707168d8174be4be43fd1d352fed3 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:29:31 +0700 Subject: [PATCH 06/23] Update README.md --- kmath-ast/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 4563e17cf..0e375f14d 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -48,12 +48,11 @@ This API is an extension to MST and MSTExpression APIs. You may optimize both MS ```kotlin RealField.mstInField { symbol("x") + 2 }.compile() -RealField.expression("2+2".parseMath()) +RealField.expression("x+2".parseMath()) ``` ### Known issues -- Using numeric algebras causes boxing and calling bridge methods. - The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead. - This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. From 23816d33665314016dfc46b67a34d09568e56640 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:42:13 +0700 Subject: [PATCH 07/23] Update KDoc comments, optimize imports --- .../kmath/asm/internal/AsmBuilder.kt | 101 +++++++++++++++--- .../kmath/asm/internal/classWriters.kt | 1 - .../kmath/asm/internal/specialization.kt | 4 +- 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 536d6136d..a291ba4ee 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -18,6 +18,7 @@ import kotlin.reflect.KClass * @param T the type of AsmExpression to unwrap. * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. + * @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. */ internal class AsmBuilder internal constructor( private val classOfT: KClass<*>, @@ -37,8 +38,19 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + /** + * ASM Type for [algebra] + */ private val tAlgebraType: Type = algebra::class.asm + + /** + * ASM type for [T] + */ internal val tType: Type = classOfT.asm + + /** + * ASM type for new class + */ private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** @@ -60,14 +72,30 @@ internal class AsmBuilder internal constructor( * Method visitor of `invoke` method of the subclass. */ private lateinit var invokeMethodVisitor: InstructionAdapter + + /** + * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. + */ internal var primitiveMode = false - @Suppress("PropertyName") - internal var PRIMITIVE_MASK: Type = OBJECT_TYPE + /** + * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMask: Type = OBJECT_TYPE - @Suppress("PropertyName") - internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE + /** + * Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. + */ + internal var primitiveMaskBoxed: Type = OBJECT_TYPE + + /** + * Stack of useful objects types on stack to verify types. + */ private val typeStack = Stack() + + /** + * Stack of useful objects types on stack expected by algebra calls. + */ internal val expectationStack: Stack = Stack().apply { push(tType) } /** @@ -86,8 +114,8 @@ internal class AsmBuilder internal constructor( if (SIGNATURE_LETTERS.containsKey(classOfT)) { primitiveMode = true - PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) - PRIMITIVE_MASK_BOXED = tType + primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) + primitiveMaskBoxed = tType } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { @@ -238,34 +266,43 @@ internal class AsmBuilder internal constructor( } /** - * Loads a constant from + * Loads a [T] constant from [constants]. */ internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop()!! val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) - if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) return } loadConstant(value as Any, tType) } + /** + * Boxes the current value and pushes it. + */ private fun box(): Unit = invokeMethodVisitor.invokestatic( tType.internalName, "valueOf", - Type.getMethodDescriptor(tType, PRIMITIVE_MASK), + Type.getMethodDescriptor(tType, primitiveMask), false ) + /** + * Unboxes the current boxed value and pushes it. + */ private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK), - Type.getMethodDescriptor(PRIMITIVE_MASK), + NUMBER_CONVERTER_METHODS.getValue(primitiveMask), + Type.getMethodDescriptor(primitiveMask), false ) + /** + * Loads [java.lang.Object] constant from constants. + */ private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() @@ -275,6 +312,9 @@ internal class AsmBuilder internal constructor( checkcast(type) } + /** + * Loads this variable. + */ private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) /** @@ -344,7 +384,7 @@ internal class AsmBuilder internal constructor( typeStack.push(tType) else { unbox() - typeStack.push(PRIMITIVE_MASK) + typeStack.push(primitiveMask) } } @@ -393,7 +433,7 @@ internal class AsmBuilder internal constructor( typeStack.push(tType) else { unbox() - typeStack.push(PRIMITIVE_MASK) + typeStack.push(primitiveMask) } } @@ -404,7 +444,7 @@ internal class AsmBuilder internal constructor( internal companion object { /** - * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. + * Maps JVM primitive numbers boxed types to their primitive ASM types. */ private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( @@ -417,8 +457,14 @@ internal class AsmBuilder internal constructor( ) } + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + /** + * Maps primitive ASM types to [Number] functions unboxing them. + */ private val NUMBER_CONVERTER_METHODS: Map by lazy { hashMapOf( Type.BYTE_TYPE to "byteValue", @@ -434,14 +480,41 @@ internal class AsmBuilder internal constructor( * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + + /** + * ASM type for [Expression]. + */ internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + + /** + * ASM type for [java.lang.Number]. + */ internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + + /** + * ASM type for [java.util.Map]. + */ internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + + /** + * ASM type for [java.lang.Object]. + */ internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } + /** + * ASM type for array of [java.lang.Object]. + */ @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + + /** + * ASM type for [Algebra]. + */ internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + + /** + * ASM type for [java.lang.String]. + */ internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index af5c1049d..00093aaa7 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -2,7 +2,6 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter import org.objectweb.asm.FieldVisitor -import org.objectweb.asm.MethodVisitor internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 4c7a0d57e..252509d59 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -22,7 +22,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: val aName = methodNameAdapters[name] ?: name val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType + val t = if (primitiveMode && hasSpecific) primitiveMask else tType repeat(arity) { expectationStack.push(t) } return hasSpecific @@ -52,7 +52,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri invokeAlgebraOperation( owner = owner, method = aName, - descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }), + descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), tArity = arity, opcode = Opcodes.INVOKEVIRTUAL ) From 46f99139e2850a4bfd63e3823b48b9e026fec017 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:45:33 +0700 Subject: [PATCH 08/23] Update number literal call in tests --- .../src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 4c2be811e..6ce769613 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -92,7 +92,7 @@ class TestAsmAlgebras { "+", (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), - 1 / 2 + number(2.0) * one + number(1) / 2 + number(2.0) * one ) }("x" to 2.0) @@ -101,7 +101,7 @@ class TestAsmAlgebras { "+", (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), - 1 / 2 + number(2.0) * one + number(1) / 2 + number(2.0) * one ) }.compile()("x" to 2.0) From 7faa48be582bc34f997b50104c3e3dc30f3ee979 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 19:46:32 +0700 Subject: [PATCH 09/23] Add zero call in MSTField test --- .../src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 6ce769613..079a6be0f 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -93,7 +93,7 @@ class TestAsmAlgebras { (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one - ) + ) + zero }("x" to 2.0) val res2 = RealField.mstInField { @@ -102,7 +102,7 @@ class TestAsmAlgebras { (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one - ) + ) + zero }.compile()("x" to 2.0) assertEquals(res1, res2) From 3528fa16dbc4c1750a56c7ee292044dbec2b4ef7 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 20:10:38 +0700 Subject: [PATCH 10/23] Add missing dependency in examples --- examples/build.gradle.kts | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index fb47c998f..73def3572 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -24,6 +24,7 @@ sourceSets { } dependencies { + implementation(project(":kmath-ast")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) From e2cc3c8efefa3b6a7d604990e1b1ca0a4860678b Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 25 Jun 2020 20:54:14 +0700 Subject: [PATCH 11/23] Specify type explicitly, minor implementation refactor --- .../kotlin/scientifik/kmath/asm/asm.kt | 4 +-- .../kmath/asm/internal/AsmBuilder.kt | 9 ++--- .../kmath/asm/internal/specialization.kt | 35 ++++++++----------- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index a3af80ccd..bb456e6eb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -44,7 +44,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< AsmBuilder.OBJECT_TYPE ), - tArity = 1 + expectedArity = 1 ) } is MST.Binary -> { @@ -64,7 +64,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< AsmBuilder.OBJECT_TYPE ), - tArity = 2 + expectedArity = 2 ) } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index a291ba4ee..ebe25e19f 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -76,7 +76,7 @@ internal class AsmBuilder internal constructor( /** * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. */ - internal var primitiveMode = false + internal var primitiveMode: Boolean = false /** * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. @@ -91,7 +91,7 @@ internal class AsmBuilder internal constructor( /** * Stack of useful objects types on stack to verify types. */ - private val typeStack = Stack() + private val typeStack: Stack = Stack() /** * Stack of useful objects types on stack expected by algebra calls. @@ -345,6 +345,7 @@ internal class AsmBuilder internal constructor( } loadConstant(value, boxed) + if (!mustBeBoxed) unbox() else invokeMethodVisitor.checkcast(tType) } @@ -407,11 +408,11 @@ internal class AsmBuilder internal constructor( owner: String, method: String, descriptor: String, - tArity: Int, + expectedArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { run loop@{ - repeat(tArity) { + repeat(expectedArity) { if (typeStack.empty()) return@loop typeStack.pop() } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 252509d59..e54acf6f9 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -5,11 +5,7 @@ import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map by lazy { - hashMapOf( - "+" to "add", - "*" to "multiply", - "/" to "divide" - ) + hashMapOf("+" to "add", "*" to "multiply", "/" to "divide") } /** @@ -19,12 +15,10 @@ private val methodNameAdapters: Map by lazy { * @return `true` if contains, else `false`. */ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name - - val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null + val theName = methodNameAdapters[name] ?: name + val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType repeat(arity) { expectationStack.push(t) } - return hasSpecific } @@ -35,25 +29,24 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * @return `true` if contains, else `false`. */ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - val aName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name] ?: name - val method = - context.javaClass.methods.find { - var suitableSignature = it.name == aName && it.parameters.size == arity + context.javaClass.methods.find { + var suitableSignature = it.name == theName && it.parameters.size == arity - if (primitiveMode && it.isBridge) - suitableSignature = false + if (primitiveMode && it.isBridge) + suitableSignature = false - suitableSignature - } ?: return false + suitableSignature + } ?: return false - val owner = context::class.java.name.replace('.', '/') + val owner = context::class.asm invokeAlgebraOperation( - owner = owner, - method = aName, + owner = owner.internalName, + method = theName, descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), - tArity = arity, + expectedArity = arity, opcode = Opcodes.INVOKEVIRTUAL ) From 5ab6960e9b1a0b51ffca0cf30ee5eb83c550989c Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 15:55:01 +0700 Subject: [PATCH 12/23] Add mapIntrinsics.kt, update specialization mappings --- .../kotlin/scientifik/kmath/asm/asm.kt | 1 + .../kmath/asm/internal/AsmBuilder.kt | 37 +++++++-------- .../kmath/asm/internal/mapIntrinsics.kt | 7 +++ .../kmath/asm/internal/specialization.kt | 15 +++++-- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 2 +- .../scietifik/kmath/asm/TestAsmExpressions.kt | 2 +- .../scietifik/kmath/asm/TestSpecialization.kt | 45 +++++++++++++++++++ .../kotlin/scietifik/kmath/ast/AsmTest.kt | 2 +- 8 files changed, 84 insertions(+), 27 deletions(-) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index bb456e6eb..af39d9091 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -47,6 +47,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< expectedArity = 1 ) } + is MST.Binary -> { loadAlgebra() if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index ebe25e19f..89c9dca9d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -354,31 +354,23 @@ internal class AsmBuilder internal constructor( * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) + load(invokeArgumentsVar, MAP_TYPE) + aconst(name) - if (defaultValue != null) { - loadStringConstant(name) + if (defaultValue != null) loadTConstant(defaultValue) + else + aconst(null) - invokeinterface( - MAP_TYPE.internalName, - "getOrDefault", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) - ) - - invokeMethodVisitor.checkcast(tType) - return - } - - loadStringConstant(name) - - invokeinterface( - MAP_TYPE.internalName, - "get", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE), + false ) - invokeMethodVisitor.checkcast(tType) + checkcast(tType) + val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT) @@ -517,5 +509,10 @@ internal class AsmBuilder internal constructor( * ASM type for [java.lang.String]. */ internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } + + /** + * ASM type for MapIntrinsics. + */ + internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt new file mode 100644 index 000000000..7f7126b55 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt @@ -0,0 +1,7 @@ +@file:JvmName("MapIntrinsics") + +package scientifik.kmath.asm.internal + +internal fun Map.getOrFail(key: K, default: V?): V { + return this[key] ?: default ?: error("Parameter not found: $key") +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index e54acf6f9..a8d5a605f 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -4,8 +4,15 @@ import org.objectweb.asm.Opcodes import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra -private val methodNameAdapters: Map by lazy { - hashMapOf("+" to "add", "*" to "multiply", "/" to "divide") +private val methodNameAdapters: Map, String> by lazy { + hashMapOf( + "+" to 2 to "add", + "*" to 2 to "multiply", + "/" to 2 to "divide", + "+" to 1 to "unaryPlus", + "-" to 1 to "unaryMinus", + "-" to 2 to "minus" + ) } /** @@ -15,7 +22,7 @@ private val methodNameAdapters: Map by lazy { * @return `true` if contains, else `false`. */ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { - val theName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name to arity] ?: name val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType repeat(arity) { expectationStack.push(t) } @@ -29,7 +36,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * @return `true` if contains, else `false`. */ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - val theName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name to arity] ?: name context.javaClass.methods.find { var suitableSignature = it.name == theName && it.parameters.size == arity diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 079a6be0f..3acc6eb28 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmAlgebras { +internal class TestAsmAlgebras { @Test fun space() { val res1 = ByteRing.mstInSpace { diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 824201aa7..36c254c38 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmExpressions { +internal class TestAsmExpressions { @Test fun testUnaryOperationInvocation() { val expression = RealField.mstInSpace { -symbol("x") }.compile() diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt new file mode 100644 index 000000000..f3b07df56 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt @@ -0,0 +1,45 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(2.0, res) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(-2.0, res) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(4.0, res) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(0.0, res) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(1.0, res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 08d7fff47..23203172e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField import kotlin.test.Test import kotlin.test.assertEquals -class AsmTest { +internal class AsmTest { @Test fun `compile MST`() { val mst = "2+2*(2+2)".parseMath() From 90c287d42fe2b2f7703ae4dad80da824bf05e125 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 15:59:24 +0700 Subject: [PATCH 13/23] Add tests for MapInstrinsics --- .../scietifik/kmath/asm/TestAsmVariables.kt | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt new file mode 100644 index 000000000..aafc75448 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt @@ -0,0 +1,22 @@ +package scietifik.kmath.asm + +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestAsmVariables { + @Test + fun testVariableWithoutDefault() { + val expr = ByteRing.mstInRing { symbol("x") } + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testVariableWithoutDefaultFails() { + val expr = ByteRing.mstInRing { symbol("x") } + assertFailsWith { expr() } + } +} From 092728b1c328a6e515de5c68ade5e99bc460c4e5 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 16:01:50 +0700 Subject: [PATCH 14/23] Replace Stack with ArrayDeque --- .../scientifik/kmath/asm/internal/AsmBuilder.kt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 89c9dca9d..c9f797787 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -91,12 +91,12 @@ internal class AsmBuilder internal constructor( /** * Stack of useful objects types on stack to verify types. */ - private val typeStack: Stack = Stack() + private val typeStack: ArrayDeque = ArrayDeque() /** * Stack of useful objects types on stack expected by algebra calls. */ - internal val expectationStack: Stack = Stack().apply { push(tType) } + internal val expectationStack: ArrayDeque = ArrayDeque().apply { push(tType) } /** * The cache for instance built by this builder. @@ -270,7 +270,7 @@ internal class AsmBuilder internal constructor( */ internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) @@ -371,7 +371,7 @@ internal class AsmBuilder internal constructor( checkcast(tType) - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() if (expectedType.sort == Type.OBJECT) typeStack.push(tType) @@ -405,7 +405,7 @@ internal class AsmBuilder internal constructor( ) { run loop@{ repeat(expectedArity) { - if (typeStack.empty()) return@loop + if (typeStack.isEmpty()) return@loop typeStack.pop() } } @@ -420,7 +420,7 @@ internal class AsmBuilder internal constructor( invokeMethodVisitor.checkcast(tType) val isLastExpr = expectationStack.size == 1 - val expectedType = expectationStack.pop()!! + val expectedType = expectationStack.pop() if (expectedType.sort == Type.OBJECT || isLastExpr) typeStack.push(tType) From 2df97ca4c3e2f33680dae70181e06c3dc2fe1349 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 16:05:13 +0700 Subject: [PATCH 15/23] Update README.md, add suppression --- kmath-ast/README.md | 4 +++- .../kotlin/scientifik/kmath/asm/internal/classWriters.kt | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 0e375f14d..12d425460 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -24,6 +24,7 @@ For example, the following builder: package scientifik.kmath.asm.generated; import java.util.Map; +import scientifik.kmath.asm.internal.MapIntrinsics; import scientifik.kmath.expressions.Expression; import scientifik.kmath.operations.RealField; @@ -37,9 +38,10 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression arguments) { - return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); + return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D); } } + ``` ### Example Usage diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index 00093aaa7..7f0770b28 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -3,6 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter import org.objectweb.asm.FieldVisitor +@Suppress("FunctionName") internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) From 0ee1d31571a59659fbac33482a4e8cce16cf83fb Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 20:57:47 +0700 Subject: [PATCH 16/23] Fix MSTField and MSTRing invalid unary operation, update according ASM tests --- .../kotlin/scientifik/kmath/ast/MSTAlgebra.kt | 37 +++++++++---------- ...ialization.kt => TestAsmSpecialization.kt} | 23 ++++++------ 2 files changed, 29 insertions(+), 31 deletions(-) rename kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/{TestSpecialization.kt => TestAsmSpecialization.kt} (68%) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt index 07194a7bb..f741fc8c4 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -34,43 +34,40 @@ object MSTSpace : Space, NumericAlgebra { } object MSTRing : Ring, NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) - override val zero: MST = MSTSpace.number(0.0) override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } -object MSTField : Field{ - override fun symbol(value: String): MST = MST.Symbolic(value) - override fun number(value: Number): MST = MST.Numeric(value) - +object MSTField : Field { override val zero: MST = MSTSpace.number(0.0) override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MST.Numeric(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun divide(a: MST, b: MST): MST = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt similarity index 68% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt rename to kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt index f3b07df56..b571e076f 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt @@ -7,39 +7,40 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -internal class TestSpecialization { +internal class TestAsmSpecialization { @Test fun testUnaryPlus() { val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(2.0, res) + assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(-2.0, res) + assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(4.0, res) + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(0.0, res) + assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(1.0, res) + assertEquals(1.0, expr("x" to 2.0)) } } From d962ab4d11298fd6cd699edd584c3c12c260d11e Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 21:02:22 +0700 Subject: [PATCH 17/23] Rename and refactor MstAlgebra (ex-MSTAlgebra) (and its subclasses), MstExpression (ex-MSTExpression) --- .../scientifik/kmath/ast/MSTExpression.kt | 55 ------------------- .../ast/{MSTAlgebra.kt => MstAlgebra.kt} | 41 +++++++------- .../scientifik/kmath/ast/MstExpression.kt | 55 +++++++++++++++++++ .../kotlin/scientifik/kmath/asm/asm.kt | 6 +- 4 files changed, 78 insertions(+), 79 deletions(-) delete mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt rename kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/{MSTAlgebra.kt => MstAlgebra.kt} (62%) create mode 100644 kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt deleted file mode 100644 index 61703cac7..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.FunctionalExpressionField -import scientifik.kmath.expressions.FunctionalExpressionRing -import scientifik.kmath.expressions.FunctionalExpressionSpace -import scientifik.kmath.operations.* - -/** - * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. - */ -class MSTExpression(val algebra: Algebra, val mst: MST) : Expression { - - /** - * Substitute algebra raw value - */ - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra{ - override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right) - - override fun number(value: Number): T = if(algebra is NumericAlgebra){ - algebra.number(value) - } else{ - error("Numeric nodes are not supported by $this") - } - } - - override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} - - -inline fun , E : Algebra> A.mst( - mstAlgebra: E, - block: E.() -> MST -): MSTExpression = MSTExpression(this, mstAlgebra.block()) - -inline fun Space.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - MSTExpression(this, MSTSpace.block()) - -inline fun Ring.mstInRing(block: MSTRing.() -> MST): MSTExpression = - MSTExpression(this, MSTRing.block()) - -inline fun Field.mstInField(block: MSTField.() -> MST): MSTExpression = - MSTExpression(this, MSTField.block()) - -inline fun > FunctionalExpressionSpace.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - algebra.mstInSpace(block) - -inline fun > FunctionalExpressionRing.mstInRing(block: MSTRing.() -> MST): MSTExpression = - algebra.mstInRing(block) - -inline fun > FunctionalExpressionField.mstInField(block: MSTField.() -> MST): MSTExpression = - algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt similarity index 62% rename from kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt rename to kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt index f741fc8c4..007cf57c4 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt @@ -2,7 +2,7 @@ package scientifik.kmath.ast import scientifik.kmath.operations.* -object MSTAlgebra : NumericAlgebra { +object MstAlgebra : NumericAlgebra { override fun number(value: Number): MST = MST.Numeric(value) override fun symbol(value: String): MST = MST.Symbolic(value) @@ -14,12 +14,11 @@ object MSTAlgebra : NumericAlgebra { MST.Binary(operation, left, right) } -object MSTSpace : Space, NumericAlgebra { +object MstSpace : Space, NumericAlgebra { override val zero: MST = number(0.0) - override fun number(value: Number): MST = MST.Numeric(value) - - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) @@ -28,46 +27,46 @@ object MSTSpace : Space, NumericAlgebra { binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } -object MSTRing : Ring, NumericAlgebra { - override val zero: MST = MSTSpace.number(0.0) +object MstRing : Ring, NumericAlgebra { + override val zero: MST = number(0.0) override val one: MST = number(1.0) - override fun number(value: Number): MST = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } -object MSTField : Field { - override val zero: MST = MSTSpace.number(0.0) +object MstField : Field { + override val zero: MST = number(0.0) override val one: MST = number(1.0) - override fun symbol(value: String): MST = MST.Symbolic(value) - override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun number(value: Number): MST = MstAlgebra.number(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt new file mode 100644 index 000000000..1468c3ad4 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -0,0 +1,55 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.FunctionalExpressionField +import scientifik.kmath.expressions.FunctionalExpressionRing +import scientifik.kmath.expressions.FunctionalExpressionSpace +import scientifik.kmath.operations.* + +/** + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. + */ +class MstExpression(val algebra: Algebra, val mst: MST) : Expression { + + /** + * Substitute algebra raw value + */ + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T = + algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if (algebra is NumericAlgebra) + algebra.number(value) + else + error("Numeric nodes are not supported by $this") + } + + override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} + + +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MstExpression = MstExpression(this, mstAlgebra.block()) + +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = + MstExpression(this, MstSpace.block()) + +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = + MstExpression(this, MstRing.block()) + +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = + MstExpression(this, MstField.block()) + +inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression = + algebra.mstInSpace(block) + +inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression = + algebra.mstInRing(block) + +inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression = + algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index af39d9091..468ed01ba 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -6,7 +6,7 @@ import scientifik.kmath.asm.internal.buildExpectationStack import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTExpression +import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra @@ -80,6 +80,6 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) /** - * Optimize performance of an [MSTExpression] using ASM codegen + * Optimize performance of an [MstExpression] using ASM codegen */ -inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) +inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra) From ec46f5cf229109fe266e888f29b301d62bb2d3ea Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 21:02:31 +0700 Subject: [PATCH 18/23] Update README.md --- kmath-ast/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 12d425460..62b18b4b5 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -46,7 +46,7 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression Date: Fri, 26 Jun 2020 21:39:39 +0700 Subject: [PATCH 19/23] Add explicit toRegex call to have better IDE support --- .../kotlin/scientifik/kmath/ast/parser.kt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt index cec61a8ff..30a92c5ae 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt @@ -16,15 +16,15 @@ import scientifik.kmath.operations.SpaceOperations * TODO move to common */ private object ArithmeticsEvaluator : Grammar() { - val num by token("-?[\\d.]+(?:[eE]-?\\d+)?") - val lpar by token("\\(") - val rpar by token("\\)") - val mul by token("\\*") - val pow by token("\\^") - val div by token("/") - val minus by token("-") - val plus by token("\\+") - val ws by token("\\s+", ignore = true) + val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) + val lpar by token("\\(".toRegex()) + val rpar by token("\\)".toRegex()) + val mul by token("\\*".toRegex()) + val pow by token("\\^".toRegex()) + val div by token("/".toRegex()) + val minus by token("-".toRegex()) + val plus by token("\\+".toRegex()) + val ws by token("\\s+".toRegex(), ignore = true) val number: Parser by num use { MST.Numeric(text.toDouble()) } From bf89aa09e561da09bde2df50355d7adf312c75cc Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 22:05:42 +0700 Subject: [PATCH 20/23] Add static imports for Opcodes --- .../kmath/asm/internal/AsmBuilder.kt | 26 +++++++++---------- .../kmath/asm/internal/specialization.kt | 4 +-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index c9f797787..cea6be933 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -1,8 +1,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.AALOAD -import org.objectweb.asm.Opcodes.RETURN +import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST @@ -120,8 +119,8 @@ internal class AsmBuilder internal constructor( val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( - Opcodes.V1_8, - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, classType.internalName, "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", OBJECT_TYPE.internalName, @@ -129,7 +128,7 @@ internal class AsmBuilder internal constructor( ) visitField( - access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + access = ACC_PRIVATE or ACC_FINAL, name = "algebra", descriptor = tAlgebraType.descriptor, signature = null, @@ -138,7 +137,7 @@ internal class AsmBuilder internal constructor( ) visitField( - access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + access = ACC_PRIVATE or ACC_FINAL, name = "constants", descriptor = OBJECT_ARRAY_TYPE.descriptor, signature = null, @@ -147,7 +146,7 @@ internal class AsmBuilder internal constructor( ) visitMethod( - Opcodes.ACC_PUBLIC, + ACC_PUBLIC, "", Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), null, @@ -187,7 +186,7 @@ internal class AsmBuilder internal constructor( } visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + ACC_PUBLIC or ACC_FINAL, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", @@ -223,7 +222,7 @@ internal class AsmBuilder internal constructor( } visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, "invoke", Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), null, @@ -351,7 +350,8 @@ internal class AsmBuilder internal constructor( } /** - * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be + * provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, MAP_TYPE) @@ -391,7 +391,7 @@ internal class AsmBuilder internal constructor( /** * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is - * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be + * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be * called before the arguments and this operation. * * The result is casted to [T] automatically. @@ -401,7 +401,7 @@ internal class AsmBuilder internal constructor( method: String, descriptor: String, expectedArity: Int, - opcode: Int = Opcodes.INVOKEINTERFACE + opcode: Int = INVOKEINTERFACE ) { run loop@{ repeat(expectedArity) { @@ -415,7 +415,7 @@ internal class AsmBuilder internal constructor( owner, method, descriptor, - opcode == Opcodes.INVOKEINTERFACE + opcode == INVOKEINTERFACE ) invokeMethodVisitor.checkcast(tType) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a8d5a605f..a6d2c045b 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -1,6 +1,6 @@ package scientifik.kmath.asm.internal -import org.objectweb.asm.Opcodes +import org.objectweb.asm.Opcodes.INVOKEVIRTUAL import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra @@ -54,7 +54,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri method = theName, descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), expectedArity = arity, - opcode = Opcodes.INVOKEVIRTUAL + opcode = INVOKEVIRTUAL ) return true From 4b067f7a97c2a147b046944dfb66ddb5c78b5577 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 27 Jun 2020 12:19:43 +0300 Subject: [PATCH 21/23] DoubleBuffer -> RealBuffer. Number algebra refactoring. --- doc/buffers.md | 2 +- .../kmath/structures/BufferBenchmark.kt | 4 +- .../structures/StructureReadBenchmark.kt | 2 +- .../structures/StructureWriteBenchmark.kt | 4 +- .../commons/transform/Transformations.kt | 2 +- .../kmath/domains/HyperSquareDomain.kt | 6 +- .../scientifik/kmath/linear/BufferMatrix.kt | 8 +- .../scientifik/kmath/structures/Buffers.kt | 62 +----------- .../kmath/structures/FlaggedBuffer.kt | 53 +++++++++++ .../scientifik/kmath/structures/IntBuffer.kt | 20 ++++ .../scientifik/kmath/structures/LongBuffer.kt | 19 ++++ .../{DoubleBuffer.kt => RealBuffer.kt} | 12 +-- .../kmath/structures/RealBufferField.kt | 94 +++++++++---------- .../kmath/structures/RealNDField.kt | 6 +- .../kmath/structures/ShortBuffer.kt | 20 ++++ .../scientifik/kmath/streaming/BufferFlow.kt | 8 +- .../scientifik/kmath/real/RealVector.kt | 4 +- .../scientifik/kmath/real/realBuffer.kt | 6 +- .../scientifik/kmath/real/realMatrix.kt | 10 +- .../scientifik/kmath/histogram/Histogram.kt | 6 +- .../kmath/histogram/RealHistogram.kt | 2 +- 21 files changed, 205 insertions(+), 145 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt rename kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/{DoubleBuffer.kt => RealBuffer.kt} (59%) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt diff --git a/doc/buffers.md b/doc/buffers.md index b0b7489b3..52a9df86e 100644 --- a/doc/buffers.md +++ b/doc/buffers.md @@ -2,7 +2,7 @@ Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). There are different types of buffers: -* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. +* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays. * Boxing `ListBuffer` wrapping a list * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * `MemoryBuffer` allows direct allocation of objects in continuous memory block. diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt index 9676b5e4a..e40b0c4b7 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt @@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex class BufferBenchmark { @Benchmark - fun genericDoubleBufferReadWrite() { - val buffer = DoubleBuffer(size){it.toDouble()} + fun genericRealBufferReadWrite() { + val buffer = RealBuffer(size){it.toDouble()} (0 until size).forEach { buffer[it] diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt index ecfb4ab20..a33fdb2c4 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt @@ -6,7 +6,7 @@ fun main(args: Array) { val n = 6000 val array = DoubleArray(n * n) { 1.0 } - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) val strides = DefaultStrides(intArrayOf(n, n)) val structure = BufferNDStructure(strides, buffer) diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt index 2d16cc8f4..0241f12ad 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt @@ -26,10 +26,10 @@ fun main(args: Array) { } println("Array mapping finished in $time2 millis") - val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 }) + val buffer = RealBuffer(DoubleArray(n * n) { 1.0 }) val time3 = measureTimeMillis { - val target = DoubleBuffer(DoubleArray(n * n)) + val target = RealBuffer(DoubleArray(n * n)) val res = array.forEachIndexed { index, value -> target[index] = value + 1 } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt index bcb3ea87b..eb1b5b69a 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt @@ -18,7 +18,7 @@ object Transformations { private fun Buffer.toArray(): Array = Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } - private fun Buffer.asArray() = if (this is DoubleBuffer) { + private fun Buffer.asArray() = if (this is RealBuffer) { array } else { DoubleArray(size) { i -> get(i) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt index 21912b87c..e0019c96b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt @@ -16,7 +16,7 @@ package scientifik.kmath.domains import scientifik.kmath.linear.Point -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.indices /** @@ -25,7 +25,7 @@ import scientifik.kmath.structures.indices * * @author Alexander Nozik */ -class HyperSquareDomain(private val lower: DoubleBuffer, private val upper: DoubleBuffer) : RealDomain { +class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { override operator fun contains(point: Point): Boolean = point.indices.all { i -> point[i] in lower[i]..upper[i] @@ -49,7 +49,7 @@ class HyperSquareDomain(private val lower: DoubleBuffer, private val upper: Doub else -> point[i] } } - return DoubleBuffer(*res) + return RealBuffer(*res) } override fun volume(): Double { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 73b18b810..c4c38284b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext { override val elementContext get() = RealField override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix { - val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } return BufferMatrix(rows, columns, buffer) } - override inline fun point(size: Int, initializer: (Int) -> Double): Point = DoubleBuffer(size,initializer) + override inline fun point(size: Int, initializer: (Int) -> Double): Point = RealBuffer(size,initializer) } class BufferMatrix( @@ -102,7 +102,7 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix.unsafeArray(): DoubleArray = if (this is DoubleBuffer) { + fun Buffer.unsafeArray(): DoubleArray = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } @@ -119,6 +119,6 @@ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { companion object { - inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { + inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { val array = DoubleArray(size) { initializer(it) } - return DoubleBuffer(array) + return RealBuffer(array) } /** @@ -51,7 +51,7 @@ interface Buffer { inline fun auto(type: KClass, size: Int, crossinline initializer: (Int) -> T): Buffer { //TODO add resolution based on Annotation or companion resolution return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer @@ -93,7 +93,7 @@ interface MutableBuffer : Buffer { @Suppress("UNCHECKED_CAST") inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer { return when (type) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer + Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer @@ -109,12 +109,11 @@ interface MutableBuffer : Buffer { auto(T::class, size, initializer) val real: MutableBufferFactory = { size: Int, initializer: (Int) -> Double -> - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) } } } - inline class ListBuffer(val list: List) : Buffer { override val size: Int @@ -163,57 +162,6 @@ class ArrayBuffer(private val array: Array) : MutableBuffer { fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) -inline class ShortBuffer(val array: ShortArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Short = array[index] - - override fun set(index: Int, value: Short) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) - -} - -fun ShortArray.asBuffer() = ShortBuffer(this) - -inline class IntBuffer(val array: IntArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Int = array[index] - - override fun set(index: Int, value: Int) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = IntBuffer(array.copyOf()) - -} - -fun IntArray.asBuffer() = IntBuffer(this) - -inline class LongBuffer(val array: LongArray) : MutableBuffer { - override val size: Int get() = array.size - - override fun get(index: Int): Long = array[index] - - override fun set(index: Int, value: Long) { - array[index] = value - } - - override fun iterator() = array.iterator() - - override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) - -} - -fun LongArray.asBuffer() = LongBuffer(this) - inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { override val size: Int get() = buffer.size diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt new file mode 100644 index 000000000..749e4eeec --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt @@ -0,0 +1,53 @@ +package scientifik.kmath.structures + +import kotlin.experimental.and + +enum class ValueFlag(val mask: Byte) { + NAN(0b0000_0001), + MISSING(0b0000_0010), + NEGATIVE_INFINITY(0b0000_0100), + POSITIVE_INFINITY(0b0000_1000) +} + +/** + * A buffer with flagged values + */ +interface FlaggedBuffer : Buffer { + fun getFlag(index: Int): Byte +} + +/** + * The value is valid if all flags are down + */ +fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte() + +fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte() + +fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING) + +/** + * A real buffer which supports flags for each value like NaN or Missing + */ +class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer, Buffer { + init { + require(values.size == flags.size) { "Values and flags must have the same dimensions" } + } + + override fun getFlag(index: Int): Byte = flags[index] + + override val size: Int get() = values.size + + override fun get(index: Int): Double? = if (isValid(index)) values[index] else null + + override fun iterator(): Iterator = values.indices.asSequence().map { + if (isValid(it)) values[it] else null + }.iterator() +} + +inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { + for(i in indices){ + if(isValid(i)){ + block(values[i]) + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt new file mode 100644 index 000000000..a354c5de0 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class IntBuffer(val array: IntArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Int = array[index] + + override fun set(index: Int, value: Int) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + IntBuffer(array.copyOf()) + +} + + +fun IntArray.asBuffer() = IntBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt new file mode 100644 index 000000000..fa6229a71 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt @@ -0,0 +1,19 @@ +package scientifik.kmath.structures + +inline class LongBuffer(val array: LongArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Long = array[index] + + override fun set(index: Int, value: Long) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + LongBuffer(array.copyOf()) + +} + +fun LongArray.asBuffer() = LongBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt similarity index 59% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt index c0b7f713b..f48ace3a9 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/DoubleBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt @@ -1,6 +1,6 @@ package scientifik.kmath.structures -inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { +inline class RealBuffer(val array: DoubleArray) : MutableBuffer { override val size: Int get() = array.size override fun get(index: Int): Double = array[index] @@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { override fun iterator() = array.iterator() override fun copy(): MutableBuffer = - DoubleBuffer(array.copyOf()) + RealBuffer(array.copyOf()) } @Suppress("FunctionName") -inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) }) +inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) @Suppress("FunctionName") -fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles) +fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) /** * Transform buffer of doubles into array for high performance operations */ val MutableBuffer.array: DoubleArray - get() = if (this is DoubleBuffer) { + get() = if (this is RealBuffer) { array } else { DoubleArray(size) { get(it) } } -fun DoubleArray.asBuffer() = DoubleBuffer(this) \ No newline at end of file +fun DoubleArray.asBuffer() = RealBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 88c8c29db..a91000a2a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -9,143 +9,143 @@ import kotlin.math.* * A simple field over linear buffers of [Double] */ object RealBufferFieldOperations : ExtendedFieldOperations> { - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) + RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) + RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { val kValue = k.toDouble() - return if (a is DoubleBuffer) { + return if (a is RealBuffer) { val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) + RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) + RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) + RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) + RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) + RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) } else { - DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) + RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } } - override fun sin(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun sin(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } } - override fun cos(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun cos(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) + RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) } } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun power(arg: Buffer, pow: Number): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) } } - override fun exp(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun exp(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) + RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) } } - override fun ln(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { + override fun ln(arg: Buffer): RealBuffer { + return if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) + RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } } } class RealBufferField(val size: Int) : ExtendedField> { - override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } + override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } - override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } + override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, k) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, b) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) } - override fun sin(arg: Buffer): DoubleBuffer { + override fun sin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sin(arg) } - override fun cos(arg: Buffer): DoubleBuffer { + override fun cos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cos(arg) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) } - override fun exp(arg: Buffer): DoubleBuffer { + override fun exp(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.exp(arg) } - override fun ln(arg: Buffer): DoubleBuffer { + override fun ln(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 8c1bd4239..4a5f10790 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) : override val one by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - DoubleBuffer(DoubleArray(size) { initializer(it) }) + RealBuffer(DoubleArray(size) { initializer(it) }) /** * Inline transform an NDStructure to @@ -82,7 +82,7 @@ class RealNDField(override val shape: IntArray) : */ inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } - return BufferedNDFieldElement(this, DoubleBuffer(array)) + return BufferedNDFieldElement(this, RealBuffer(array)) } /** @@ -96,7 +96,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int */ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } - return BufferedNDFieldElement(context, DoubleBuffer(array)) + return BufferedNDFieldElement(context, RealBuffer(array)) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt new file mode 100644 index 000000000..f4b2f7d13 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +inline class ShortBuffer(val array: ShortArray) : MutableBuffer { + override val size: Int get() = array.size + + override fun get(index: Int): Short = array[index] + + override fun set(index: Int, value: Short) { + array[index] = value + } + + override fun iterator() = array.iterator() + + override fun copy(): MutableBuffer = + ShortBuffer(array.copyOf()) + +} + + +fun ShortArray.asBuffer() = ShortBuffer(this) \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt index bef21a680..54da66bb7 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt @@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.* import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer /** @@ -45,7 +45,7 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< /** * Specialized flow chunker for real buffer */ -fun Flow.chunked(bufferSize: Int): Flow = flow { +fun Flow.chunked(bufferSize: Int): Flow = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } if (this@chunked is BlockingRealChain) { @@ -61,13 +61,13 @@ fun Flow.chunked(bufferSize: Int): Flow = flow { array[counter] = element counter++ if (counter == bufferSize) { - val buffer = DoubleBuffer(array) + val buffer = RealBuffer(array) emit(buffer) counter = 0 } } if (counter > 0) { - emit(DoubleBuffer(counter) { array[it] }) + emit(RealBuffer(counter) { array[it] }) } } } diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index 23c7e19cb..2b89904e3 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -7,7 +7,7 @@ import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField import scientifik.kmath.operations.SpaceElement import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asIterable import kotlin.math.sqrt @@ -41,7 +41,7 @@ inline class RealVector(private val point: Point) : private val spaceCache = HashMap>() inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = - RealVector(DoubleBuffer(dim, initializer)) + RealVector(RealBuffer(dim, initializer)) operator fun invoke(vararg values: Double): RealVector = values.asVector() diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt index d9ee4d90b..82c0e86b2 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt @@ -1,8 +1,8 @@ package scientifik.kmath.real -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer /** - * Simplified [DoubleBuffer] to array comparison + * Simplified [RealBuffer] to array comparison */ -fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file +fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 0f4ccf2a8..65f86eec7 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.Matrix +import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asIterable import kotlin.math.pow @@ -133,22 +133,22 @@ fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = fun Matrix.extractColumn(columnIndex: Int): RealMatrix = extractColumns(columnIndex..columnIndex) -fun Matrix.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> val column = columns[j] with(elementContext) { sum(column.asIterable()) } } -fun Matrix.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") } -fun Matrix.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> +fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> columns[j].asIterable().average() } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 5199669f5..43d50ad20 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -3,7 +3,7 @@ package scientifik.kmath.histogram import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer -import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.RealBuffer /** * The bin in the histogram. The histogram is by definition always done in the real space @@ -43,9 +43,9 @@ interface MutableHistogram> : Histogram { fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) fun MutableHistogram.put(vararg point: Number) = - put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) + put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(DoubleBuffer(point)) +fun MutableHistogram.put(vararg point: Double) = put(RealBuffer(point)) fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index f9d815421..628a68461 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -50,7 +50,7 @@ class RealHistogram( override val dimension: Int get() = lower.size - private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } + private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks From efcfb4425366ee32c2aafb733c1dab8567ea9abe Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 20:04:22 +0700 Subject: [PATCH 22/23] Refactor Algebra call building --- .../kotlin/scientifik/kmath/asm/asm.kt | 51 +++++-------------- .../kmath/asm/internal/specialization.kt | 29 ++++++++++- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 468ed01ba..ef2330533 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,10 +1,8 @@ package scientifik.kmath.asm -import org.objectweb.asm.Type import scientifik.kmath.asm.internal.AsmBuilder -import scientifik.kmath.asm.internal.buildExpectationStack +import scientifik.kmath.asm.internal.buildAlgebraOperationCall import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression @@ -29,44 +27,21 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< loadTConstant(constant) } - is MST.Unary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) - visit(node.value) + is MST.Unary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "unaryOperation", + arity = 1 + ) { visit(node.value) } - if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "unaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - expectedArity = 1 - ) - } - - is MST.Binary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) + is MST.Binary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "binaryOperation", + arity = 2 + ) { visit(node.left) visit(node.right) - - if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "binaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - expectedArity = 2 - ) } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a6d2c045b..003dc47dd 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -21,7 +21,7 @@ private val methodNameAdapters: Map, String> by lazy { * * @return `true` if contains, else `false`. */ -internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { +private fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { val theName = methodNameAdapters[name to arity] ?: name val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType @@ -35,7 +35,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * * @return `true` if contains, else `false`. */ -internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val theName = methodNameAdapters[name to arity] ?: name context.javaClass.methods.find { @@ -59,3 +59,28 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri return true } + +internal fun AsmBuilder.buildAlgebraOperationCall( + context: Algebra, + name: String, + fallbackMethodName: String, + arity: Int, + parameters: AsmBuilder.() -> Unit +) { + loadAlgebra() + if (!buildExpectationStack(context, name, arity)) loadStringConstant(name) + parameters() + + if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = fallbackMethodName, + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + *Array(arity) { AsmBuilder.OBJECT_TYPE } + ), + + expectedArity = arity + ) +} From e98fc126c4118bcf8750ea85bb25bca2d85b145f Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 20:15:14 +0700 Subject: [PATCH 23/23] Merge various codegen utilities into one file --- .../kmath/asm/internal/buildName.kt | 22 ------- .../kmath/asm/internal/classWriters.kt | 17 ----- .../scientifik/kmath/asm/internal/classes.kt | 7 -- .../{specialization.kt => codegenUtils.kt} | 64 ++++++++++++++++++- .../kmath/asm/internal/instructionAdapters.kt | 10 --- .../kmath/asm/internal/methodVisitors.kt | 9 --- 6 files changed, 63 insertions(+), 66 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{specialization.kt => codegenUtils.kt} (57%) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt deleted file mode 100644 index 41dbf5807..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ /dev/null @@ -1,22 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.ast.MST -import scientifik.kmath.expressions.Expression - -/** - * Creates a class name for [Expression] subclassed to implement [mst] provided. - * - * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there - * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. - */ -internal tailrec fun buildName(mst: MST, collision: Int = 0): String { - val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(mst, collision + 1) -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt deleted file mode 100644 index 7f0770b28..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ /dev/null @@ -1,17 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.FieldVisitor - -@Suppress("FunctionName") -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = - ClassWriter(flags).apply(block) - -internal inline fun ClassWriter.visitField( - access: Int, - name: String, - descriptor: String, - signature: String?, - value: Any?, - block: FieldVisitor.() -> Unit -): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt deleted file mode 100644 index dc0b35531..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt +++ /dev/null @@ -1,7 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.Type -import kotlin.reflect.KClass - -internal val KClass<*>.asm: Type - get() = Type.getType(java) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt similarity index 57% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt index 003dc47dd..46d07976d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -1,8 +1,12 @@ package scientifik.kmath.asm.internal +import org.objectweb.asm.* import org.objectweb.asm.Opcodes.INVOKEVIRTUAL -import org.objectweb.asm.Type +import org.objectweb.asm.commons.InstructionAdapter +import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra +import kotlin.reflect.KClass private val methodNameAdapters: Map, String> by lazy { hashMapOf( @@ -15,6 +19,60 @@ private val methodNameAdapters: Map, String> by lazy { ) } +internal val KClass<*>.asm: Type + get() = Type.getType(java) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor]. + */ +private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. + */ +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) + +/** + * Constructs a [Label], then applies it to this visitor. + */ +internal fun MethodVisitor.label(): Label { + val l = Label() + visitLabel(l) + return l +} + +/** + * Creates a class name for [Expression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} + +@Suppress("FunctionName") +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) + +internal inline fun ClassWriter.visitField( + access: Int, + name: String, + descriptor: String, + signature: String?, + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) + /** * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds * type expectation stack for needed arity. @@ -60,6 +118,9 @@ private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Strin return true } +/** + * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. + */ internal fun AsmBuilder.buildAlgebraOperationCall( context: Algebra, name: String, @@ -84,3 +145,4 @@ internal fun AsmBuilder.buildAlgebraOperationCall( expectedArity = arity ) } + diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt deleted file mode 100644 index f47293687..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt +++ /dev/null @@ -1,10 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.Label -import org.objectweb.asm.commons.InstructionAdapter - -internal fun InstructionAdapter.label(): Label { - val l = Label() - visitLabel(l) - return l -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt deleted file mode 100644 index aaae02ebb..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ /dev/null @@ -1,9 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.commons.InstructionAdapter - -internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) - -internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = - instructionAdapter().apply(block)