Dev #127

Merged
altavir merged 214 commits from dev into master 2020-08-11 08:33:21 +03:00
52 changed files with 1043 additions and 509 deletions
Showing only changes of commit a71c02e9ed - Show all commits

View File

@ -2,7 +2,7 @@ plugins {
id("scientifik.publish") apply false id("scientifik.publish") apply false
} }
val kmathVersion by extra("0.1.4-dev-7") val kmathVersion by extra("0.1.4-dev-8")
val bintrayRepo by extra("scientifik") val bintrayRepo by extra("scientifik")
val githubProject by extra("kmath") val githubProject by extra("kmath")

View File

@ -2,7 +2,7 @@
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
There are different types of buffers: There are different types of buffers:
* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. * Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays.
* Boxing `ListBuffer` wrapping a list * Boxing `ListBuffer` wrapping a list
* Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value
* `MemoryBuffer` allows direct allocation of objects in continuous memory block. * `MemoryBuffer` allows direct allocation of objects in continuous memory block.

View File

@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
plugins { plugins {
java java
kotlin("jvm") kotlin("jvm")
kotlin("plugin.allopen") version "1.3.71" kotlin("plugin.allopen") version "1.3.72"
id("kotlinx.benchmark") version "0.2.0-dev-7" id("kotlinx.benchmark") version "0.2.0-dev-8"
} }
configure<AllOpenExtension> { configure<AllOpenExtension> {
@ -24,6 +24,7 @@ sourceSets {
} }
dependencies { dependencies {
implementation(project(":kmath-ast"))
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))
@ -33,8 +34,8 @@ dependencies {
implementation(project(":kmath-dimensions")) implementation(project(":kmath-dimensions"))
implementation("com.kyonifer:koma-core-ejml:0.12") implementation("com.kyonifer:koma-core-ejml:0.12")
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
"benchmarksCompile"(sourceSets.main.get().compileClasspath) "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
} }
// Configure benchmark // Configure benchmark

View File

@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex
class BufferBenchmark { class BufferBenchmark {
@Benchmark @Benchmark
fun genericDoubleBufferReadWrite() { fun genericRealBufferReadWrite() {
val buffer = DoubleBuffer(size){it.toDouble()} val buffer = RealBuffer(size){it.toDouble()}
(0 until size).forEach { (0 until size).forEach {
buffer[it] buffer[it]

View File

@ -20,48 +20,39 @@ class ViktorBenchmark {
final val viktorField = ViktorNDField(intArrayOf(dim, dim)) final val viktorField = ViktorNDField(intArrayOf(dim, dim))
@Benchmark @Benchmark
fun `Automatic field addition`() { fun automaticFieldAddition() {
autoField.run { autoField.run {
var res = one var res = one
repeat(n) { repeat(n) { res += one }
res += 1.0
}
} }
} }
@Benchmark @Benchmark
fun `Viktor field addition`() { fun viktorFieldAddition() {
viktorField.run { viktorField.run {
var res = one var res = one
repeat(n) { repeat(n) { res += one }
res += one
}
} }
} }
@Benchmark @Benchmark
fun `Raw Viktor`() { fun rawViktor() {
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
var res = one var res = one
repeat(n) { repeat(n) { res = res + one }
res = res + one
}
} }
@Benchmark @Benchmark
fun `Real field log`() { fun realdFieldLog() {
realField.run { realField.run {
val fortyTwo = produce { 42.0 } val fortyTwo = produce { 42.0 }
var res = one var res = one
repeat(n) { res = ln(fortyTwo) }
repeat(n) {
res = ln(fortyTwo)
}
} }
} }
@Benchmark @Benchmark
fun `Raw Viktor log`() { fun rawViktorLog() {
val fortyTwo = F64Array.full(dim, dim, init = 42.0) val fortyTwo = F64Array.full(dim, dim, init = 42.0)
var res: F64Array var res: F64Array
repeat(n) { repeat(n) {

View File

@ -0,0 +1,70 @@
package scientifik.kmath.ast
import scientifik.kmath.asm.compile
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.expressionInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import kotlin.random.Random
import kotlin.system.measureTimeMillis
class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField
fun functionalExpression() {
val expr = algebra.expressionInField {
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
}
invokeAndSum(expr)
}
fun mstExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}
invokeAndSum(expr)
}
fun asmExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}.compile()
invokeAndSum(expr)
}
private fun invokeAndSum(expr: Expression<Double>) {
val random = Random(0)
var sum = 0.0
repeat(1000000) {
sum += expr("x" to random.nextDouble())
}
println(sum)
}
}
fun main() {
val benchmark = ExpressionsInterpretersBenchmark()
val fe = measureTimeMillis {
benchmark.functionalExpression()
}
println("fe=$fe")
val mst = measureTimeMillis {
benchmark.mstExpression()
}
println("mst=$mst")
val asm = measureTimeMillis {
benchmark.asmExpression()
}
println("asm=$asm")
}

View File

@ -6,7 +6,7 @@ fun main(args: Array<String>) {
val n = 6000 val n = 6000
val array = DoubleArray(n * n) { 1.0 } val array = DoubleArray(n * n) { 1.0 }
val buffer = DoubleBuffer(array) val buffer = RealBuffer(array)
val strides = DefaultStrides(intArrayOf(n, n)) val strides = DefaultStrides(intArrayOf(n, n))
val structure = BufferNDStructure(strides, buffer) val structure = BufferNDStructure(strides, buffer)

View File

@ -26,10 +26,10 @@ fun main(args: Array<String>) {
} }
println("Array mapping finished in $time2 millis") println("Array mapping finished in $time2 millis")
val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 }) val buffer = RealBuffer(DoubleArray(n * n) { 1.0 })
val time3 = measureTimeMillis { val time3 = measureTimeMillis {
val target = DoubleBuffer(DoubleArray(n * n)) val target = RealBuffer(DoubleArray(n * n))
val res = array.forEachIndexed { index, value -> val res = array.forEachIndexed { index, value ->
target[index] = value + 1 target[index] = value + 1
} }

View File

@ -24,6 +24,7 @@ For example, the following builder:
package scientifik.kmath.asm.generated; package scientifik.kmath.asm.generated;
import java.util.Map; import java.util.Map;
import scientifik.kmath.asm.internal.MapIntrinsics;
import scientifik.kmath.expressions.Expression; import scientifik.kmath.expressions.Expression;
import scientifik.kmath.operations.RealField; import scientifik.kmath.operations.RealField;
@ -37,23 +38,23 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression<Doub
} }
public final Double invoke(Map<String, ? extends Double> arguments) { public final Double invoke(Map<String, ? extends Double> arguments) {
return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D);
} }
} }
``` ```
### Example Usage ### Example Usage
This API is an extension to MST and MSTExpression APIs. You may optimize both MST and MSTExpression: This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression:
```kotlin ```kotlin
RealField.mstInField { symbol("x") + 2 }.compile() RealField.mstInField { symbol("x") + 2 }.compile()
RealField.expression("2+2".parseMath()) RealField.expression("x+2".parseMath())
``` ```
### Known issues ### Known issues
- Using numeric algebras causes boxing and calling bridge methods.
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid - The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
class loading overhead. class loading overhead.
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. - This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.

View File

@ -1,76 +0,0 @@
package scientifik.kmath.ast
import scientifik.kmath.operations.*
object MSTAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST =
MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MST.Binary(operation, left, right)
}
object MSTSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0)
override fun number(value: Number): MST = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun add(a: MST, b: MST): MST =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MSTAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg)
}
object MSTRing : Ring<MST>, NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value)
override val zero: MST = MSTSpace.number(0.0)
override val one: MST = number(1.0)
override fun add(a: MST, b: MST): MST =
MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
override fun multiply(a: MST, b: MST): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MSTAlgebra.binaryOperation(operation, left, right)
}
object MSTField : Field<MST>{
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun number(value: Number): MST = MST.Numeric(value)
override val zero: MST = MSTSpace.number(0.0)
override val one: MST = number(1.0)
override fun add(a: MST, b: MST): MST =
MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k))
override fun multiply(a: MST, b: MST): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun divide(a: MST, b: MST): MST =
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MSTAlgebra.binaryOperation(operation, left, right)
}

