Dev #127
@ -2,7 +2,7 @@ plugins {
|
|||||||
id("scientifik.publish") apply false
|
id("scientifik.publish") apply false
|
||||||
}
|
}
|
||||||
|
|
||||||
val kmathVersion by extra("0.1.4-dev-7")
|
val kmathVersion by extra("0.1.4-dev-8")
|
||||||
|
|
||||||
val bintrayRepo by extra("scientifik")
|
val bintrayRepo by extra("scientifik")
|
||||||
val githubProject by extra("kmath")
|
val githubProject by extra("kmath")
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
|
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
|
||||||
There are different types of buffers:
|
There are different types of buffers:
|
||||||
|
|
||||||
* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays.
|
* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays.
|
||||||
* Boxing `ListBuffer` wrapping a list
|
* Boxing `ListBuffer` wrapping a list
|
||||||
* Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value
|
* Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value
|
||||||
* `MemoryBuffer` allows direct allocation of objects in continuous memory block.
|
* `MemoryBuffer` allows direct allocation of objects in continuous memory block.
|
||||||
|
@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
|||||||
plugins {
|
plugins {
|
||||||
java
|
java
|
||||||
kotlin("jvm")
|
kotlin("jvm")
|
||||||
kotlin("plugin.allopen") version "1.3.71"
|
kotlin("plugin.allopen") version "1.3.72"
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-7"
|
id("kotlinx.benchmark") version "0.2.0-dev-8"
|
||||||
}
|
}
|
||||||
|
|
||||||
configure<AllOpenExtension> {
|
configure<AllOpenExtension> {
|
||||||
@ -24,6 +24,7 @@ sourceSets {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
implementation(project(":kmath-ast"))
|
||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
@ -33,8 +34,8 @@ dependencies {
|
|||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
|
||||||
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
|
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure benchmark
|
// Configure benchmark
|
||||||
|
@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex
|
|||||||
class BufferBenchmark {
|
class BufferBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun genericDoubleBufferReadWrite() {
|
fun genericRealBufferReadWrite() {
|
||||||
val buffer = DoubleBuffer(size){it.toDouble()}
|
val buffer = RealBuffer(size){it.toDouble()}
|
||||||
|
|
||||||
(0 until size).forEach {
|
(0 until size).forEach {
|
||||||
buffer[it]
|
buffer[it]
|
||||||
|
@ -20,48 +20,39 @@ class ViktorBenchmark {
|
|||||||
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Automatic field addition`() {
|
fun automaticFieldAddition() {
|
||||||
autoField.run {
|
autoField.run {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) { res += one }
|
||||||
res += 1.0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Viktor field addition`() {
|
fun viktorFieldAddition() {
|
||||||
viktorField.run {
|
viktorField.run {
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) { res += one }
|
||||||
res += one
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Raw Viktor`() {
|
fun rawViktor() {
|
||||||
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) {
|
repeat(n) { res = res + one }
|
||||||
res = res + one
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Real field log`() {
|
fun realdFieldLog() {
|
||||||
realField.run {
|
realField.run {
|
||||||
val fortyTwo = produce { 42.0 }
|
val fortyTwo = produce { 42.0 }
|
||||||
var res = one
|
var res = one
|
||||||
|
repeat(n) { res = ln(fortyTwo) }
|
||||||
repeat(n) {
|
|
||||||
res = ln(fortyTwo)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun `Raw Viktor log`() {
|
fun rawViktorLog() {
|
||||||
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
||||||
var res: F64Array
|
var res: F64Array
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
|
@ -0,0 +1,70 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.expressions.expressionInField
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
class ExpressionsInterpretersBenchmark {
|
||||||
|
private val algebra: Field<Double> = RealField
|
||||||
|
fun functionalExpression() {
|
||||||
|
val expr = algebra.expressionInField {
|
||||||
|
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
invokeAndSum(expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun mstExpression() {
|
||||||
|
val expr = algebra.mstInField {
|
||||||
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
invokeAndSum(expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun asmExpression() {
|
||||||
|
val expr = algebra.mstInField {
|
||||||
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
|
}.compile()
|
||||||
|
|
||||||
|
invokeAndSum(expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun invokeAndSum(expr: Expression<Double>) {
|
||||||
|
val random = Random(0)
|
||||||
|
var sum = 0.0
|
||||||
|
|
||||||
|
repeat(1000000) {
|
||||||
|
sum += expr("x" to random.nextDouble())
|
||||||
|
}
|
||||||
|
|
||||||
|
println(sum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val benchmark = ExpressionsInterpretersBenchmark()
|
||||||
|
|
||||||
|
val fe = measureTimeMillis {
|
||||||
|
benchmark.functionalExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("fe=$fe")
|
||||||
|
|
||||||
|
val mst = measureTimeMillis {
|
||||||
|
benchmark.mstExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("mst=$mst")
|
||||||
|
|
||||||
|
val asm = measureTimeMillis {
|
||||||
|
benchmark.asmExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("asm=$asm")
|
||||||
|
}
|
@ -6,7 +6,7 @@ fun main(args: Array<String>) {
|
|||||||
val n = 6000
|
val n = 6000
|
||||||
|
|
||||||
val array = DoubleArray(n * n) { 1.0 }
|
val array = DoubleArray(n * n) { 1.0 }
|
||||||
val buffer = DoubleBuffer(array)
|
val buffer = RealBuffer(array)
|
||||||
val strides = DefaultStrides(intArrayOf(n, n))
|
val strides = DefaultStrides(intArrayOf(n, n))
|
||||||
|
|
||||||
val structure = BufferNDStructure(strides, buffer)
|
val structure = BufferNDStructure(strides, buffer)
|
||||||
|
@ -26,10 +26,10 @@ fun main(args: Array<String>) {
|
|||||||
}
|
}
|
||||||
println("Array mapping finished in $time2 millis")
|
println("Array mapping finished in $time2 millis")
|
||||||
|
|
||||||
val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 })
|
val buffer = RealBuffer(DoubleArray(n * n) { 1.0 })
|
||||||
|
|
||||||
val time3 = measureTimeMillis {
|
val time3 = measureTimeMillis {
|
||||||
val target = DoubleBuffer(DoubleArray(n * n))
|
val target = RealBuffer(DoubleArray(n * n))
|
||||||
val res = array.forEachIndexed { index, value ->
|
val res = array.forEachIndexed { index, value ->
|
||||||
target[index] = value + 1
|
target[index] = value + 1
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ For example, the following builder:
|
|||||||
package scientifik.kmath.asm.generated;
|
package scientifik.kmath.asm.generated;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import scientifik.kmath.asm.internal.MapIntrinsics;
|
||||||
import scientifik.kmath.expressions.Expression;
|
import scientifik.kmath.expressions.Expression;
|
||||||
import scientifik.kmath.operations.RealField;
|
import scientifik.kmath.operations.RealField;
|
||||||
|
|
||||||
@ -37,23 +38,23 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression<Doub
|
|||||||
}
|
}
|
||||||
|
|
||||||
public final Double invoke(Map<String, ? extends Double> arguments) {
|
public final Double invoke(Map<String, ? extends Double> arguments) {
|
||||||
return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D);
|
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Example Usage
|
### Example Usage
|
||||||
|
|
||||||
This API is an extension to MST and MSTExpression APIs. You may optimize both MST and MSTExpression:
|
This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||||
RealField.expression("2+2".parseMath())
|
RealField.expression("x+2".parseMath())
|
||||||
```
|
```
|
||||||
|
|
||||||
### Known issues
|
### Known issues
|
||||||
|
|
||||||
- Using numeric algebras causes boxing and calling bridge methods.
|
|
||||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
||||||
class loading overhead.
|
class loading overhead.
|
||||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
||||||
|
@ -1,76 +0,0 @@
|
|||||||
package scientifik.kmath.ast
|
|
||||||
|
|
||||||
import scientifik.kmath.operations.*
|
|
||||||
|
|
||||||
object MSTAlgebra : NumericAlgebra<MST> {
|
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST =
|
|
||||||
MST.Unary(operation, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
|
||||||
MST.Binary(operation, left, right)
|
|
||||||
}
|
|
||||||
|
|
||||||
object MSTSpace : Space<MST>, NumericAlgebra<MST> {
|
|
||||||
override val zero: MST = number(0.0)
|
|
||||||
|
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
|
||||||
|
|
||||||
override fun add(a: MST, b: MST): MST =
|
|
||||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST =
|
|
||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
|
||||||
MSTAlgebra.binaryOperation(operation, left, right)
|
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
object MSTRing : Ring<MST>, NumericAlgebra<MST> {
|
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
|
||||||
|
|
||||||
override val zero: MST = MSTSpace.number(0.0)
|
|
||||||
override val one: MST = number(1.0)
|
|
||||||
override fun add(a: MST, b: MST): MST =
|
|
||||||
MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST =
|
|
||||||
MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
|
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST =
|
|
||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
|
||||||
MSTAlgebra.binaryOperation(operation, left, right)
|
|
||||||
}
|
|
||||||
|
|
||||||
object MSTField : Field<MST>{
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
|
||||||
|
|
||||||
override val zero: MST = MSTSpace.number(0.0)
|
|
||||||
override val one: MST = number(1.0)
|
|
||||||
override fun add(a: MST, b: MST): MST =
|
|
||||||
MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
|
||||||
|
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST =
|
|
||||||
MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
|
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST =
|
|
||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
|
||||||
|
|
||||||
override fun divide(a: MST, b: MST): MST =
|
|
||||||
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
|
||||||
MSTAlgebra.binaryOperation(operation, left, right)
|
|
||||||
}
|
|
@ -1,55 +0,0 @@
|
|||||||
package scientifik.kmath.ast
|
|
||||||
|
|
||||||
import scientifik.kmath.expressions.Expression
|
|
||||||
import scientifik.kmath.expressions.FunctionalExpressionField
|
|
||||||
import scientifik.kmath.expressions.FunctionalExpressionRing
|
|
||||||
import scientifik.kmath.expressions.FunctionalExpressionSpace
|
|
||||||
import scientifik.kmath.operations.*
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
|
||||||
*/
|
|
||||||
class MSTExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Substitute algebra raw value
|
|
||||||
*/
|
|
||||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T>{
|
|
||||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right)
|
|
||||||
|
|
||||||
override fun number(value: Number): T = if(algebra is NumericAlgebra){
|
|
||||||
algebra.number(value)
|
|
||||||
} else{
|
|
||||||
error("Numeric nodes are not supported by $this")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
|
||||||
mstAlgebra: E,
|
|
||||||
block: E.() -> MST
|
|
||||||
): MSTExpression<T> = MSTExpression(this, mstAlgebra.block())
|
|
||||||
|
|
||||||
inline fun <reified T : Any> Space<T>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
|
|
||||||
MSTExpression(this, MSTSpace.block())
|
|
||||||
|
|
||||||
inline fun <reified T : Any> Ring<T>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
|
|
||||||
MSTExpression(this, MSTRing.block())
|
|
||||||
|
|
||||||
inline fun <reified T : Any> Field<T>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
|
|
||||||
MSTExpression(this, MSTField.block())
|
|
||||||
|
|
||||||
inline fun <reified T: Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
|
|
||||||
algebra.mstInSpace(block)
|
|
||||||
|
|
||||||
inline fun <reified T: Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
|
|
||||||
algebra.mstInRing(block)
|
|
||||||
|
|
||||||
inline fun <reified T: Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
|
|
||||||
algebra.mstInField(block)
|
|
@ -0,0 +1,72 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
object MstAlgebra : NumericAlgebra<MST> {
|
||||||
|
override fun number(value: Number): MST = MST.Numeric(value)
|
||||||
|
|
||||||
|
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST =
|
||||||
|
MST.Unary(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MST.Binary(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
|
override val zero: MST = number(0.0)
|
||||||
|
|
||||||
|
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||||
|
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||||
|
|
||||||
|
override fun add(a: MST, b: MST): MST =
|
||||||
|
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
|
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
|
override val zero: MST = number(0.0)
|
||||||
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
|
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||||
|
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||||
|
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
|
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||||
|
|
||||||
|
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
object MstField : Field<MST> {
|
||||||
|
override val zero: MST = number(0.0)
|
||||||
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
|
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||||
|
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||||
|
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun multiply(a: MST, k: Number): MST =
|
||||||
|
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||||
|
|
||||||
|
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||||
|
}
|
@ -0,0 +1,55 @@
|
|||||||
|
package scientifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionField
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionRing
|
||||||
|
import scientifik.kmath.expressions.FunctionalExpressionSpace
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
||||||
|
*/
|
||||||
|
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Substitute algebra raw value
|
||||||
|
*/
|
||||||
|
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
||||||
|
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
||||||
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
|
algebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
||||||
|
algebra.number(value)
|
||||||
|
else
|
||||||
|
error("Numeric nodes are not supported by $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
|
mstAlgebra: E,
|
||||||
|
block: E.() -> MST
|
||||||
|
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstSpace.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstRing.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||||
|
MstExpression(this, MstField.block())
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||||
|
algebra.mstInSpace(block)
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||||
|
algebra.mstInRing(block)
|
||||||
|
|
||||||
|
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||||
|
algebra.mstInField(block)
|
@ -16,15 +16,15 @@ import scientifik.kmath.operations.SpaceOperations
|
|||||||
* TODO move to common
|
* TODO move to common
|
||||||
*/
|
*/
|
||||||
private object ArithmeticsEvaluator : Grammar<MST>() {
|
private object ArithmeticsEvaluator : Grammar<MST>() {
|
||||||
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?")
|
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
||||||
val lpar by token("\\(")
|
val lpar by token("\\(".toRegex())
|
||||||
val rpar by token("\\)")
|
val rpar by token("\\)".toRegex())
|
||||||
val mul by token("\\*")
|
val mul by token("\\*".toRegex())
|
||||||
val pow by token("\\^")
|
val pow by token("\\^".toRegex())
|
||||||
val div by token("/")
|
val div by token("/".toRegex())
|
||||||
val minus by token("-")
|
val minus by token("-".toRegex())
|
||||||
val plus by token("\\+")
|
val plus by token("\\+".toRegex())
|
||||||
val ws by token("\\s+", ignore = true)
|
val ws by token("\\s+".toRegex(), ignore = true)
|
||||||
|
|
||||||
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||||
|
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
package scientifik.kmath.asm
|
package scientifik.kmath.asm
|
||||||
|
|
||||||
import org.objectweb.asm.Type
|
|
||||||
import scientifik.kmath.asm.internal.AsmBuilder
|
import scientifik.kmath.asm.internal.AsmBuilder
|
||||||
import scientifik.kmath.asm.internal.buildExpectationStack
|
import scientifik.kmath.asm.internal.buildAlgebraOperationCall
|
||||||
import scientifik.kmath.asm.internal.buildName
|
import scientifik.kmath.asm.internal.buildName
|
||||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
|
||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
import scientifik.kmath.ast.MSTExpression
|
import scientifik.kmath.ast.MstExpression
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
import scientifik.kmath.operations.NumericAlgebra
|
||||||
@ -29,43 +27,21 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
loadTConstant(constant)
|
loadTConstant(constant)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Unary -> {
|
is MST.Unary -> buildAlgebraOperationCall(
|
||||||
loadAlgebra()
|
context = algebra,
|
||||||
if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation)
|
name = node.operation,
|
||||||
visit(node.value)
|
fallbackMethodName = "unaryOperation",
|
||||||
|
arity = 1
|
||||||
|
) { visit(node.value) }
|
||||||
|
|
||||||
if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation(
|
is MST.Binary -> buildAlgebraOperationCall(
|
||||||
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
context = algebra,
|
||||||
method = "unaryOperation",
|
name = node.operation,
|
||||||
|
fallbackMethodName = "binaryOperation",
|
||||||
descriptor = Type.getMethodDescriptor(
|
arity = 2
|
||||||
AsmBuilder.OBJECT_TYPE,
|
) {
|
||||||
AsmBuilder.STRING_TYPE,
|
|
||||||
AsmBuilder.OBJECT_TYPE
|
|
||||||
),
|
|
||||||
|
|
||||||
tArity = 1
|
|
||||||
)
|
|
||||||
}
|
|
||||||
is MST.Binary -> {
|
|
||||||
loadAlgebra()
|
|
||||||
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
|
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
|
|
||||||
if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation(
|
|
||||||
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
|
||||||
method = "binaryOperation",
|
|
||||||
|
|
||||||
descriptor = Type.getMethodDescriptor(
|
|
||||||
AsmBuilder.OBJECT_TYPE,
|
|
||||||
AsmBuilder.STRING_TYPE,
|
|
||||||
AsmBuilder.OBJECT_TYPE,
|
|
||||||
AsmBuilder.OBJECT_TYPE
|
|
||||||
),
|
|
||||||
|
|
||||||
tArity = 2
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -79,6 +55,6 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize performance of an [MSTExpression] using ASM codegen
|
* Optimize performance of an [MstExpression] using ASM codegen
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
package scientifik.kmath.asm.internal
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
import org.objectweb.asm.*
|
import org.objectweb.asm.*
|
||||||
import org.objectweb.asm.Opcodes.AALOAD
|
import org.objectweb.asm.Opcodes.*
|
||||||
import org.objectweb.asm.Opcodes.RETURN
|
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
@ -18,6 +17,7 @@ import kotlin.reflect.KClass
|
|||||||
* @param T the type of AsmExpression to unwrap.
|
* @param T the type of AsmExpression to unwrap.
|
||||||
* @param algebra the algebra the applied AsmExpressions use.
|
* @param algebra the algebra the applied AsmExpressions use.
|
||||||
* @param className the unique class name of new loaded class.
|
* @param className the unique class name of new loaded class.
|
||||||
|
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
||||||
*/
|
*/
|
||||||
internal class AsmBuilder<T> internal constructor(
|
internal class AsmBuilder<T> internal constructor(
|
||||||
private val classOfT: KClass<*>,
|
private val classOfT: KClass<*>,
|
||||||
@ -37,8 +37,19 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM Type for [algebra]
|
||||||
|
*/
|
||||||
private val tAlgebraType: Type = algebra::class.asm
|
private val tAlgebraType: Type = algebra::class.asm
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [T]
|
||||||
|
*/
|
||||||
internal val tType: Type = classOfT.asm
|
internal val tType: Type = classOfT.asm
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for new class
|
||||||
|
*/
|
||||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -60,15 +71,31 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* Method visitor of `invoke` method of the subclass.
|
* Method visitor of `invoke` method of the subclass.
|
||||||
*/
|
*/
|
||||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
internal var primitiveMode = false
|
|
||||||
|
|
||||||
@Suppress("PropertyName")
|
/**
|
||||||
internal var PRIMITIVE_MASK: Type = OBJECT_TYPE
|
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||||
|
*/
|
||||||
|
internal var primitiveMode: Boolean = false
|
||||||
|
|
||||||
@Suppress("PropertyName")
|
/**
|
||||||
internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE
|
* Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
private val typeStack = Stack<Type>()
|
*/
|
||||||
internal val expectationStack: Stack<Type> = Stack<Type>().apply { push(tType) }
|
internal var primitiveMask: Type = OBJECT_TYPE
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
|
*/
|
||||||
|
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stack of useful objects types on stack to verify types.
|
||||||
|
*/
|
||||||
|
private val typeStack: ArrayDeque<Type> = ArrayDeque()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stack of useful objects types on stack expected by algebra calls.
|
||||||
|
*/
|
||||||
|
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>().apply { push(tType) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The cache for instance built by this builder.
|
* The cache for instance built by this builder.
|
||||||
@ -86,14 +113,14 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
|
|
||||||
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
||||||
primitiveMode = true
|
primitiveMode = true
|
||||||
PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT)
|
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
|
||||||
PRIMITIVE_MASK_BOXED = tType
|
primitiveMaskBoxed = tType
|
||||||
}
|
}
|
||||||
|
|
||||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||||
visit(
|
visit(
|
||||||
Opcodes.V1_8,
|
V1_8,
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
||||||
classType.internalName,
|
classType.internalName,
|
||||||
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
||||||
OBJECT_TYPE.internalName,
|
OBJECT_TYPE.internalName,
|
||||||
@ -101,7 +128,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
visitField(
|
visitField(
|
||||||
access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL,
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
name = "algebra",
|
name = "algebra",
|
||||||
descriptor = tAlgebraType.descriptor,
|
descriptor = tAlgebraType.descriptor,
|
||||||
signature = null,
|
signature = null,
|
||||||
@ -110,7 +137,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
visitField(
|
visitField(
|
||||||
access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL,
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
name = "constants",
|
name = "constants",
|
||||||
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
||||||
signature = null,
|
signature = null,
|
||||||
@ -119,7 +146,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
Opcodes.ACC_PUBLIC,
|
ACC_PUBLIC,
|
||||||
"<init>",
|
"<init>",
|
||||||
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
|
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
|
||||||
null,
|
null,
|
||||||
@ -159,7 +186,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
ACC_PUBLIC or ACC_FINAL,
|
||||||
"invoke",
|
"invoke",
|
||||||
Type.getMethodDescriptor(tType, MAP_TYPE),
|
Type.getMethodDescriptor(tType, MAP_TYPE),
|
||||||
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
||||||
@ -195,7 +222,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||||
"invoke",
|
"invoke",
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||||
null,
|
null,
|
||||||
@ -238,34 +265,43 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a constant from
|
* Loads a [T] constant from [constants].
|
||||||
*/
|
*/
|
||||||
internal fun loadTConstant(value: T) {
|
internal fun loadTConstant(value: T) {
|
||||||
if (classOfT in INLINABLE_NUMBERS) {
|
if (classOfT in INLINABLE_NUMBERS) {
|
||||||
val expectedType = expectationStack.pop()!!
|
val expectedType = expectationStack.pop()
|
||||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||||
loadNumberConstant(value as Number, mustBeBoxed)
|
loadNumberConstant(value as Number, mustBeBoxed)
|
||||||
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK)
|
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loadConstant(value as Any, tType)
|
loadConstant(value as Any, tType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Boxes the current value and pushes it.
|
||||||
|
*/
|
||||||
private fun box(): Unit = invokeMethodVisitor.invokestatic(
|
private fun box(): Unit = invokeMethodVisitor.invokestatic(
|
||||||
tType.internalName,
|
tType.internalName,
|
||||||
"valueOf",
|
"valueOf",
|
||||||
Type.getMethodDescriptor(tType, PRIMITIVE_MASK),
|
Type.getMethodDescriptor(tType, primitiveMask),
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unboxes the current boxed value and pushes it.
|
||||||
|
*/
|
||||||
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
|
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
|
||||||
NUMBER_TYPE.internalName,
|
NUMBER_TYPE.internalName,
|
||||||
NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK),
|
NUMBER_CONVERTER_METHODS.getValue(primitiveMask),
|
||||||
Type.getMethodDescriptor(PRIMITIVE_MASK),
|
Type.getMethodDescriptor(primitiveMask),
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads [java.lang.Object] constant from constants.
|
||||||
|
*/
|
||||||
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||||
loadThis()
|
loadThis()
|
||||||
@ -275,6 +311,9 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
checkcast(type)
|
checkcast(type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads this variable.
|
||||||
|
*/
|
||||||
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
|
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -305,46 +344,40 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
loadConstant(value, boxed)
|
loadConstant(value, boxed)
|
||||||
|
|
||||||
if (!mustBeBoxed) unbox()
|
if (!mustBeBoxed) unbox()
|
||||||
else invokeMethodVisitor.checkcast(tType)
|
else invokeMethodVisitor.checkcast(tType)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided.
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
|
||||||
|
* provided.
|
||||||
*/
|
*/
|
||||||
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||||
load(invokeArgumentsVar, OBJECT_ARRAY_TYPE)
|
load(invokeArgumentsVar, MAP_TYPE)
|
||||||
|
aconst(name)
|
||||||
|
|
||||||
if (defaultValue != null) {
|
if (defaultValue != null)
|
||||||
loadStringConstant(name)
|
|
||||||
loadTConstant(defaultValue)
|
loadTConstant(defaultValue)
|
||||||
|
else
|
||||||
|
aconst(null)
|
||||||
|
|
||||||
invokeinterface(
|
invokestatic(
|
||||||
MAP_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
"getOrDefault",
|
"getOrFail",
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE)
|
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
|
||||||
)
|
false
|
||||||
|
|
||||||
invokeMethodVisitor.checkcast(tType)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
loadStringConstant(name)
|
|
||||||
|
|
||||||
invokeinterface(
|
|
||||||
MAP_TYPE.internalName,
|
|
||||||
"get",
|
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeMethodVisitor.checkcast(tType)
|
checkcast(tType)
|
||||||
val expectedType = expectationStack.pop()!!
|
|
||||||
|
val expectedType = expectationStack.pop()
|
||||||
|
|
||||||
if (expectedType.sort == Type.OBJECT)
|
if (expectedType.sort == Type.OBJECT)
|
||||||
typeStack.push(tType)
|
typeStack.push(tType)
|
||||||
else {
|
else {
|
||||||
unbox()
|
unbox()
|
||||||
typeStack.push(PRIMITIVE_MASK)
|
typeStack.push(primitiveMask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -358,7 +391,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
|
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
|
||||||
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be
|
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
|
||||||
* called before the arguments and this operation.
|
* called before the arguments and this operation.
|
||||||
*
|
*
|
||||||
* The result is casted to [T] automatically.
|
* The result is casted to [T] automatically.
|
||||||
@ -367,12 +400,12 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
owner: String,
|
owner: String,
|
||||||
method: String,
|
method: String,
|
||||||
descriptor: String,
|
descriptor: String,
|
||||||
tArity: Int,
|
expectedArity: Int,
|
||||||
opcode: Int = Opcodes.INVOKEINTERFACE
|
opcode: Int = INVOKEINTERFACE
|
||||||
) {
|
) {
|
||||||
run loop@{
|
run loop@{
|
||||||
repeat(tArity) {
|
repeat(expectedArity) {
|
||||||
if (typeStack.empty()) return@loop
|
if (typeStack.isEmpty()) return@loop
|
||||||
typeStack.pop()
|
typeStack.pop()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -382,18 +415,18 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
owner,
|
owner,
|
||||||
method,
|
method,
|
||||||
descriptor,
|
descriptor,
|
||||||
opcode == Opcodes.INVOKEINTERFACE
|
opcode == INVOKEINTERFACE
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeMethodVisitor.checkcast(tType)
|
invokeMethodVisitor.checkcast(tType)
|
||||||
val isLastExpr = expectationStack.size == 1
|
val isLastExpr = expectationStack.size == 1
|
||||||
val expectedType = expectationStack.pop()!!
|
val expectedType = expectationStack.pop()
|
||||||
|
|
||||||
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
||||||
typeStack.push(tType)
|
typeStack.push(tType)
|
||||||
else {
|
else {
|
||||||
unbox()
|
unbox()
|
||||||
typeStack.push(PRIMITIVE_MASK)
|
typeStack.push(primitiveMask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -404,7 +437,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
|
|
||||||
internal companion object {
|
internal companion object {
|
||||||
/**
|
/**
|
||||||
* Maps JVM primitive numbers boxed types to their letters of JVM signature convention.
|
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
||||||
*/
|
*/
|
||||||
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
|
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
|
||||||
hashMapOf(
|
hashMapOf(
|
||||||
@ -417,8 +450,14 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||||
|
*/
|
||||||
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps primitive ASM types to [Number] functions unboxing them.
|
||||||
|
*/
|
||||||
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
|
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
|
||||||
hashMapOf(
|
hashMapOf(
|
||||||
Type.BYTE_TYPE to "byteValue",
|
Type.BYTE_TYPE to "byteValue",
|
||||||
@ -434,14 +473,46 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
||||||
*/
|
*/
|
||||||
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [Expression].
|
||||||
|
*/
|
||||||
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
|
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.lang.Number].
|
||||||
|
*/
|
||||||
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
|
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.util.Map].
|
||||||
|
*/
|
||||||
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
|
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.lang.Object].
|
||||||
|
*/
|
||||||
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
|
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for array of [java.lang.Object].
|
||||||
|
*/
|
||||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
||||||
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
|
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [Algebra].
|
||||||
|
*/
|
||||||
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
|
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [java.lang.String].
|
||||||
|
*/
|
||||||
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for MapIntrinsics.
|
||||||
|
*/
|
||||||
|
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
package scientifik.kmath.asm.internal
|
|
||||||
|
|
||||||
import scientifik.kmath.ast.MST
|
|
||||||
import scientifik.kmath.expressions.Expression
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
|
||||||
*
|
|
||||||
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
|
|
||||||
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
|
||||||
*/
|
|
||||||
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
|
||||||
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
|
||||||
|
|
||||||
try {
|
|
||||||
Class.forName(name)
|
|
||||||
} catch (ignored: ClassNotFoundException) {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
return buildName(mst, collision + 1)
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
package scientifik.kmath.asm.internal
|
|
||||||
|
|
||||||
import org.objectweb.asm.ClassWriter
|
|
||||||
import org.objectweb.asm.FieldVisitor
|
|
||||||
import org.objectweb.asm.MethodVisitor
|
|
||||||
|
|
||||||
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
|
|
||||||
ClassWriter(flags).apply(block)
|
|
||||||
|
|
||||||
internal inline fun ClassWriter.visitField(
|
|
||||||
access: Int,
|
|
||||||
name: String,
|
|
||||||
descriptor: String,
|
|
||||||
signature: String?,
|
|
||||||
value: Any?,
|
|
||||||
block: FieldVisitor.() -> Unit
|
|
||||||
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
|
@ -1,7 +0,0 @@
|
|||||||
package scientifik.kmath.asm.internal
|
|
||||||
|
|
||||||
import org.objectweb.asm.Type
|
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
internal val KClass<*>.asm: Type
|
|
||||||
get() = Type.getType(java)
|
|
@ -0,0 +1,148 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import org.objectweb.asm.*
|
||||||
|
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
||||||
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||||
|
hashMapOf(
|
||||||
|
"+" to 2 to "add",
|
||||||
|
"*" to 2 to "multiply",
|
||||||
|
"/" to 2 to "divide",
|
||||||
|
"+" to 1 to "unaryPlus",
|
||||||
|
"-" to 1 to "unaryMinus",
|
||||||
|
"-" to 2 to "minus"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal val KClass<*>.asm: Type
|
||||||
|
get() = Type.getType(java)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
||||||
|
*/
|
||||||
|
private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
|
||||||
|
*/
|
||||||
|
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
|
||||||
|
instructionAdapter().apply(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a [Label], then applies it to this visitor.
|
||||||
|
*/
|
||||||
|
internal fun MethodVisitor.label(): Label {
|
||||||
|
val l = Label()
|
||||||
|
visitLabel(l)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
||||||
|
*
|
||||||
|
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
|
||||||
|
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
||||||
|
*/
|
||||||
|
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
||||||
|
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||||
|
|
||||||
|
try {
|
||||||
|
Class.forName(name)
|
||||||
|
} catch (ignored: ClassNotFoundException) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildName(mst, collision + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("FunctionName")
|
||||||
|
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
|
||||||
|
ClassWriter(flags).apply(block)
|
||||||
|
|
||||||
|
internal inline fun ClassWriter.visitField(
|
||||||
|
access: Int,
|
||||||
|
name: String,
|
||||||
|
descriptor: String,
|
||||||
|
signature: String?,
|
||||||
|
value: Any?,
|
||||||
|
block: FieldVisitor.() -> Unit
|
||||||
|
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
|
||||||
|
* type expectation stack for needed arity.
|
||||||
|
*
|
||||||
|
* @return `true` if contains, else `false`.
|
||||||
|
*/
|
||||||
|
private fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
|
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
|
||||||
|
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
|
||||||
|
repeat(arity) { expectationStack.push(t) }
|
||||||
|
return hasSpecific
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
|
||||||
|
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
||||||
|
*
|
||||||
|
* @return `true` if contains, else `false`.
|
||||||
|
*/
|
||||||
|
private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
|
|
||||||
|
context.javaClass.methods.find {
|
||||||
|
var suitableSignature = it.name == theName && it.parameters.size == arity
|
||||||
|
|
||||||
|
if (primitiveMode && it.isBridge)
|
||||||
|
suitableSignature = false
|
||||||
|
|
||||||
|
suitableSignature
|
||||||
|
} ?: return false
|
||||||
|
|
||||||
|
val owner = context::class.asm
|
||||||
|
|
||||||
|
invokeAlgebraOperation(
|
||||||
|
owner = owner.internalName,
|
||||||
|
method = theName,
|
||||||
|
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }),
|
||||||
|
expectedArity = arity,
|
||||||
|
opcode = INVOKEVIRTUAL
|
||||||
|
)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
||||||
|
*/
|
||||||
|
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||||
|
context: Algebra<T>,
|
||||||
|
name: String,
|
||||||
|
fallbackMethodName: String,
|
||||||
|
arity: Int,
|
||||||
|
parameters: AsmBuilder<T>.() -> Unit
|
||||||
|
) {
|
||||||
|
loadAlgebra()
|
||||||
|
if (!buildExpectationStack(context, name, arity)) loadStringConstant(name)
|
||||||
|
parameters()
|
||||||
|
|
||||||
|
if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation(
|
||||||
|
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||||
|
method = fallbackMethodName,
|
||||||
|
|
||||||
|
descriptor = Type.getMethodDescriptor(
|
||||||
|
AsmBuilder.OBJECT_TYPE,
|
||||||
|
AsmBuilder.STRING_TYPE,
|
||||||
|
*Array(arity) { AsmBuilder.OBJECT_TYPE }
|
||||||
|
),
|
||||||
|
|
||||||
|
expectedArity = arity
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,7 @@
|
|||||||
|
@file:JvmName("MapIntrinsics")
|
||||||
|
|
||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
|
||||||
|
return this[key] ?: default ?: error("Parameter not found: $key")
|
||||||
|
}
|
@ -1,9 +0,0 @@
|
|||||||
package scientifik.kmath.asm.internal
|
|
||||||
|
|
||||||
import org.objectweb.asm.MethodVisitor
|
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
|
||||||
|
|
||||||
internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
|
|
||||||
|
|
||||||
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
|
|
||||||
instructionAdapter().apply(block)
|
|
@ -1,61 +0,0 @@
|
|||||||
package scientifik.kmath.asm.internal
|
|
||||||
|
|
||||||
import org.objectweb.asm.Opcodes
|
|
||||||
import org.objectweb.asm.Type
|
|
||||||
import scientifik.kmath.operations.Algebra
|
|
||||||
|
|
||||||
private val methodNameAdapters: Map<String, String> by lazy {
|
|
||||||
hashMapOf(
|
|
||||||
"+" to "add",
|
|
||||||
"*" to "multiply",
|
|
||||||
"/" to "divide"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
|
|
||||||
* type expectation stack for needed arity.
|
|
||||||
*
|
|
||||||
* @return `true` if contains, else `false`.
|
|
||||||
*/
|
|
||||||
internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
|
||||||
val aName = methodNameAdapters[name] ?: name
|
|
||||||
|
|
||||||
val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null
|
|
||||||
val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType
|
|
||||||
repeat(arity) { expectationStack.push(t) }
|
|
||||||
|
|
||||||
return hasSpecific
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
|
|
||||||
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
|
||||||
*
|
|
||||||
* @return `true` if contains, else `false`.
|
|
||||||
*/
|
|
||||||
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
|
||||||
val aName = methodNameAdapters[name] ?: name
|
|
||||||
|
|
||||||
val method =
|
|
||||||
context.javaClass.methods.find {
|
|
||||||
var suitableSignature = it.name == aName && it.parameters.size == arity
|
|
||||||
|
|
||||||
if (primitiveMode && it.isBridge)
|
|
||||||
suitableSignature = false
|
|
||||||
|
|
||||||
suitableSignature
|
|
||||||
} ?: return false
|
|
||||||
|
|
||||||
val owner = context::class.java.name.replace('.', '/')
|
|
||||||
|
|
||||||
invokeAlgebraOperation(
|
|
||||||
owner = owner,
|
|
||||||
method = aName,
|
|
||||||
descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }),
|
|
||||||
tArity = arity,
|
|
||||||
opcode = Opcodes.INVOKEVIRTUAL
|
|
||||||
)
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class TestAsmAlgebras {
|
internal class TestAsmAlgebras {
|
||||||
@Test
|
@Test
|
||||||
fun space() {
|
fun space() {
|
||||||
val res1 = ByteRing.mstInSpace {
|
val res1 = ByteRing.mstInSpace {
|
||||||
@ -92,8 +92,8 @@ class TestAsmAlgebras {
|
|||||||
"+",
|
"+",
|
||||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
1 / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
)
|
) + zero
|
||||||
}("x" to 2.0)
|
}("x" to 2.0)
|
||||||
|
|
||||||
val res2 = RealField.mstInField {
|
val res2 = RealField.mstInField {
|
||||||
@ -101,8 +101,8 @@ class TestAsmAlgebras {
|
|||||||
"+",
|
"+",
|
||||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
1 / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
)
|
) + zero
|
||||||
}.compile()("x" to 2.0)
|
}.compile()("x" to 2.0)
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
assertEquals(res1, res2)
|
||||||
|
@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class TestAsmExpressions {
|
internal class TestAsmExpressions {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryOperationInvocation() {
|
fun testUnaryOperationInvocation() {
|
||||||
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.compile
|
||||||
|
import scientifik.kmath.ast.mstInField
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class TestAsmSpecialization {
|
||||||
|
@Test
|
||||||
|
fun testUnaryPlus() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
||||||
|
assertEquals(2.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testUnaryMinus() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
||||||
|
assertEquals(-2.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAdd() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
||||||
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSine() {
|
||||||
|
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
|
||||||
|
assertEquals(0.0, expr("x" to 0.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
||||||
|
assertEquals(0.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivide() {
|
||||||
|
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||||
|
assertEquals(1.0, expr("x" to 2.0))
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,22 @@
|
|||||||
|
package scietifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.mstInRing
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.ByteRing
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFailsWith
|
||||||
|
|
||||||
|
internal class TestAsmVariables {
|
||||||
|
@Test
|
||||||
|
fun testVariableWithoutDefault() {
|
||||||
|
val expr = ByteRing.mstInRing { symbol("x") }
|
||||||
|
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testVariableWithoutDefaultFails() {
|
||||||
|
val expr = ByteRing.mstInRing { symbol("x") }
|
||||||
|
assertFailsWith<IllegalStateException> { expr() }
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class AsmTest {
|
internal class AsmTest {
|
||||||
@Test
|
@Test
|
||||||
fun `compile MST`() {
|
fun `compile MST`() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
package scientifik.kmath.commons.random
|
||||||
|
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
|
||||||
|
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
|
||||||
|
org.apache.commons.math3.random.RandomGenerator {
|
||||||
|
private var generator = factory(intArrayOf())
|
||||||
|
|
||||||
|
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||||
|
|
||||||
|
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||||
|
|
||||||
|
override fun setSeed(seed: Int) {
|
||||||
|
generator = factory(intArrayOf(seed))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setSeed(seed: IntArray) {
|
||||||
|
generator = factory(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun setSeed(seed: Long) {
|
||||||
|
setSeed(seed.toInt())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nextBytes(bytes: ByteArray) {
|
||||||
|
generator.fillBytes(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nextInt(): Int = generator.nextInt()
|
||||||
|
|
||||||
|
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
||||||
|
|
||||||
|
override fun nextGaussian(): Double = TODO()
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = generator.nextDouble()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = generator.nextLong()
|
||||||
|
}
|
@ -18,7 +18,7 @@ object Transformations {
|
|||||||
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
|
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
|
||||||
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
|
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
|
||||||
|
|
||||||
private fun Buffer<Double>.asArray() = if (this is DoubleBuffer) {
|
private fun Buffer<Double>.asArray() = if (this is RealBuffer) {
|
||||||
array
|
array
|
||||||
} else {
|
} else {
|
||||||
DoubleArray(size) { i -> get(i) }
|
DoubleArray(size) { i -> get(i) }
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple geometric domain
|
||||||
|
*/
|
||||||
|
interface Domain<T : Any> {
|
||||||
|
operator fun contains(point: Point<T>): Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of hyperspace dimensions
|
||||||
|
*/
|
||||||
|
val dimension: Int
|
||||||
|
}
|
@ -0,0 +1,67 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Alexander Nozik.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.structures.RealBuffer
|
||||||
|
import scientifik.kmath.structures.indices
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* HyperSquareDomain class.
|
||||||
|
*
|
||||||
|
* @author Alexander Nozik
|
||||||
|
*/
|
||||||
|
class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
|
||||||
|
|
||||||
|
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
|
||||||
|
point[i] in lower[i]..upper[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
override val dimension: Int get() = lower.size
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int): Double? = lower[num]
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int): Double? = upper[num]
|
||||||
|
|
||||||
|
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||||
|
val res: DoubleArray = DoubleArray(point.size) { i ->
|
||||||
|
when {
|
||||||
|
point[i] < lower[i] -> lower[i]
|
||||||
|
point[i] > upper[i] -> upper[i]
|
||||||
|
else -> point[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RealBuffer(*res)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun volume(): Double {
|
||||||
|
var res = 1.0
|
||||||
|
for (i in 0 until dimension) {
|
||||||
|
if (lower[i].isInfinite() || upper[i].isInfinite()) {
|
||||||
|
return Double.POSITIVE_INFINITY
|
||||||
|
}
|
||||||
|
if (upper[i] > lower[i]) {
|
||||||
|
res *= upper[i] - lower[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,65 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Alexander Nozik.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
|
/**
|
||||||
|
* n-dimensional volume
|
||||||
|
*
|
||||||
|
* @author Alexander Nozik
|
||||||
|
*/
|
||||||
|
interface RealDomain: Domain<Double> {
|
||||||
|
|
||||||
|
fun nearestInDomain(point: Point<Double>): Point<Double>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The lower edge for the domain going down from point
|
||||||
|
* @param num
|
||||||
|
* @param point
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getLowerBound(num: Int, point: Point<Double>): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The upper edge of the domain going up from point
|
||||||
|
* @param num
|
||||||
|
* @param point
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getUpperBound(num: Int, point: Point<Double>): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Global lower edge
|
||||||
|
* @param num
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getLowerBound(num: Int): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Global upper edge
|
||||||
|
* @param num
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun getUpperBound(num: Int): Double?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hyper volume
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
fun volume(): Double
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,36 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Alexander Nozik.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
|
class UnconstrainedDomain(override val dimension: Int) : RealDomain {
|
||||||
|
|
||||||
|
override operator fun contains(point: Point<Double>): Boolean = true
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
|
||||||
|
|
||||||
|
override fun nearestInDomain(point: Point<Double>): Point<Double> = point
|
||||||
|
|
||||||
|
override fun volume(): Double = Double.POSITIVE_INFINITY
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
package scientifik.kmath.domains
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
|
||||||
|
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
|
||||||
|
|
||||||
|
operator fun contains(d: Double): Boolean = range.contains(d)
|
||||||
|
|
||||||
|
override operator fun contains(point: Point<Double>): Boolean {
|
||||||
|
require(point.size == 0)
|
||||||
|
return contains(point[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||||
|
require(point.size == 1)
|
||||||
|
val value = point[0]
|
||||||
|
return when{
|
||||||
|
value in range -> point
|
||||||
|
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
|
||||||
|
else -> doubleArrayOf(range.start).asBuffer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int, point: Point<Double>): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.start
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int, point: Point<Double>): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.endInclusive
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getLowerBound(num: Int): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.start
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int): Double? {
|
||||||
|
require(num == 0)
|
||||||
|
return range.endInclusive
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun volume(): Double = range.endInclusive - range.start
|
||||||
|
|
||||||
|
override val dimension: Int get() = 1
|
||||||
|
}
|
@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
|||||||
override val elementContext get() = RealField
|
override val elementContext get() = RealField
|
||||||
|
|
||||||
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||||
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
return BufferMatrix(rows, columns, buffer)
|
return BufferMatrix(rows, columns, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer)
|
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size,initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
class BufferMatrix<T : Any>(
|
class BufferMatrix<T : Any>(
|
||||||
@ -102,7 +102,7 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
|
|||||||
val array = DoubleArray(this.rowNum * other.colNum)
|
val array = DoubleArray(this.rowNum * other.colNum)
|
||||||
|
|
||||||
//convert to array to insure there is not memory indirection
|
//convert to array to insure there is not memory indirection
|
||||||
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is DoubleBuffer) {
|
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is RealBuffer) {
|
||||||
array
|
array
|
||||||
} else {
|
} else {
|
||||||
DoubleArray(size) { get(it) }
|
DoubleArray(size) { get(it) }
|
||||||
@ -119,6 +119,6 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val buffer = DoubleBuffer(array)
|
val buffer = RealBuffer(array)
|
||||||
return BufferMatrix(rowNum, other.colNum, buffer)
|
return BufferMatrix(rowNum, other.colNum, buffer)
|
||||||
}
|
}
|
@ -37,9 +37,9 @@ interface Buffer<T> {
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
||||||
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer {
|
inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
|
||||||
val array = DoubleArray(size) { initializer(it) }
|
val array = DoubleArray(size) { initializer(it) }
|
||||||
return DoubleBuffer(array)
|
return RealBuffer(array)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -51,7 +51,7 @@ interface Buffer<T> {
|
|||||||
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||||
//TODO add resolution based on Annotation or companion resolution
|
//TODO add resolution based on Annotation or companion resolution
|
||||||
return when (type) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
||||||
@ -93,7 +93,7 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||||
return when (type) {
|
return when (type) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
||||||
@ -109,12 +109,11 @@ interface MutableBuffer<T> : Buffer<T> {
|
|||||||
auto(T::class, size, initializer)
|
auto(T::class, size, initializer)
|
||||||
|
|
||||||
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
||||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
||||||
|
|
||||||
override val size: Int
|
override val size: Int
|
||||||
@ -163,57 +162,6 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
|||||||
|
|
||||||
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
|
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
|
||||||
|
|
||||||
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
|
||||||
override val size: Int get() = array.size
|
|
||||||
|
|
||||||
override fun get(index: Int): Short = array[index]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: Short) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun iterator() = array.iterator()
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Short> = ShortBuffer(array.copyOf())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fun ShortArray.asBuffer() = ShortBuffer(this)
|
|
||||||
|
|
||||||
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
|
||||||
override val size: Int get() = array.size
|
|
||||||
|
|
||||||
override fun get(index: Int): Int = array[index]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: Int) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun iterator() = array.iterator()
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Int> = IntBuffer(array.copyOf())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fun IntArray.asBuffer() = IntBuffer(this)
|
|
||||||
|
|
||||||
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
|
||||||
override val size: Int get() = array.size
|
|
||||||
|
|
||||||
override fun get(index: Int): Long = array[index]
|
|
||||||
|
|
||||||
override fun set(index: Int, value: Long) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun iterator() = array.iterator()
|
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Long> = LongBuffer(array.copyOf())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fun LongArray.asBuffer() = LongBuffer(this)
|
|
||||||
|
|
||||||
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
||||||
override val size: Int get() = buffer.size
|
override val size: Int get() = buffer.size
|
||||||
|
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import kotlin.experimental.and
|
||||||
|
|
||||||
|
enum class ValueFlag(val mask: Byte) {
|
||||||
|
NAN(0b0000_0001),
|
||||||
|
MISSING(0b0000_0010),
|
||||||
|
NEGATIVE_INFINITY(0b0000_0100),
|
||||||
|
POSITIVE_INFINITY(0b0000_1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A buffer with flagged values
|
||||||
|
*/
|
||||||
|
interface FlaggedBuffer<T> : Buffer<T> {
|
||||||
|
fun getFlag(index: Int): Byte
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The value is valid if all flags are down
|
||||||
|
*/
|
||||||
|
fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte()
|
||||||
|
|
||||||
|
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte()
|
||||||
|
|
||||||
|
fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A real buffer which supports flags for each value like NaN or Missing
|
||||||
|
*/
|
||||||
|
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
||||||
|
init {
|
||||||
|
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getFlag(index: Int): Byte = flags[index]
|
||||||
|
|
||||||
|
override val size: Int get() = values.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
|
||||||
|
if (isValid(it)) values[it] else null
|
||||||
|
}.iterator()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
||||||
|
for(i in indices){
|
||||||
|
if(isValid(i)){
|
||||||
|
block(values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Int = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Int) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator() = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Int> =
|
||||||
|
IntBuffer(array.copyOf())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun IntArray.asBuffer() = IntBuffer(this)
|
@ -0,0 +1,19 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Long = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Long) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator() = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Long> =
|
||||||
|
LongBuffer(array.copyOf())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fun LongArray.asBuffer() = LongBuffer(this)
|
@ -1,6 +1,6 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override fun get(index: Int): Double = array[index]
|
override fun get(index: Int): Double = array[index]
|
||||||
@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
|||||||
override fun iterator() = array.iterator()
|
override fun iterator() = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Double> =
|
override fun copy(): MutableBuffer<Double> =
|
||||||
DoubleBuffer(array.copyOf())
|
RealBuffer(array.copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) })
|
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles)
|
fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transform buffer of doubles into array for high performance operations
|
* Transform buffer of doubles into array for high performance operations
|
||||||
*/
|
*/
|
||||||
val MutableBuffer<out Double>.array: DoubleArray
|
val MutableBuffer<out Double>.array: DoubleArray
|
||||||
get() = if (this is DoubleBuffer) {
|
get() = if (this is RealBuffer) {
|
||||||
array
|
array
|
||||||
} else {
|
} else {
|
||||||
DoubleArray(size) { get(it) }
|
DoubleArray(size) { get(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun DoubleArray.asBuffer() = DoubleBuffer(this)
|
fun DoubleArray.asBuffer() = RealBuffer(this)
|
@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) :
|
|||||||
override val one by lazy { produce { one } }
|
override val one by lazy { produce { one } }
|
||||||
|
|
||||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inline transform an NDStructure to
|
* Inline transform an NDStructure to
|
||||||
@ -89,7 +89,7 @@ class RealNDField(override val shape: IntArray) :
|
|||||||
*/
|
*/
|
||||||
inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
|
inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
|
||||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
|
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
|
||||||
return BufferedNDFieldElement(this, DoubleBuffer(array))
|
return BufferedNDFieldElement(this, RealBuffer(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -103,7 +103,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int
|
|||||||
*/
|
*/
|
||||||
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
|
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
|
||||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
|
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
|
||||||
return BufferedNDFieldElement(context, DoubleBuffer(array))
|
return BufferedNDFieldElement(context, RealBuffer(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
||||||
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
|
override fun get(index: Int): Short = array[index]
|
||||||
|
|
||||||
|
override fun set(index: Int, value: Short) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator() = array.iterator()
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<Short> =
|
||||||
|
ShortBuffer(array.copyOf())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun ShortArray.asBuffer() = ShortBuffer(this)
|
@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.*
|
|||||||
import scientifik.kmath.chains.BlockingRealChain
|
import scientifik.kmath.chains.BlockingRealChain
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.BufferFactory
|
import scientifik.kmath.structures.BufferFactory
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
import scientifik.kmath.structures.RealBuffer
|
||||||
import scientifik.kmath.structures.asBuffer
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -45,7 +45,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
|
|||||||
/**
|
/**
|
||||||
* Specialized flow chunker for real buffer
|
* Specialized flow chunker for real buffer
|
||||||
*/
|
*/
|
||||||
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
|
||||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
|
|
||||||
if (this@chunked is BlockingRealChain) {
|
if (this@chunked is BlockingRealChain) {
|
||||||
@ -61,13 +61,13 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
|||||||
array[counter] = element
|
array[counter] = element
|
||||||
counter++
|
counter++
|
||||||
if (counter == bufferSize) {
|
if (counter == bufferSize) {
|
||||||
val buffer = DoubleBuffer(array)
|
val buffer = RealBuffer(array)
|
||||||
emit(buffer)
|
emit(buffer)
|
||||||
counter = 0
|
counter = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (counter > 0) {
|
if (counter > 0) {
|
||||||
emit(DoubleBuffer(counter) { array[it] })
|
emit(RealBuffer(counter) { array[it] })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,31 +7,28 @@ import scientifik.kmath.operations.Norm
|
|||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.SpaceElement
|
import scientifik.kmath.operations.SpaceElement
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
import scientifik.kmath.structures.RealBuffer
|
||||||
import scientifik.kmath.structures.asBuffer
|
import scientifik.kmath.structures.asBuffer
|
||||||
import scientifik.kmath.structures.asIterable
|
import scientifik.kmath.structures.asIterable
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
typealias RealPoint = Point<Double>
|
||||||
|
|
||||||
fun DoubleArray.asVector() = RealVector(this.asBuffer())
|
fun DoubleArray.asVector() = RealVector(this.asBuffer())
|
||||||
fun List<Double>.asVector() = RealVector(this.asBuffer())
|
fun List<Double>.asVector() = RealVector(this.asBuffer())
|
||||||
|
|
||||||
|
|
||||||
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||||
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
|
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
|
||||||
}
|
}
|
||||||
|
|
||||||
inline class RealVector(private val point: Point<Double>) :
|
inline class RealVector(private val point: Point<Double>) :
|
||||||
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> {
|
SpaceElement<RealPoint, RealVector, VectorSpace<Double, RealField>>, RealPoint {
|
||||||
|
|
||||||
override val context: VectorSpace<Double, RealField>
|
override val context: VectorSpace<Double, RealField> get() = space(point.size)
|
||||||
get() = space(
|
|
||||||
point.size
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun unwrap(): Point<Double> = point
|
override fun unwrap(): RealPoint = point
|
||||||
|
|
||||||
override fun Point<Double>.wrap(): RealVector =
|
override fun RealPoint.wrap(): RealVector = RealVector(this)
|
||||||
RealVector(this)
|
|
||||||
|
|
||||||
override val size: Int get() = point.size
|
override val size: Int get() = point.size
|
||||||
|
|
||||||
@ -44,16 +41,12 @@ inline class RealVector(private val point: Point<Double>) :
|
|||||||
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
||||||
|
|
||||||
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
|
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
|
||||||
RealVector(DoubleBuffer(dim, initializer))
|
RealVector(RealBuffer(dim, initializer))
|
||||||
|
|
||||||
operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
||||||
|
|
||||||
fun space(dim: Int): BufferVectorSpace<Double, RealField> =
|
fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) {
|
||||||
spaceCache.getOrPut(dim) {
|
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
|
||||||
BufferVectorSpace(
|
}
|
||||||
dim,
|
|
||||||
RealField
|
|
||||||
) { size, init -> Buffer.real(size, init) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,8 +1,8 @@
|
|||||||
package scientifik.kmath.real
|
package scientifik.kmath.real
|
||||||
|
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
import scientifik.kmath.structures.RealBuffer
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simplified [DoubleBuffer] to array comparison
|
* Simplified [RealBuffer] to array comparison
|
||||||
*/
|
*/
|
||||||
fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles)
|
fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles)
|
@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext
|
|||||||
import scientifik.kmath.linear.VirtualMatrix
|
import scientifik.kmath.linear.VirtualMatrix
|
||||||
import scientifik.kmath.operations.sum
|
import scientifik.kmath.operations.sum
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
|
import scientifik.kmath.structures.RealBuffer
|
||||||
import scientifik.kmath.structures.asIterable
|
import scientifik.kmath.structures.asIterable
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
@ -27,6 +27,10 @@ typealias RealMatrix = Matrix<Double>
|
|||||||
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
||||||
MatrixContext.real.produce(rowNum, colNum, initializer)
|
MatrixContext.real.produce(rowNum, colNum, initializer)
|
||||||
|
|
||||||
|
fun Array<DoubleArray>.toMatrix(): RealMatrix{
|
||||||
|
return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] }
|
||||||
|
}
|
||||||
|
|
||||||
fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let {
|
fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let {
|
||||||
MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] }
|
MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] }
|
||||||
}
|
}
|
||||||
@ -129,22 +133,22 @@ fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
|
|||||||
fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix =
|
fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix =
|
||||||
extractColumns(columnIndex..columnIndex)
|
extractColumns(columnIndex..columnIndex)
|
||||||
|
|
||||||
fun Matrix<Double>.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
val column = columns[j]
|
val column = columns[j]
|
||||||
with(elementContext) {
|
with(elementContext) {
|
||||||
sum(column.asIterable())
|
sum(column.asIterable())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column")
|
columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column")
|
columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Matrix<Double>.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||||
columns[j].asIterable().average()
|
columns[j].asIterable().average()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
package scientifik.kmath.histogram
|
package scientifik.kmath.histogram
|
||||||
|
|
||||||
|
import scientifik.kmath.domains.Domain
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.structures.ArrayBuffer
|
import scientifik.kmath.structures.ArrayBuffer
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
import scientifik.kmath.structures.RealBuffer
|
||||||
|
|
||||||
/**
|
|
||||||
* A simple geometric domain
|
|
||||||
* TODO move to geometry module
|
|
||||||
*/
|
|
||||||
interface Domain<T : Any> {
|
|
||||||
operator fun contains(vector: Point<out T>): Boolean
|
|
||||||
val dimension: Int
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The bin in the histogram. The histogram is by definition always done in the real space
|
* The bin in the histogram. The histogram is by definition always done in the real space
|
||||||
@ -51,9 +43,9 @@ interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
|||||||
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
||||||
|
|
||||||
fun MutableHistogram<Double, *>.put(vararg point: Number) =
|
fun MutableHistogram<Double, *>.put(vararg point: Number) =
|
||||||
put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
||||||
|
|
||||||
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(DoubleBuffer(point))
|
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(RealBuffer(point))
|
||||||
|
|
||||||
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
|
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package scientifik.kmath.histogram
|
package scientifik.kmath.histogram
|
||||||
|
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.real.asVector
|
|
||||||
import scientifik.kmath.operations.SpaceOperations
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
|
import scientifik.kmath.real.asVector
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import kotlin.math.floor
|
import kotlin.math.floor
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val c
|
|||||||
|
|
||||||
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
||||||
|
|
||||||
override fun contains(vector: Point<out T>): Boolean = def.contains(vector)
|
override fun contains(point: Point<T>): Boolean = def.contains(point)
|
||||||
|
|
||||||
override val dimension: Int
|
override val dimension: Int
|
||||||
get() = def.center.size
|
get() = def.center.size
|
||||||
@ -50,7 +50,7 @@ class RealHistogram(
|
|||||||
override val dimension: Int get() = lower.size
|
override val dimension: Int get() = lower.size
|
||||||
|
|
||||||
|
|
||||||
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||||
|
|
||||||
init {
|
init {
|
||||||
// argument checks
|
// argument checks
|
||||||
|
@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
|
|||||||
|
|
||||||
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
|
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
|
||||||
|
|
||||||
override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0])
|
override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
|
||||||
|
|
||||||
internal operator fun inc() = this.also { counter.increment() }
|
internal operator fun inc() = this.also { counter.increment() }
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ interface MemorySpec<T : Any> {
|
|||||||
val objectSize: Int
|
val objectSize: Int
|
||||||
|
|
||||||
fun MemoryReader.read(offset: Int): T
|
fun MemoryReader.read(offset: Int): T
|
||||||
|
//TODO consider thread safety
|
||||||
fun MemoryWriter.write(offset: Int, value: T)
|
fun MemoryWriter.write(offset: Int, value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,10 +3,12 @@ pluginManagement {
|
|||||||
val toolsVersion = "0.5.0"
|
val toolsVersion = "0.5.0"
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
|
id("kotlinx.benchmark") version "0.2.0-dev-8"
|
||||||
id("scientifik.mpp") version toolsVersion
|
id("scientifik.mpp") version toolsVersion
|
||||||
id("scientifik.jvm") version toolsVersion
|
id("scientifik.jvm") version toolsVersion
|
||||||
id("scientifik.atomic") version toolsVersion
|
id("scientifik.atomic") version toolsVersion
|
||||||
id("scientifik.publish") version toolsVersion
|
id("scientifik.publish") version toolsVersion
|
||||||
|
kotlin("plugin.allopen") version "1.3.72"
|
||||||
}
|
}
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
|
Loading…
Reference in New Issue
Block a user