View File

@ -1,55 +0,0 @@
package scientifik.kmath.ast
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.FunctionalExpressionField
import scientifik.kmath.expressions.FunctionalExpressionRing
import scientifik.kmath.expressions.FunctionalExpressionSpace
import scientifik.kmath.operations.*
/**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
*/
class MSTExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
/**
* Substitute algebra raw value
*/
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T>{
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right)
override fun number(value: Number): T = if(algebra is NumericAlgebra){
algebra.number(value)
} else{
error("Numeric nodes are not supported by $this")
}
}
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
}
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E,
block: E.() -> MST
): MSTExpression<T> = MSTExpression(this, mstAlgebra.block())
inline fun <reified T : Any> Space<T>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
MSTExpression(this, MSTSpace.block())
inline fun <reified T : Any> Ring<T>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
MSTExpression(this, MSTRing.block())
inline fun <reified T : Any> Field<T>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
MSTExpression(this, MSTField.block())
inline fun <reified T: Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MSTSpace.() -> MST): MSTExpression<T> =
algebra.mstInSpace(block)
inline fun <reified T: Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MSTRing.() -> MST): MSTExpression<T> =
algebra.mstInRing(block)
inline fun <reified T: Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MSTField.() -> MST): MSTExpression<T> =
algebra.mstInField(block)

View File

@ -0,0 +1,72 @@
package scientifik.kmath.ast
import scientifik.kmath.operations.*
object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST =
MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MST.Binary(operation, left, right)
}
object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0)
override fun number(value: Number): MST = MstAlgebra.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
}
object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0)
override val one: MST = number(1.0)
override fun number(value: Number): MST = MstAlgebra.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
}
object MstField : Field<MST> {
override val zero: MST = number(0.0)
override val one: MST = number(1.0)
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
override fun number(value: Number): MST = MstAlgebra.number(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
}

View File

@ -0,0 +1,55 @@
package scientifik.kmath.ast
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.FunctionalExpressionField
import scientifik.kmath.expressions.FunctionalExpressionRing
import scientifik.kmath.expressions.FunctionalExpressionSpace
import scientifik.kmath.operations.*
/**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
*/
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
/**
* Substitute algebra raw value
*/
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun number(value: Number): T = if (algebra is NumericAlgebra)
algebra.number(value)
else
error("Numeric nodes are not supported by $this")
}
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
}
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E,
block: E.() -> MST
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
MstExpression(this, MstSpace.block())
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
MstExpression(this, MstRing.block())
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
MstExpression(this, MstField.block())
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
algebra.mstInSpace(block)
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
algebra.mstInRing(block)
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> =
algebra.mstInField(block)

View File

@ -16,15 +16,15 @@ import scientifik.kmath.operations.SpaceOperations
* TODO move to common * TODO move to common
*/ */
private object ArithmeticsEvaluator : Grammar<MST>() { private object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?") val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
val lpar by token("\\(") val lpar by token("\\(".toRegex())
val rpar by token("\\)") val rpar by token("\\)".toRegex())
val mul by token("\\*") val mul by token("\\*".toRegex())
val pow by token("\\^") val pow by token("\\^".toRegex())
val div by token("/") val div by token("/".toRegex())
val minus by token("-") val minus by token("-".toRegex())
val plus by token("\\+") val plus by token("\\+".toRegex())
val ws by token("\\s+", ignore = true) val ws by token("\\s+".toRegex(), ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) } val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }

View File

@ -1,12 +1,10 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import org.objectweb.asm.Type
import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.buildExpectationStack import scientifik.kmath.asm.internal.buildAlgebraOperationCall
import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.buildName
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MSTExpression import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra import scientifik.kmath.operations.NumericAlgebra
@ -29,43 +27,21 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
loadTConstant(constant) loadTConstant(constant)
} }
is MST.Unary -> { is MST.Unary -> buildAlgebraOperationCall(
loadAlgebra() context = algebra,
if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) name = node.operation,
visit(node.value) fallbackMethodName = "unaryOperation",
arity = 1
) { visit(node.value) }
if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( is MST.Binary -> buildAlgebraOperationCall(
owner = AsmBuilder.ALGEBRA_TYPE.internalName, context = algebra,
method = "unaryOperation", name = node.operation,
fallbackMethodName = "binaryOperation",
descriptor = Type.getMethodDescriptor( arity = 2
AsmBuilder.OBJECT_TYPE, ) {
AsmBuilder.STRING_TYPE,
AsmBuilder.OBJECT_TYPE
),
tArity = 1
)
}
is MST.Binary -> {
loadAlgebra()
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
method = "binaryOperation",
descriptor = Type.getMethodDescriptor(
AsmBuilder.OBJECT_TYPE,
AsmBuilder.STRING_TYPE,
AsmBuilder.OBJECT_TYPE,
AsmBuilder.OBJECT_TYPE
),
tArity = 2
)
} }
} }
} }
@ -79,6 +55,6 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this) inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
/** /**
* Optimize performance of an [MSTExpression] using ASM codegen * Optimize performance of an [MstExpression] using ASM codegen
*/ */
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra) inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)

View File

@ -1,8 +1,7 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.* import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.AALOAD import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Opcodes.RETURN
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
@ -18,6 +17,7 @@ import kotlin.reflect.KClass
* @param T the type of AsmExpression to unwrap. * @param T the type of AsmExpression to unwrap.
* @param algebra the algebra the applied AsmExpressions use. * @param algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class. * @param className the unique class name of new loaded class.
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
*/ */
internal class AsmBuilder<T> internal constructor( internal class AsmBuilder<T> internal constructor(
private val classOfT: KClass<*>, private val classOfT: KClass<*>,
@ -37,8 +37,19 @@ internal class AsmBuilder<T> internal constructor(
*/ */
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
/**
* ASM Type for [algebra]
*/
private val tAlgebraType: Type = algebra::class.asm private val tAlgebraType: Type = algebra::class.asm
/**
* ASM type for [T]
*/
internal val tType: Type = classOfT.asm internal val tType: Type = classOfT.asm
/**
* ASM type for new class
*/
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
/** /**
@ -60,15 +71,31 @@ internal class AsmBuilder<T> internal constructor(
* Method visitor of `invoke` method of the subclass. * Method visitor of `invoke` method of the subclass.
*/ */
private lateinit var invokeMethodVisitor: InstructionAdapter private lateinit var invokeMethodVisitor: InstructionAdapter
internal var primitiveMode = false
@Suppress("PropertyName") /**
internal var PRIMITIVE_MASK: Type = OBJECT_TYPE * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
*/
internal var primitiveMode: Boolean = false
@Suppress("PropertyName") /**
internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
private val typeStack = Stack<Type>() */
internal val expectationStack: Stack<Type> = Stack<Type>().apply { push(tType) } internal var primitiveMask: Type = OBJECT_TYPE
/**
* Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
*/
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
/**
* Stack of useful objects types on stack to verify types.
*/
private val typeStack: ArrayDeque<Type> = ArrayDeque()
/**
* Stack of useful objects types on stack expected by algebra calls.
*/
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>().apply { push(tType) }
/** /**
* The cache for instance built by this builder. * The cache for instance built by this builder.
@ -86,14 +113,14 @@ internal class AsmBuilder<T> internal constructor(
if (SIGNATURE_LETTERS.containsKey(classOfT)) { if (SIGNATURE_LETTERS.containsKey(classOfT)) {
primitiveMode = true primitiveMode = true
PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
PRIMITIVE_MASK_BOXED = tType primitiveMaskBoxed = tType
} }
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit( visit(
Opcodes.V1_8, V1_8,
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
classType.internalName, classType.internalName,
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
OBJECT_TYPE.internalName, OBJECT_TYPE.internalName,
@ -101,7 +128,7 @@ internal class AsmBuilder<T> internal constructor(
) )
visitField( visitField(
access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, access = ACC_PRIVATE or ACC_FINAL,
name = "algebra", name = "algebra",
descriptor = tAlgebraType.descriptor, descriptor = tAlgebraType.descriptor,
signature = null, signature = null,
@ -110,7 +137,7 @@ internal class AsmBuilder<T> internal constructor(
) )
visitField( visitField(
access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, access = ACC_PRIVATE or ACC_FINAL,
name = "constants", name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor, descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null, signature = null,
@ -119,7 +146,7 @@ internal class AsmBuilder<T> internal constructor(
) )
visitMethod( visitMethod(
Opcodes.ACC_PUBLIC, ACC_PUBLIC,
"<init>", "<init>",
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
null, null,
@ -159,7 +186,7 @@ internal class AsmBuilder<T> internal constructor(
} }
visitMethod( visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, ACC_PUBLIC or ACC_FINAL,
"invoke", "invoke",
Type.getMethodDescriptor(tType, MAP_TYPE), Type.getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
@ -195,7 +222,7 @@ internal class AsmBuilder<T> internal constructor(
} }
visitMethod( visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
"invoke", "invoke",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null, null,
@ -238,34 +265,43 @@ internal class AsmBuilder<T> internal constructor(
} }
/** /**
* Loads a constant from * Loads a [T] constant from [constants].
*/ */
internal fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop()!! val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT val mustBeBoxed = expectedType.sort == Type.OBJECT
loadNumberConstant(value as Number, mustBeBoxed) loadNumberConstant(value as Number, mustBeBoxed)
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
return return
} }
loadConstant(value as Any, tType) loadConstant(value as Any, tType)
} }
/**
* Boxes the current value and pushes it.
*/
private fun box(): Unit = invokeMethodVisitor.invokestatic( private fun box(): Unit = invokeMethodVisitor.invokestatic(
tType.internalName, tType.internalName,
"valueOf", "valueOf",
Type.getMethodDescriptor(tType, PRIMITIVE_MASK), Type.getMethodDescriptor(tType, primitiveMask),
false false
) )
/**
* Unboxes the current boxed value and pushes it.
*/
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
NUMBER_TYPE.internalName, NUMBER_TYPE.internalName,
NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK), NUMBER_CONVERTER_METHODS.getValue(primitiveMask),
Type.getMethodDescriptor(PRIMITIVE_MASK), Type.getMethodDescriptor(primitiveMask),
false false
) )
/**
* Loads [java.lang.Object] constant from constants.
*/
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
loadThis() loadThis()
@ -275,6 +311,9 @@ internal class AsmBuilder<T> internal constructor(
checkcast(type) checkcast(type)
} }
/**
* Loads this variable.
*/
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
/** /**
@ -305,46 +344,40 @@ internal class AsmBuilder<T> internal constructor(
} }
loadConstant(value, boxed) loadConstant(value, boxed)
if (!mustBeBoxed) unbox() if (!mustBeBoxed) unbox()
else invokeMethodVisitor.checkcast(tType) else invokeMethodVisitor.checkcast(tType)
} }
/** /**
* Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
* provided.
*/ */
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) load(invokeArgumentsVar, MAP_TYPE)
aconst(name)
if (defaultValue != null) { if (defaultValue != null)
loadStringConstant(name)
loadTConstant(defaultValue) loadTConstant(defaultValue)
else
aconst(null)
invokeinterface( invokestatic(
MAP_TYPE.internalName, MAP_INTRINSICS_TYPE.internalName,
"getOrDefault", "getOrFail",
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
) false
invokeMethodVisitor.checkcast(tType)
return
}
loadStringConstant(name)
invokeinterface(
MAP_TYPE.internalName,
"get",
Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
) )
invokeMethodVisitor.checkcast(tType) checkcast(tType)
val expectedType = expectationStack.pop()!!
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT) if (expectedType.sort == Type.OBJECT)
typeStack.push(tType) typeStack.push(tType)
else { else {
unbox() unbox()
typeStack.push(PRIMITIVE_MASK) typeStack.push(primitiveMask)
} }
} }
@ -358,7 +391,7 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
* called before the arguments and this operation. * called before the arguments and this operation.
* *
* The result is casted to [T] automatically. * The result is casted to [T] automatically.
@ -367,12 +400,12 @@ internal class AsmBuilder<T> internal constructor(
owner: String, owner: String,
method: String, method: String,
descriptor: String, descriptor: String,
tArity: Int, expectedArity: Int,
opcode: Int = Opcodes.INVOKEINTERFACE opcode: Int = INVOKEINTERFACE
) { ) {
run loop@{ run loop@{
repeat(tArity) { repeat(expectedArity) {
if (typeStack.empty()) return@loop if (typeStack.isEmpty()) return@loop
typeStack.pop() typeStack.pop()
} }
} }
@ -382,18 +415,18 @@ internal class AsmBuilder<T> internal constructor(
owner, owner,
method, method,
descriptor, descriptor,
opcode == Opcodes.INVOKEINTERFACE opcode == INVOKEINTERFACE
) )
invokeMethodVisitor.checkcast(tType) invokeMethodVisitor.checkcast(tType)
val isLastExpr = expectationStack.size == 1 val isLastExpr = expectationStack.size == 1
val expectedType = expectationStack.pop()!! val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT || isLastExpr) if (expectedType.sort == Type.OBJECT || isLastExpr)
typeStack.push(tType) typeStack.push(tType)
else { else {
unbox() unbox()
typeStack.push(PRIMITIVE_MASK) typeStack.push(primitiveMask)
} }
} }
@ -404,7 +437,7 @@ internal class AsmBuilder<T> internal constructor(
internal companion object { internal companion object {
/** /**
* Maps JVM primitive numbers boxed types to their letters of JVM signature convention. * Maps JVM primitive numbers boxed types to their primitive ASM types.
*/ */
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy { private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
hashMapOf( hashMapOf(
@ -417,8 +450,14 @@ internal class AsmBuilder<T> internal constructor(
) )
} }
/**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
/**
* Maps primitive ASM types to [Number] functions unboxing them.
*/
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy { private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
hashMapOf( hashMapOf(
Type.BYTE_TYPE to "byteValue", Type.BYTE_TYPE to "byteValue",
@ -434,14 +473,46 @@ internal class AsmBuilder<T> internal constructor(
* Provides boxed number types values of which can be stored in JVM bytecode constant pool. * Provides boxed number types values of which can be stored in JVM bytecode constant pool.
*/ */
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys } private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
/**
* ASM type for [Expression].
*/
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
/**
* ASM type for [java.lang.Number].
*/
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
/**
* ASM type for [java.util.Map].
*/
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
/**
* ASM type for [java.lang.Object].
*/
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
/**
* ASM type for array of [java.lang.Object].
*/
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm } internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
/**
* ASM type for [Algebra].
*/
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
/**
* ASM type for [java.lang.String].
*/
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
/**
* ASM type for MapIntrinsics.
*/
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
} }
} }

View File

@ -1,22 +0,0 @@
package scientifik.kmath.asm.internal
import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression
/**
* Creates a class name for [Expression] subclassed to implement [mst] provided.
*
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
*/
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(mst, collision + 1)
}

View File

@ -1,17 +0,0 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.FieldVisitor
import org.objectweb.asm.MethodVisitor
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
ClassWriter(flags).apply(block)
internal inline fun ClassWriter.visitField(
access: Int,
name: String,
descriptor: String,
signature: String?,
value: Any?,
block: FieldVisitor.() -> Unit
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)

View File

@ -1,7 +0,0 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Type
import kotlin.reflect.KClass
internal val KClass<*>.asm: Type
get() = Type.getType(java)

View File

@ -0,0 +1,148 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import kotlin.reflect.KClass
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
hashMapOf(
"+" to 2 to "add",
"*" to 2 to "multiply",
"/" to 2 to "divide",
"+" to 1 to "unaryPlus",
"-" to 1 to "unaryMinus",
"-" to 2 to "minus"
)
}
internal val KClass<*>.asm: Type
get() = Type.getType(java)
/**
* Creates an [InstructionAdapter] from this [MethodVisitor].
*/
private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
/**
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
*/
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
instructionAdapter().apply(block)
/**
* Constructs a [Label], then applies it to this visitor.
*/
internal fun MethodVisitor.label(): Label {
val l = Label()
visitLabel(l)
return l
}
/**
* Creates a class name for [Expression] subclassed to implement [mst] provided.
*
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
*/
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(mst, collision + 1)
}
@Suppress("FunctionName")
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
ClassWriter(flags).apply(block)
internal inline fun ClassWriter.visitField(
access: Int,
name: String,
descriptor: String,
signature: String?,
value: Any?,
block: FieldVisitor.() -> Unit
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
* type expectation stack for needed arity.
*
* @return `true` if contains, else `false`.
*/
private fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
val theName = methodNameAdapters[name to arity] ?: name
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
repeat(arity) { expectationStack.push(t) }
return hasSpecific
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method.
*
* @return `true` if contains, else `false`.
*/
private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val theName = methodNameAdapters[name to arity] ?: name
context.javaClass.methods.find {
var suitableSignature = it.name == theName && it.parameters.size == arity
if (primitiveMode && it.isBridge)
suitableSignature = false
suitableSignature
} ?: return false
val owner = context::class.asm
invokeAlgebraOperation(
owner = owner.internalName,
method = theName,
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }),
expectedArity = arity,
opcode = INVOKEVIRTUAL
)
return true
}
/**
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
*/
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
context: Algebra<T>,
name: String,
fallbackMethodName: String,
arity: Int,
parameters: AsmBuilder<T>.() -> Unit
) {
loadAlgebra()
if (!buildExpectationStack(context, name, arity)) loadStringConstant(name)
parameters()
if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
method = fallbackMethodName,
descriptor = Type.getMethodDescriptor(
AsmBuilder.OBJECT_TYPE,
AsmBuilder.STRING_TYPE,
*Array(arity) { AsmBuilder.OBJECT_TYPE }
),
expectedArity = arity
)
}

View File

@ -0,0 +1,7 @@
@file:JvmName("MapIntrinsics")
package scientifik.kmath.asm.internal
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
return this[key] ?: default ?: error("Parameter not found: $key")
}

View File

@ -1,9 +0,0 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.commons.InstructionAdapter
internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
instructionAdapter().apply(block)

View File

@ -1,61 +0,0 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> by lazy {
hashMapOf(
"+" to "add",
"*" to "multiply",
"/" to "divide"
)
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
* type expectation stack for needed arity.
*
* @return `true` if contains, else `false`.
*/
internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name
val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null
val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType
repeat(arity) { expectationStack.push(t) }
return hasSpecific
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method.
*
* @return `true` if contains, else `false`.
*/
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name
val method =
context.javaClass.methods.find {
var suitableSignature = it.name == aName && it.parameters.size == arity
if (primitiveMode && it.isBridge)
suitableSignature = false
suitableSignature
} ?: return false
val owner = context::class.java.name.replace('.', '/')
invokeAlgebraOperation(
owner = owner,
method = aName,
descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }),
tArity = arity,
opcode = Opcodes.INVOKEVIRTUAL
)
return true
}

View File

@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class TestAsmAlgebras { internal class TestAsmAlgebras {
@Test @Test
fun space() { fun space() {
val res1 = ByteRing.mstInSpace { val res1 = ByteRing.mstInSpace {
@ -92,8 +92,8 @@ class TestAsmAlgebras {
"+", "+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
1 / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one
) ) + zero
}("x" to 2.0) }("x" to 2.0)
val res2 = RealField.mstInField { val res2 = RealField.mstInField {
@ -101,8 +101,8 @@ class TestAsmAlgebras {
"+", "+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
1 / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one
) ) + zero
}.compile()("x" to 2.0) }.compile()("x" to 2.0)
assertEquals(res1, res2) assertEquals(res1, res2)

View File

@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class TestAsmExpressions { internal class TestAsmExpressions {
@Test @Test
fun testUnaryOperationInvocation() { fun testUnaryOperationInvocation() {
val expression = RealField.mstInSpace { -symbol("x") }.compile() val expression = RealField.mstInSpace { -symbol("x") }.compile()

View File

@ -0,0 +1,46 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.mstInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmSpecialization {
@Test
fun testUnaryPlus() {
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
assertEquals(2.0, expr("x" to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
assertEquals(-2.0, expr("x" to 2.0))
}
@Test
fun testAdd() {
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
assertEquals(4.0, expr("x" to 2.0))
}
@Test
fun testSine() {
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 0.0))
}
@Test
fun testMinus() {
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 2.0))
}
@Test
fun testDivide() {
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
assertEquals(1.0, expr("x" to 2.0))
}
}

View File

@ -0,0 +1,22 @@
package scietifik.kmath.asm
import scientifik.kmath.ast.mstInRing
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class TestAsmVariables {
@Test
fun testVariableWithoutDefault() {
val expr = ByteRing.mstInRing { symbol("x") }
assertEquals(1.toByte(), expr("x" to 1.toByte()))
}
@Test
fun testVariableWithoutDefaultFails() {
val expr = ByteRing.mstInRing { symbol("x") }
assertFailsWith<IllegalStateException> { expr() }
}
}

View File

@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class AsmTest { internal class AsmTest {
@Test @Test
fun `compile MST`() { fun `compile MST`() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()

View File

@ -0,0 +1,38 @@
package scientifik.kmath.commons.random
import scientifik.kmath.prob.RandomGenerator
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
org.apache.commons.math3.random.RandomGenerator {
private var generator = factory(intArrayOf())
override fun nextBoolean(): Boolean = generator.nextBoolean()
override fun nextFloat(): Float = generator.nextDouble().toFloat()
override fun setSeed(seed: Int) {
generator = factory(intArrayOf(seed))
}
override fun setSeed(seed: IntArray) {
generator = factory(seed)
}
override fun setSeed(seed: Long) {
setSeed(seed.toInt())
}
override fun nextBytes(bytes: ByteArray) {
generator.fillBytes(bytes)
}
override fun nextInt(): Int = generator.nextInt()
override fun nextInt(n: Int): Int = generator.nextInt(n)
override fun nextGaussian(): Double = TODO()
override fun nextDouble(): Double = generator.nextDouble()
override fun nextLong(): Long = generator.nextLong()
}

View File

@ -18,7 +18,7 @@ object Transformations {
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> = private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
private fun Buffer<Double>.asArray() = if (this is DoubleBuffer) { private fun Buffer<Double>.asArray() = if (this is RealBuffer) {
array array
} else { } else {
DoubleArray(size) { i -> get(i) } DoubleArray(size) { i -> get(i) }

View File

@ -0,0 +1,15 @@
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
/**
* A simple geometric domain
*/
interface Domain<T : Any> {
operator fun contains(point: Point<T>): Boolean
/**
* Number of hyperspace dimensions
*/
val dimension: Int
}

View File

@ -0,0 +1,67 @@
/*
* Copyright 2015 Alexander Nozik.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.indices
/**
*
* HyperSquareDomain class.
*
* @author Alexander Nozik
*/
class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
point[i] in lower[i]..upper[i]
}
override val dimension: Int get() = lower.size
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
override fun getLowerBound(num: Int): Double? = lower[num]
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
override fun getUpperBound(num: Int): Double? = upper[num]
override fun nearestInDomain(point: Point<Double>): Point<Double> {
val res: DoubleArray = DoubleArray(point.size) { i ->
when {
point[i] < lower[i] -> lower[i]
point[i] > upper[i] -> upper[i]
else -> point[i]
}
}
return RealBuffer(*res)
}
override fun volume(): Double {
var res = 1.0
for (i in 0 until dimension) {
if (lower[i].isInfinite() || upper[i].isInfinite()) {
return Double.POSITIVE_INFINITY
}
if (upper[i] > lower[i]) {
res *= upper[i] - lower[i]
}
}
return res
}
}

View File

@ -0,0 +1,65 @@
/*
* Copyright 2015 Alexander Nozik.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
/**
* n-dimensional volume
*
* @author Alexander Nozik
*/
interface RealDomain: Domain<Double> {
fun nearestInDomain(point: Point<Double>): Point<Double>
/**
* The lower edge for the domain going down from point
* @param num
* @param point
* @return
*/
fun getLowerBound(num: Int, point: Point<Double>): Double?
/**
* The upper edge of the domain going up from point
* @param num
* @param point
* @return
*/
fun getUpperBound(num: Int, point: Point<Double>): Double?
/**
* Global lower edge
* @param num
* @return
*/
fun getLowerBound(num: Int): Double?
/**
* Global upper edge
* @param num
* @return
*/
fun getUpperBound(num: Int): Double?
/**
* Hyper volume
* @return
*/
fun volume(): Double
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2015 Alexander Nozik.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
class UnconstrainedDomain(override val dimension: Int) : RealDomain {
override operator fun contains(point: Point<Double>): Boolean = true
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
override fun nearestInDomain(point: Point<Double>): Point<Double> = point
override fun volume(): Double = Double.POSITIVE_INFINITY
}

View File

@ -0,0 +1,48 @@
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
import scientifik.kmath.structures.asBuffer
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
operator fun contains(d: Double): Boolean = range.contains(d)
override operator fun contains(point: Point<Double>): Boolean {
require(point.size == 0)
return contains(point[0])
}
override fun nearestInDomain(point: Point<Double>): Point<Double> {
require(point.size == 1)
val value = point[0]
return when{
value in range -> point
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
else -> doubleArrayOf(range.start).asBuffer()
}
}
override fun getLowerBound(num: Int, point: Point<Double>): Double? {
require(num == 0)
return range.start
}
override fun getUpperBound(num: Int, point: Point<Double>): Double? {
require(num == 0)
return range.endInclusive
}
override fun getLowerBound(num: Int): Double? {
require(num == 0)
return range.start
}
override fun getUpperBound(num: Int): Double? {
require(num == 0)
return range.endInclusive
}
override fun volume(): Double = range.endInclusive - range.start
override val dimension: Int get() = 1
}

View File

@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext<Double, RealField> {
override val elementContext get() = RealField override val elementContext get() = RealField
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> { override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer) override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size,initializer)
} }
class BufferMatrix<T : Any>( class BufferMatrix<T : Any>(
@ -102,7 +102,7 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
val array = DoubleArray(this.rowNum * other.colNum) val array = DoubleArray(this.rowNum * other.colNum)
//convert to array to insure there is not memory indirection //convert to array to insure there is not memory indirection
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is DoubleBuffer) { fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is RealBuffer) {
array array
} else { } else {
DoubleArray(size) { get(it) } DoubleArray(size) { get(it) }
@ -119,6 +119,6 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
} }
} }
val buffer = DoubleBuffer(array) val buffer = RealBuffer(array)
return BufferMatrix(rowNum, other.colNum, buffer) return BufferMatrix(rowNum, other.colNum, buffer)
} }

View File

@ -37,9 +37,9 @@ interface Buffer<T> {
companion object { companion object {
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
val array = DoubleArray(size) { initializer(it) } val array = DoubleArray(size) { initializer(it) }
return DoubleBuffer(array) return RealBuffer(array)
} }
/** /**
@ -51,7 +51,7 @@ interface Buffer<T> {
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> { inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
//TODO add resolution based on Annotation or companion resolution //TODO add resolution based on Annotation or companion resolution
return when (type) { return when (type) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
@ -93,7 +93,7 @@ interface MutableBuffer<T> : Buffer<T> {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> { inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
return when (type) { return when (type) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
@ -109,12 +109,11 @@ interface MutableBuffer<T> : Buffer<T> {
auto(T::class, size, initializer) auto(T::class, size, initializer)
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double -> val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
DoubleBuffer(DoubleArray(size) { initializer(it) }) RealBuffer(DoubleArray(size) { initializer(it) })
} }
} }
} }
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> { inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
override val size: Int override val size: Int
@ -163,57 +162,6 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this) fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
override val size: Int get() = array.size
override fun get(index: Int): Short = array[index]
override fun set(index: Int, value: Short) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Short> = ShortBuffer(array.copyOf())
}
fun ShortArray.asBuffer() = ShortBuffer(this)
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
override val size: Int get() = array.size
override fun get(index: Int): Int = array[index]
override fun set(index: Int, value: Int) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Int> = IntBuffer(array.copyOf())
}
fun IntArray.asBuffer() = IntBuffer(this)
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
override val size: Int get() = array.size
override fun get(index: Int): Long = array[index]
override fun set(index: Int, value: Long) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Long> = LongBuffer(array.copyOf())
}
fun LongArray.asBuffer() = LongBuffer(this)
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> { inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size override val size: Int get() = buffer.size

View File

@ -0,0 +1,53 @@
package scientifik.kmath.structures
import kotlin.experimental.and
enum class ValueFlag(val mask: Byte) {
NAN(0b0000_0001),
MISSING(0b0000_0010),
NEGATIVE_INFINITY(0b0000_0100),
POSITIVE_INFINITY(0b0000_1000)
}
/**
* A buffer with flagged values
*/
interface FlaggedBuffer<T> : Buffer<T> {
fun getFlag(index: Int): Byte
}
/**
* The value is valid if all flags are down
*/
fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte()
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte()
fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING)
/**
* A real buffer which supports flags for each value like NaN or Missing
*/
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
init {
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
}
override fun getFlag(index: Int): Byte = flags[index]
override val size: Int get() = values.size
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
if (isValid(it)) values[it] else null
}.iterator()
}
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
for(i in indices){
if(isValid(i)){
block(values[i])
}
}
}

View File

@ -0,0 +1,20 @@
package scientifik.kmath.structures
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
override val size: Int get() = array.size
override fun get(index: Int): Int = array[index]
override fun set(index: Int, value: Int) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Int> =
IntBuffer(array.copyOf())
}
fun IntArray.asBuffer() = IntBuffer(this)

View File

@ -0,0 +1,19 @@
package scientifik.kmath.structures
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
override val size: Int get() = array.size
override fun get(index: Int): Long = array[index]
override fun set(index: Int, value: Long) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Long> =
LongBuffer(array.copyOf())
}
fun LongArray.asBuffer() = LongBuffer(this)

View File

@ -1,6 +1,6 @@
package scientifik.kmath.structures package scientifik.kmath.structures
inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> { inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Double = array[index] override fun get(index: Int): Double = array[index]
@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
override fun iterator() = array.iterator() override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Double> = override fun copy(): MutableBuffer<Double> =
DoubleBuffer(array.copyOf()) RealBuffer(array.copyOf())
} }
@Suppress("FunctionName") @Suppress("FunctionName")
inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) }) inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
@Suppress("FunctionName") @Suppress("FunctionName")
fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles) fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
/** /**
* Transform buffer of doubles into array for high performance operations * Transform buffer of doubles into array for high performance operations
*/ */
val MutableBuffer<out Double>.array: DoubleArray val MutableBuffer<out Double>.array: DoubleArray
get() = if (this is DoubleBuffer) { get() = if (this is RealBuffer) {
array array
} else { } else {
DoubleArray(size) { get(it) } DoubleArray(size) { get(it) }
} }
fun DoubleArray.asBuffer() = DoubleBuffer(this) fun DoubleArray.asBuffer() = RealBuffer(this)

View File

@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) :
override val one by lazy { produce { one } } override val one by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
DoubleBuffer(DoubleArray(size) { initializer(it) }) RealBuffer(DoubleArray(size) { initializer(it) })
/** /**
* Inline transform an NDStructure to * Inline transform an NDStructure to
@ -89,7 +89,7 @@ class RealNDField(override val shape: IntArray) :
*/ */
inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
return BufferedNDFieldElement(this, DoubleBuffer(array)) return BufferedNDFieldElement(this, RealBuffer(array))
} }
/** /**
@ -103,7 +103,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int
*/ */
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
return BufferedNDFieldElement(context, DoubleBuffer(array)) return BufferedNDFieldElement(context, RealBuffer(array))
} }
/** /**

View File

@ -0,0 +1,20 @@
package scientifik.kmath.structures
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
override val size: Int get() = array.size
override fun get(index: Int): Short = array[index]
override fun set(index: Int, value: Short) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Short> =
ShortBuffer(array.copyOf())
}
fun ShortArray.asBuffer() = ShortBuffer(this)

View File

@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.*
import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
/** /**
@ -45,7 +45,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
/** /**
* Specialized flow chunker for real buffer * Specialized flow chunker for real buffer
*/ */
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow { fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
require(bufferSize > 0) { "Resulting chunk size must be more than zero" } require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
if (this@chunked is BlockingRealChain) { if (this@chunked is BlockingRealChain) {
@ -61,13 +61,13 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
array[counter] = element array[counter] = element
counter++ counter++
if (counter == bufferSize) { if (counter == bufferSize) {
val buffer = DoubleBuffer(array) val buffer = RealBuffer(array)
emit(buffer) emit(buffer)
counter = 0 counter = 0
} }
} }
if (counter > 0) { if (counter > 0) {
emit(DoubleBuffer(counter) { array[it] }) emit(RealBuffer(counter) { array[it] })
} }
} }
} }

View File

@ -7,31 +7,28 @@ import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.SpaceElement import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
import scientifik.kmath.structures.asIterable import scientifik.kmath.structures.asIterable
import kotlin.math.sqrt import kotlin.math.sqrt
typealias RealPoint = Point<Double>
fun DoubleArray.asVector() = RealVector(this.asBuffer()) fun DoubleArray.asVector() = RealVector(this.asBuffer())
fun List<Double>.asVector() = RealVector(this.asBuffer()) fun List<Double>.asVector() = RealVector(this.asBuffer())
object VectorL2Norm : Norm<Point<out Number>, Double> { object VectorL2Norm : Norm<Point<out Number>, Double> {
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
} }
inline class RealVector(private val point: Point<Double>) : inline class RealVector(private val point: Point<Double>) :
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> { SpaceElement<RealPoint, RealVector, VectorSpace<Double, RealField>>, RealPoint {
override val context: VectorSpace<Double, RealField> override val context: VectorSpace<Double, RealField> get() = space(point.size)
get() = space(
point.size
)
override fun unwrap(): Point<Double> = point override fun unwrap(): RealPoint = point
override fun Point<Double>.wrap(): RealVector = override fun RealPoint.wrap(): RealVector = RealVector(this)
RealVector(this)
override val size: Int get() = point.size override val size: Int get() = point.size
@ -44,16 +41,12 @@ inline class RealVector(private val point: Point<Double>) :
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>() private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
RealVector(DoubleBuffer(dim, initializer)) RealVector(RealBuffer(dim, initializer))
operator fun invoke(vararg values: Double): RealVector = values.asVector() operator fun invoke(vararg values: Double): RealVector = values.asVector()
fun space(dim: Int): BufferVectorSpace<Double, RealField> = fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) {
spaceCache.getOrPut(dim) { BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
BufferVectorSpace( }
dim,
RealField
) { size, init -> Buffer.real(size, init) }
}
} }
} }

View File

@ -1,8 +1,8 @@
package scientifik.kmath.real package scientifik.kmath.real
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
/** /**
* Simplified [DoubleBuffer] to array comparison * Simplified [RealBuffer] to array comparison
*/ */
fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles)

View File

@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext
import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.linear.VirtualMatrix
import scientifik.kmath.operations.sum import scientifik.kmath.operations.sum
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.asIterable import scientifik.kmath.structures.asIterable
import kotlin.math.pow import kotlin.math.pow
@ -27,6 +27,10 @@ typealias RealMatrix = Matrix<Double>
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
MatrixContext.real.produce(rowNum, colNum, initializer) MatrixContext.real.produce(rowNum, colNum, initializer)
fun Array<DoubleArray>.toMatrix(): RealMatrix{
return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] }
}
fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let { fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let {
MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] }
} }
@ -129,22 +133,22 @@ fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix = fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix =
extractColumns(columnIndex..columnIndex) extractColumns(columnIndex..columnIndex)
fun Matrix<Double>.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
val column = columns[j] val column = columns[j]
with(elementContext) { with(elementContext) {
sum(column.asIterable()) sum(column.asIterable())
} }
} }
fun Matrix<Double>.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column")
} }
fun Matrix<Double>.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column")
} }
fun Matrix<Double>.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().average() columns[j].asIterable().average()
} }

View File

@ -1,17 +1,9 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.domains.Domain
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
/**
* A simple geometric domain
* TODO move to geometry module
*/
interface Domain<T : Any> {
operator fun contains(vector: Point<out T>): Boolean
val dimension: Int
}
/** /**
* The bin in the histogram. The histogram is by definition always done in the real space * The bin in the histogram. The histogram is by definition always done in the real space
@ -51,9 +43,9 @@ interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point)) fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
fun MutableHistogram<Double, *>.put(vararg point: Number) = fun MutableHistogram<Double, *>.put(vararg point: Number) =
put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(DoubleBuffer(point)) fun MutableHistogram<Double, *>.put(vararg point: Double) = put(RealBuffer(point))
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) } fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }

View File

@ -1,8 +1,8 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.real.asVector
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.real.asVector
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.math.floor import kotlin.math.floor
@ -21,7 +21,7 @@ data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val c
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> { class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
override fun contains(vector: Point<out T>): Boolean = def.contains(vector) override fun contains(point: Point<T>): Boolean = def.contains(point)
override val dimension: Int override val dimension: Int
get() = def.center.size get() = def.center.size
@ -50,7 +50,7 @@ class RealHistogram(
override val dimension: Int get() = lower.size override val dimension: Int get() = lower.size
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
init { init {
// argument checks // argument checks

View File

@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0]) override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
internal operator fun inc() = this.also { counter.increment() } internal operator fun inc() = this.also { counter.increment() }

View File

@ -10,6 +10,7 @@ interface MemorySpec<T : Any> {
val objectSize: Int val objectSize: Int
fun MemoryReader.read(offset: Int): T fun MemoryReader.read(offset: Int): T
//TODO consider thread safety
fun MemoryWriter.write(offset: Int, value: T) fun MemoryWriter.write(offset: Int, value: T)
} }

View File

@ -3,10 +3,12 @@ pluginManagement {
val toolsVersion = "0.5.0" val toolsVersion = "0.5.0"
plugins { plugins {
id("kotlinx.benchmark") version "0.2.0-dev-8"
id("scientifik.mpp") version toolsVersion id("scientifik.mpp") version toolsVersion
id("scientifik.jvm") version toolsVersion id("scientifik.jvm") version toolsVersion
id("scientifik.atomic") version toolsVersion id("scientifik.atomic") version toolsVersion
id("scientifik.publish") version toolsVersion id("scientifik.publish") version toolsVersion
kotlin("plugin.allopen") version "1.3.72"
} }
repositories { repositories {