Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116
@ -2,7 +2,7 @@ plugins {
|
||||
id("scientifik.publish") apply false
|
||||
}
|
||||
|
||||
val kmathVersion by extra("0.1.4-dev-7")
|
||||
val kmathVersion by extra("0.1.4-dev-8")
|
||||
|
||||
val bintrayRepo by extra("scientifik")
|
||||
val githubProject by extra("kmath")
|
||||
|
@ -2,7 +2,7 @@
|
||||
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
|
||||
There are different types of buffers:
|
||||
|
||||
* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays.
|
||||
* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays.
|
||||
* Boxing `ListBuffer` wrapping a list
|
||||
* Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value
|
||||
* `MemoryBuffer` allows direct allocation of objects in continuous memory block.
|
||||
|
@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
||||
plugins {
|
||||
java
|
||||
kotlin("jvm")
|
||||
kotlin("plugin.allopen") version "1.3.71"
|
||||
id("kotlinx.benchmark") version "0.2.0-dev-7"
|
||||
kotlin("plugin.allopen") version "1.3.72"
|
||||
id("kotlinx.benchmark") version "0.2.0-dev-8"
|
||||
}
|
||||
|
||||
configure<AllOpenExtension> {
|
||||
@ -24,6 +24,7 @@ sourceSets {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation(project(":kmath-ast"))
|
||||
implementation(project(":kmath-core"))
|
||||
implementation(project(":kmath-coroutines"))
|
||||
implementation(project(":kmath-commons"))
|
||||
@ -33,8 +34,8 @@ dependencies {
|
||||
implementation(project(":kmath-dimensions"))
|
||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7")
|
||||
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
|
||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
|
||||
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
|
||||
}
|
||||
|
||||
// Configure benchmark
|
||||
|
@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex
|
||||
class BufferBenchmark {
|
||||
|
||||
@Benchmark
|
||||
fun genericDoubleBufferReadWrite() {
|
||||
val buffer = DoubleBuffer(size){it.toDouble()}
|
||||
fun genericRealBufferReadWrite() {
|
||||
val buffer = RealBuffer(size){it.toDouble()}
|
||||
|
||||
(0 until size).forEach {
|
||||
buffer[it]
|
||||
|
@ -20,48 +20,39 @@ class ViktorBenchmark {
|
||||
final val viktorField = ViktorNDField(intArrayOf(dim, dim))
|
||||
|
||||
@Benchmark
|
||||
fun `Automatic field addition`() {
|
||||
fun automaticFieldAddition() {
|
||||
autoField.run {
|
||||
var res = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
}
|
||||
repeat(n) { res += one }
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun `Viktor field addition`() {
|
||||
fun viktorFieldAddition() {
|
||||
viktorField.run {
|
||||
var res = one
|
||||
repeat(n) {
|
||||
res += one
|
||||
}
|
||||
repeat(n) { res += one }
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun `Raw Viktor`() {
|
||||
fun rawViktor() {
|
||||
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
|
||||
var res = one
|
||||
repeat(n) {
|
||||
res = res + one
|
||||
}
|
||||
repeat(n) { res = res + one }
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun `Real field log`() {
|
||||
fun realdFieldLog() {
|
||||
realField.run {
|
||||
val fortyTwo = produce { 42.0 }
|
||||
var res = one
|
||||
|
||||
repeat(n) {
|
||||
res = ln(fortyTwo)
|
||||
}
|
||||
repeat(n) { res = ln(fortyTwo) }
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun `Raw Viktor log`() {
|
||||
fun rawViktorLog() {
|
||||
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
||||
var res: F64Array
|
||||
repeat(n) {
|
||||
|
@ -0,0 +1,70 @@
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.expressions.expressionInField
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.random.Random
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
class ExpressionsInterpretersBenchmark {
|
||||
private val algebra: Field<Double> = RealField
|
||||
fun functionalExpression() {
|
||||
val expr = algebra.expressionInField {
|
||||
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
||||
}
|
||||
|
||||
invokeAndSum(expr)
|
||||
}
|
||||
|
||||
fun mstExpression() {
|
||||
val expr = algebra.mstInField {
|
||||
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||
}
|
||||
|
||||
invokeAndSum(expr)
|
||||
}
|
||||
|
||||
fun asmExpression() {
|
||||
val expr = algebra.mstInField {
|
||||
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||
}.compile()
|
||||
|
||||
invokeAndSum(expr)
|
||||
}
|
||||
|
||||
private fun invokeAndSum(expr: Expression<Double>) {
|
||||
val random = Random(0)
|
||||
var sum = 0.0
|
||||
|
||||
repeat(1000000) {
|
||||
sum += expr("x" to random.nextDouble())
|
||||
}
|
||||
|
||||
println(sum)
|
||||
}
|
||||
}
|
||||
|
||||
fun main() {
|
||||
val benchmark = ExpressionsInterpretersBenchmark()
|
||||
|
||||
val fe = measureTimeMillis {
|
||||
benchmark.functionalExpression()
|
||||
}
|
||||
|
||||
println("fe=$fe")
|
||||
|
||||
val mst = measureTimeMillis {
|
||||
benchmark.mstExpression()
|
||||
}
|
||||
|
||||
println("mst=$mst")
|
||||
|
||||
val asm = measureTimeMillis {
|
||||
benchmark.asmExpression()
|
||||
}
|
||||
|
||||
println("asm=$asm")
|
||||
}
|
@ -27,7 +27,7 @@ fun main() {
|
||||
|
||||
val complexTime = measureTimeMillis {
|
||||
complexField.run {
|
||||
var res = one
|
||||
var res: NDBuffer<Complex> = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
}
|
||||
|
@ -23,9 +23,9 @@ fun main() {
|
||||
|
||||
measureAndPrint("Automatic field addition") {
|
||||
autoField.run {
|
||||
var res = one
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
res += number(1.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -63,7 +63,7 @@ fun main() {
|
||||
genericField.run {
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) {
|
||||
res += 1.0
|
||||
res += one // con't avoid using `one` due to resolution ambiguity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ fun main(args: Array<String>) {
|
||||
val n = 6000
|
||||
|
||||
val array = DoubleArray(n * n) { 1.0 }
|
||||
val buffer = DoubleBuffer(array)
|
||||
val buffer = RealBuffer(array)
|
||||
val strides = DefaultStrides(intArrayOf(n, n))
|
||||
|
||||
val structure = BufferNDStructure(strides, buffer)
|
||||
|
@ -26,10 +26,10 @@ fun main(args: Array<String>) {
|
||||
}
|
||||
println("Array mapping finished in $time2 millis")
|
||||
|
||||
val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 })
|
||||
val buffer = RealBuffer(DoubleArray(n * n) { 1.0 })
|
||||
|
||||
val time3 = measureTimeMillis {
|
||||
val target = DoubleBuffer(DoubleArray(n * n))
|
||||
val target = RealBuffer(DoubleArray(n * n))
|
||||
val res = array.forEachIndexed { index, value ->
|
||||
target[index] = value + 1
|
||||
}
|
||||
|
62
kmath-ast/README.md
Normal file
62
kmath-ast/README.md
Normal file
@ -0,0 +1,62 @@
|
||||
# AST-based expression representation and operations (`kmath-ast`)
|
||||
|
||||
This subproject implements the following features:
|
||||
|
||||
- Expression Language and its parser.
|
||||
- MST as expression language's syntax intermediate representation.
|
||||
- Type-safe builder of MST.
|
||||
- Evaluating expressions by traversing MST.
|
||||
|
||||
## Dynamic expression code generation with OW2 ASM
|
||||
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
|
||||
For example, the following builder:
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
```
|
||||
|
||||
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||
|
||||
```java
|
||||
package scientifik.kmath.asm.generated;
|
||||
|
||||
import java.util.Map;
|
||||
import scientifik.kmath.asm.internal.MapIntrinsics;
|
||||
import scientifik.kmath.expressions.Expression;
|
||||
import scientifik.kmath.operations.RealField;
|
||||
|
||||
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
|
||||
private final RealField algebra;
|
||||
private final Object[] constants;
|
||||
|
||||
public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) {
|
||||
this.algebra = algebra;
|
||||
this.constants = constants;
|
||||
}
|
||||
|
||||
public final Double invoke(Map<String, ? extends Double> arguments) {
|
||||
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D);
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Example Usage
|
||||
|
||||
This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression:
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
RealField.expression("x+2".parseMath())
|
||||
```
|
||||
|
||||
### Known issues
|
||||
|
||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
||||
class loading overhead.
|
||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
||||
|
||||
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).
|
37
kmath-ast/build.gradle.kts
Normal file
37
kmath-ast/build.gradle.kts
Normal file
@ -0,0 +1,37 @@
|
||||
plugins {
|
||||
id("scientifik.mpp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
// all {
|
||||
// languageSettings.apply{
|
||||
// enableLanguageFeature("NewInference")
|
||||
// }
|
||||
// }
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3")
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
|
||||
}
|
||||
}
|
||||
|
||||
jvmMain {
|
||||
dependencies {
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
|
||||
implementation("org.ow2.asm:asm:8.0.1")
|
||||
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||
implementation(kotlin("reflect"))
|
||||
}
|
||||
}
|
||||
|
||||
jsMain {
|
||||
dependencies {
|
||||
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
|
||||
}
|
||||
}
|
||||
}
|
67
kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt
Normal file
67
kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt
Normal file
@ -0,0 +1,67 @@
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.NumericAlgebra
|
||||
import scientifik.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* A Mathematical Syntax Tree node for mathematical expressions
|
||||
*/
|
||||
sealed class MST {
|
||||
|
||||
/**
|
||||
* A node containing unparsed string
|
||||
*/
|
||||
data class Symbolic(val value: String) : MST()
|
||||
|
||||
/**
|
||||
* A node containing a number
|
||||
*/
|
||||
data class Numeric(val value: Number) : MST()
|
||||
|
||||
/**
|
||||
* A node containing an unary operation
|
||||
*/
|
||||
data class Unary(val operation: String, val value: MST) : MST() {
|
||||
companion object {
|
||||
const val ABS_OPERATION = "abs"
|
||||
//TODO add operations
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A node containing binary operation
|
||||
*/
|
||||
data class Binary(val operation: String, val left: MST, val right: MST) : MST() {
|
||||
companion object
|
||||
}
|
||||
}
|
||||
|
||||
//TODO add a function with positional arguments
|
||||
|
||||
//TODO add a function with named arguments
|
||||
|
||||
fun <T> Algebra<T>.evaluate(node: MST): T {
|
||||
return when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
node.left.value.toDouble(),
|
||||
node.right.value.toDouble()
|
||||
)
|
||||
number(number)
|
||||
}
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
@ -0,0 +1,72 @@
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
object MstAlgebra : NumericAlgebra<MST> {
|
||||
override fun number(value: Number): MST = MST.Numeric(value)
|
||||
|
||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST =
|
||||
MST.Unary(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MST.Binary(operation, left, right)
|
||||
}
|
||||
|
||||
object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
|
||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||
|
||||
override fun add(a: MST, b: MST): MST =
|
||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
object MstField : Field<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
}
|
@ -0,0 +1,55 @@
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.expressions.FunctionalExpressionField
|
||||
import scientifik.kmath.expressions.FunctionalExpressionRing
|
||||
import scientifik.kmath.expressions.FunctionalExpressionSpace
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
/**
|
||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
|
||||
*/
|
||||
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
|
||||
|
||||
/**
|
||||
* Substitute algebra raw value
|
||||
*/
|
||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
||||
algebra.number(value)
|
||||
else
|
||||
error("Numeric nodes are not supported by $this")
|
||||
}
|
||||
|
||||
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||
}
|
||||
|
||||
|
||||
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||
mstAlgebra: E,
|
||||
block: E.() -> MST
|
||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||
|
||||
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||
MstExpression(this, MstSpace.block())
|
||||
|
||||
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||
MstExpression(this, MstRing.block())
|
||||
|
||||
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||
MstExpression(this, MstField.block())
|
||||
|
||||
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
|
||||
algebra.mstInSpace(block)
|
||||
|
||||
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
|
||||
algebra.mstInRing(block)
|
||||
|
||||
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> =
|
||||
algebra.mstInField(block)
|
@ -0,0 +1,59 @@
|
||||
package scientifik.kmath.ast
|
||||
|
||||
import com.github.h0tk3y.betterParse.combinators.*
|
||||
import com.github.h0tk3y.betterParse.grammar.Grammar
|
||||
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
||||
import com.github.h0tk3y.betterParse.grammar.parser
|
||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||
import com.github.h0tk3y.betterParse.parser.Parser
|
||||
import scientifik.kmath.operations.FieldOperations
|
||||
import scientifik.kmath.operations.PowerOperations
|
||||
import scientifik.kmath.operations.RingOperations
|
||||
import scientifik.kmath.operations.SpaceOperations
|
||||
|
||||
/**
|
||||
* TODO move to common
|
||||
*/
|
||||
private object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
||||
val lpar by token("\\(".toRegex())
|
||||
val rpar by token("\\)".toRegex())
|
||||
val mul by token("\\*".toRegex())
|
||||
val pow by token("\\^".toRegex())
|
||||
val div by token("/".toRegex())
|
||||
val minus by token("-".toRegex())
|
||||
val plus by token("\\+".toRegex())
|
||||
val ws by token("\\s+".toRegex(), ignore = true)
|
||||
|
||||
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||
|
||||
val term: Parser<MST> by number or
|
||||
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
|
||||
(skip(lpar) and parser(this::rootParser) and skip(rpar))
|
||||
|
||||
val powChain by leftAssociative(term, pow) { a, _, b ->
|
||||
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
||||
}
|
||||
|
||||
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
|
||||
if (op == div) {
|
||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||
} else {
|
||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
|
||||
if (op == plus) {
|
||||
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
} else {
|
||||
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
override val rootParser: Parser<MST> by subSumChain
|
||||
}
|
||||
|
||||
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
|
||||
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)
|
60
kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt
Normal file
60
kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt
Normal file
@ -0,0 +1,60 @@
|
||||
package scientifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.internal.AsmBuilder
|
||||
import scientifik.kmath.asm.internal.buildAlgebraOperationCall
|
||||
import scientifik.kmath.asm.internal.buildName
|
||||
import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.ast.MstExpression
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.NumericAlgebra
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* Compile given MST to an Expression using AST compiler
|
||||
*/
|
||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||
fun AsmBuilder<T>.visit(node: MST) {
|
||||
when (node) {
|
||||
is MST.Symbolic -> loadVariable(node.value)
|
||||
|
||||
is MST.Numeric -> {
|
||||
val constant = if (algebra is NumericAlgebra<T>)
|
||||
algebra.number(node.value)
|
||||
else
|
||||
error("Number literals are not supported in $algebra")
|
||||
|
||||
loadTConstant(constant)
|
||||
}
|
||||
|
||||
is MST.Unary -> buildAlgebraOperationCall(
|
||||
context = algebra,
|
||||
name = node.operation,
|
||||
fallbackMethodName = "unaryOperation",
|
||||
arity = 1
|
||||
) { visit(node.value) }
|
||||
|
||||
is MST.Binary -> buildAlgebraOperationCall(
|
||||
context = algebra,
|
||||
name = node.operation,
|
||||
fallbackMethodName = "binaryOperation",
|
||||
arity = 2
|
||||
) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||
}
|
||||
|
||||
/**
|
||||
* Compile an [MST] to ASM using given algebra
|
||||
*/
|
||||
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||
|
||||
/**
|
||||
* Optimize performance of an [MstExpression] using ASM codegen
|
||||
*/
|
||||
inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
@ -0,0 +1,518 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import org.objectweb.asm.*
|
||||
import org.objectweb.asm.Opcodes.*
|
||||
import org.objectweb.asm.commons.InstructionAdapter
|
||||
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||
import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import java.util.*
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
||||
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
||||
*
|
||||
* @param T the type of AsmExpression to unwrap.
|
||||
* @param algebra the algebra the applied AsmExpressions use.
|
||||
* @param className the unique class name of new loaded class.
|
||||
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
||||
*/
|
||||
internal class AsmBuilder<T> internal constructor(
|
||||
private val classOfT: KClass<*>,
|
||||
private val algebra: Algebra<T>,
|
||||
private val className: String,
|
||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||
) {
|
||||
/**
|
||||
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||
*/
|
||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||
}
|
||||
|
||||
/**
|
||||
* The instance of [ClassLoader] used by this builder.
|
||||
*/
|
||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||
|
||||
/**
|
||||
* ASM Type for [algebra]
|
||||
*/
|
||||
private val tAlgebraType: Type = algebra::class.asm
|
||||
|
||||
/**
|
||||
* ASM type for [T]
|
||||
*/
|
||||
internal val tType: Type = classOfT.asm
|
||||
|
||||
/**
|
||||
* ASM type for new class
|
||||
*/
|
||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||
|
||||
/**
|
||||
* Index of `this` variable in invoke method of the built subclass.
|
||||
*/
|
||||
private val invokeThisVar: Int = 0
|
||||
|
||||
/**
|
||||
* Index of `arguments` variable in invoke method of the built subclass.
|
||||
*/
|
||||
private val invokeArgumentsVar: Int = 1
|
||||
|
||||
/**
|
||||
* List of constants to provide to the subclass.
|
||||
*/
|
||||
private val constants: MutableList<Any> = mutableListOf()
|
||||
|
||||
/**
|
||||
* Method visitor of `invoke` method of the subclass.
|
||||
*/
|
||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||
|
||||
/**
|
||||
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||
*/
|
||||
internal var primitiveMode: Boolean = false
|
||||
|
||||
/**
|
||||
* Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||
*/
|
||||
internal var primitiveMask: Type = OBJECT_TYPE
|
||||
|
||||
/**
|
||||
* Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||
*/
|
||||
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
||||
|
||||
/**
|
||||
* Stack of useful objects types on stack to verify types.
|
||||
*/
|
||||
private val typeStack: ArrayDeque<Type> = ArrayDeque()
|
||||
|
||||
/**
|
||||
* Stack of useful objects types on stack expected by algebra calls.
|
||||
*/
|
||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>().apply { push(tType) }
|
||||
|
||||
/**
|
||||
* The cache for instance built by this builder.
|
||||
*/
|
||||
private var generatedInstance: Expression<T>? = null
|
||||
|
||||
/**
|
||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||
*
|
||||
* The built instance is cached.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun getInstance(): Expression<T> {
|
||||
generatedInstance?.let { return it }
|
||||
|
||||
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
||||
primitiveMode = true
|
||||
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
|
||||
primitiveMaskBoxed = tType
|
||||
}
|
||||
|
||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||
visit(
|
||||
V1_8,
|
||||
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
||||
classType.internalName,
|
||||
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
||||
OBJECT_TYPE.internalName,
|
||||
arrayOf(EXPRESSION_TYPE.internalName)
|
||||
)
|
||||
|
||||
visitField(
|
||||
access = ACC_PRIVATE or ACC_FINAL,
|
||||
name = "algebra",
|
||||
descriptor = tAlgebraType.descriptor,
|
||||
signature = null,
|
||||
value = null,
|
||||
block = FieldVisitor::visitEnd
|
||||
)
|
||||
|
||||
visitField(
|
||||
access = ACC_PRIVATE or ACC_FINAL,
|
||||
name = "constants",
|
||||
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
||||
signature = null,
|
||||
value = null,
|
||||
block = FieldVisitor::visitEnd
|
||||
)
|
||||
|
||||
visitMethod(
|
||||
ACC_PUBLIC,
|
||||
"<init>",
|
||||
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
|
||||
null,
|
||||
null
|
||||
).instructionAdapter {
|
||||
val thisVar = 0
|
||||
val algebraVar = 1
|
||||
val constantsVar = 2
|
||||
val l0 = label()
|
||||
load(thisVar, classType)
|
||||
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
||||
label()
|
||||
load(thisVar, classType)
|
||||
load(algebraVar, tAlgebraType)
|
||||
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||
label()
|
||||
load(thisVar, classType)
|
||||
load(constantsVar, OBJECT_ARRAY_TYPE)
|
||||
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||
label()
|
||||
visitInsn(RETURN)
|
||||
val l4 = label()
|
||||
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
||||
|
||||
visitLocalVariable(
|
||||
"algebra",
|
||||
tAlgebraType.descriptor,
|
||||
null,
|
||||
l0,
|
||||
l4,
|
||||
algebraVar
|
||||
)
|
||||
|
||||
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
||||
visitMaxs(0, 3)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL,
|
||||
"invoke",
|
||||
Type.getMethodDescriptor(tType, MAP_TYPE),
|
||||
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
||||
null
|
||||
).instructionAdapter {
|
||||
invokeMethodVisitor = this
|
||||
visitCode()
|
||||
val l0 = label()
|
||||
invokeLabel0Visitor()
|
||||
areturn(tType)
|
||||
val l1 = label()
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
classType.descriptor,
|
||||
null,
|
||||
l0,
|
||||
l1,
|
||||
invokeThisVar
|
||||
)
|
||||
|
||||
visitLocalVariable(
|
||||
"arguments",
|
||||
MAP_TYPE.descriptor,
|
||||
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
|
||||
l0,
|
||||
l1,
|
||||
invokeArgumentsVar
|
||||
)
|
||||
|
||||
visitMaxs(0, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||
"invoke",
|
||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||
null,
|
||||
null
|
||||
).instructionAdapter {
|
||||
val thisVar = 0
|
||||
val argumentsVar = 1
|
||||
visitCode()
|
||||
val l0 = label()
|
||||
load(thisVar, OBJECT_TYPE)
|
||||
load(argumentsVar, MAP_TYPE)
|
||||
invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false)
|
||||
areturn(tType)
|
||||
val l1 = label()
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
classType.descriptor,
|
||||
null,
|
||||
l0,
|
||||
l1,
|
||||
thisVar
|
||||
)
|
||||
|
||||
visitMaxs(0, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
val new = classLoader
|
||||
.defineClass(className, classWriter.toByteArray())
|
||||
.constructors
|
||||
.first()
|
||||
.newInstance(algebra, constants.toTypedArray()) as Expression<T>
|
||||
|
||||
generatedInstance = new
|
||||
return new
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a [T] constant from [constants].
|
||||
*/
|
||||
internal fun loadTConstant(value: T) {
|
||||
if (classOfT in INLINABLE_NUMBERS) {
|
||||
val expectedType = expectationStack.pop()
|
||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||
loadNumberConstant(value as Number, mustBeBoxed)
|
||||
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
||||
return
|
||||
}
|
||||
|
||||
loadConstant(value as Any, tType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Boxes the current value and pushes it.
|
||||
*/
|
||||
private fun box(): Unit = invokeMethodVisitor.invokestatic(
|
||||
tType.internalName,
|
||||
"valueOf",
|
||||
Type.getMethodDescriptor(tType, primitiveMask),
|
||||
false
|
||||
)
|
||||
|
||||
/**
|
||||
* Unboxes the current boxed value and pushes it.
|
||||
*/
|
||||
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
|
||||
NUMBER_TYPE.internalName,
|
||||
NUMBER_CONVERTER_METHODS.getValue(primitiveMask),
|
||||
Type.getMethodDescriptor(primitiveMask),
|
||||
false
|
||||
)
|
||||
|
||||
/**
|
||||
* Loads [java.lang.Object] constant from constants.
|
||||
*/
|
||||
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
loadThis()
|
||||
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||
iconst(idx)
|
||||
visitInsn(AALOAD)
|
||||
checkcast(type)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads this variable.
|
||||
*/
|
||||
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
|
||||
|
||||
/**
|
||||
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
|
||||
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
|
||||
* from it).
|
||||
*/
|
||||
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
||||
val boxed = value::class.asm
|
||||
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
||||
|
||||
if (primitive != null) {
|
||||
when (primitive) {
|
||||
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
||||
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
|
||||
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
|
||||
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
}
|
||||
|
||||
if (mustBeBoxed) {
|
||||
box()
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
loadConstant(value, boxed)
|
||||
|
||||
if (!mustBeBoxed) unbox()
|
||||
else invokeMethodVisitor.checkcast(tType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
|
||||
* provided.
|
||||
*/
|
||||
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||
load(invokeArgumentsVar, MAP_TYPE)
|
||||
aconst(name)
|
||||
|
||||
if (defaultValue != null)
|
||||
loadTConstant(defaultValue)
|
||||
else
|
||||
aconst(null)
|
||||
|
||||
invokestatic(
|
||||
MAP_INTRINSICS_TYPE.internalName,
|
||||
"getOrFail",
|
||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
|
||||
false
|
||||
)
|
||||
|
||||
checkcast(tType)
|
||||
|
||||
val expectedType = expectationStack.pop()
|
||||
|
||||
if (expectedType.sort == Type.OBJECT)
|
||||
typeStack.push(tType)
|
||||
else {
|
||||
unbox()
|
||||
typeStack.push(primitiveMask)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads algebra from according field of the class and casts it to class of [algebra] provided.
|
||||
*/
|
||||
internal fun loadAlgebra() {
|
||||
loadThis()
|
||||
invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
|
||||
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
|
||||
* called before the arguments and this operation.
|
||||
*
|
||||
* The result is casted to [T] automatically.
|
||||
*/
|
||||
internal fun invokeAlgebraOperation(
|
||||
owner: String,
|
||||
method: String,
|
||||
descriptor: String,
|
||||
expectedArity: Int,
|
||||
opcode: Int = INVOKEINTERFACE
|
||||
) {
|
||||
run loop@{
|
||||
repeat(expectedArity) {
|
||||
if (typeStack.isEmpty()) return@loop
|
||||
typeStack.pop()
|
||||
}
|
||||
}
|
||||
|
||||
invokeMethodVisitor.visitMethodInsn(
|
||||
opcode,
|
||||
owner,
|
||||
method,
|
||||
descriptor,
|
||||
opcode == INVOKEINTERFACE
|
||||
)
|
||||
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
val isLastExpr = expectationStack.size == 1
|
||||
val expectedType = expectationStack.pop()
|
||||
|
||||
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
||||
typeStack.push(tType)
|
||||
else {
|
||||
unbox()
|
||||
typeStack.push(primitiveMask)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes a LDC Instruction with string constant provided.
|
||||
*/
|
||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
||||
|
||||
internal companion object {
|
||||
/**
|
||||
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
||||
*/
|
||||
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
|
||||
hashMapOf(
|
||||
java.lang.Byte::class to Type.BYTE_TYPE,
|
||||
java.lang.Short::class to Type.SHORT_TYPE,
|
||||
java.lang.Integer::class to Type.INT_TYPE,
|
||||
java.lang.Long::class to Type.LONG_TYPE,
|
||||
java.lang.Float::class to Type.FLOAT_TYPE,
|
||||
java.lang.Double::class to Type.DOUBLE_TYPE
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||
*/
|
||||
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
||||
|
||||
/**
|
||||
* Maps primitive ASM types to [Number] functions unboxing them.
|
||||
*/
|
||||
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
|
||||
hashMapOf(
|
||||
Type.BYTE_TYPE to "byteValue",
|
||||
Type.SHORT_TYPE to "shortValue",
|
||||
Type.INT_TYPE to "intValue",
|
||||
Type.LONG_TYPE to "longValue",
|
||||
Type.FLOAT_TYPE to "floatValue",
|
||||
Type.DOUBLE_TYPE to "doubleValue"
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
||||
*/
|
||||
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||
|
||||
/**
|
||||
* ASM type for [Expression].
|
||||
*/
|
||||
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for [java.lang.Number].
|
||||
*/
|
||||
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for [java.util.Map].
|
||||
*/
|
||||
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for [java.lang.Object].
|
||||
*/
|
||||
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for array of [java.lang.Object].
|
||||
*/
|
||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
||||
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for [Algebra].
|
||||
*/
|
||||
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for [java.lang.String].
|
||||
*/
|
||||
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
||||
|
||||
/**
|
||||
* ASM type for MapIntrinsics.
|
||||
*/
|
||||
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
|
||||
}
|
||||
}
|
@ -0,0 +1,148 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import org.objectweb.asm.*
|
||||
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
||||
import org.objectweb.asm.commons.InstructionAdapter
|
||||
import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||
hashMapOf(
|
||||
"+" to 2 to "add",
|
||||
"*" to 2 to "multiply",
|
||||
"/" to 2 to "divide",
|
||||
"+" to 1 to "unaryPlus",
|
||||
"-" to 1 to "unaryMinus",
|
||||
"-" to 2 to "minus"
|
||||
)
|
||||
}
|
||||
|
||||
internal val KClass<*>.asm: Type
|
||||
get() = Type.getType(java)
|
||||
|
||||
/**
|
||||
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
||||
*/
|
||||
private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
|
||||
|
||||
/**
|
||||
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
|
||||
*/
|
||||
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
|
||||
instructionAdapter().apply(block)
|
||||
|
||||
/**
|
||||
* Constructs a [Label], then applies it to this visitor.
|
||||
*/
|
||||
internal fun MethodVisitor.label(): Label {
|
||||
val l = Label()
|
||||
visitLabel(l)
|
||||
return l
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
||||
*
|
||||
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
|
||||
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
||||
*/
|
||||
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(mst, collision + 1)
|
||||
}
|
||||
|
||||
@Suppress("FunctionName")
|
||||
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
|
||||
ClassWriter(flags).apply(block)
|
||||
|
||||
internal inline fun ClassWriter.visitField(
|
||||
access: Int,
|
||||
name: String,
|
||||
descriptor: String,
|
||||
signature: String?,
|
||||
value: Any?,
|
||||
block: FieldVisitor.() -> Unit
|
||||
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
|
||||
|
||||
/**
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
|
||||
* type expectation stack for needed arity.
|
||||
*
|
||||
* @return `true` if contains, else `false`.
|
||||
*/
|
||||
private fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||
val theName = methodNameAdapters[name to arity] ?: name
|
||||
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
|
||||
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
|
||||
repeat(arity) { expectationStack.push(t) }
|
||||
return hasSpecific
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
|
||||
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
||||
*
|
||||
* @return `true` if contains, else `false`.
|
||||
*/
|
||||
private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||
val theName = methodNameAdapters[name to arity] ?: name
|
||||
|
||||
context.javaClass.methods.find {
|
||||
var suitableSignature = it.name == theName && it.parameters.size == arity
|
||||
|
||||
if (primitiveMode && it.isBridge)
|
||||
suitableSignature = false
|
||||
|
||||
suitableSignature
|
||||
} ?: return false
|
||||
|
||||
val owner = context::class.asm
|
||||
|
||||
invokeAlgebraOperation(
|
||||
owner = owner.internalName,
|
||||
method = theName,
|
||||
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }),
|
||||
expectedArity = arity,
|
||||
opcode = INVOKEVIRTUAL
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
||||
*/
|
||||
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||
context: Algebra<T>,
|
||||
name: String,
|
||||
fallbackMethodName: String,
|
||||
arity: Int,
|
||||
parameters: AsmBuilder<T>.() -> Unit
|
||||
) {
|
||||
loadAlgebra()
|
||||
if (!buildExpectationStack(context, name, arity)) loadStringConstant(name)
|
||||
parameters()
|
||||
|
||||
if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||
method = fallbackMethodName,
|
||||
|
||||
descriptor = Type.getMethodDescriptor(
|
||||
AsmBuilder.OBJECT_TYPE,
|
||||
AsmBuilder.STRING_TYPE,
|
||||
*Array(arity) { AsmBuilder.OBJECT_TYPE }
|
||||
),
|
||||
|
||||
expectedArity = arity
|
||||
)
|
||||
}
|
||||
|
@ -0,0 +1,10 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import org.objectweb.asm.Label
|
||||
import org.objectweb.asm.commons.InstructionAdapter
|
||||
|
||||
internal fun InstructionAdapter.label(): Label {
|
||||
val l = Label()
|
||||
visitLabel(l)
|
||||
return l
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
@file:JvmName("MapIntrinsics")
|
||||
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
|
||||
return this[key] ?: default ?: error("Parameter not found: $key")
|
||||
}
|
@ -0,0 +1,110 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.ast.mstInRing
|
||||
import scientifik.kmath.ast.mstInSpace
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.ByteRing
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmAlgebras {
|
||||
@Test
|
||||
fun space() {
|
||||
val res1 = ByteRing.mstInSpace {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + symbol("x") + zero
|
||||
}("x" to 2.toByte())
|
||||
|
||||
val res2 = ByteRing.mstInSpace {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + symbol("x") + zero
|
||||
}.compile()("x" to 2.toByte())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun ring() {
|
||||
val res1 = ByteRing.mstInRing {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
(symbol("x") - (2.toByte() + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}("x" to 3.toByte())
|
||||
|
||||
val res2 = ByteRing.mstInRing {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
(symbol("x") - (2.toByte() + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}.compile()("x" to 3.toByte())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun field() {
|
||||
val res1 = RealField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||
"+",
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0)
|
||||
|
||||
val res2 = RealField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||
"+",
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0)
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.ast.mstInSpace
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmExpressions {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBinaryOperationInvocation() {
|
||||
val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile()
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-1.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
||||
assertEquals(2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
||||
assertEquals(-2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(1.0, expr("x" to 2.0))
|
||||
}
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.ast.mstInRing
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.ByteRing
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
internal class TestAsmVariables {
|
||||
@Test
|
||||
fun testVariableWithoutDefault() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }
|
||||
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testVariableWithoutDefaultFails() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }
|
||||
assertFailsWith<IllegalStateException> { expr() }
|
||||
}
|
||||
}
|
26
kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt
Normal file
26
kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt
Normal file
@ -0,0 +1,26 @@
|
||||
package scietifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.asm.expression
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.ast.parseMath
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.ComplexField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class AsmTest {
|
||||
@Test
|
||||
fun `compile MST`() {
|
||||
val mst = "2+2*(2+2)".parseMath()
|
||||
val res = ComplexField.expression(mst)()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compile MSTExpression`() {
|
||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package scietifik.kmath.ast
|
||||
|
||||
import scientifik.kmath.ast.evaluate
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.ast.parseMath
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.ComplexField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class ParserTest {
|
||||
@Test
|
||||
fun `evaluate MST`() {
|
||||
val mst = "2+2*(2+2)".parseMath()
|
||||
val res = ComplexField.evaluate(mst)
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `evaluate MSTExpression`() {
|
||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
}
|
@ -2,7 +2,7 @@ package scientifik.kmath.commons.expressions
|
||||
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.expressions.ExpressionContext
|
||||
import scientifik.kmath.expressions.ExpressionAlgebra
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.Field
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
@ -59,8 +59,10 @@ class DerivativeStructureField(
|
||||
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
||||
|
||||
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
||||
|
||||
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
||||
override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
|
||||
override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
|
||||
override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
|
||||
|
||||
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
|
||||
is Double -> arg.pow(pow)
|
||||
@ -74,10 +76,10 @@ class DerivativeStructureField(
|
||||
|
||||
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
||||
|
||||
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble())
|
||||
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble())
|
||||
operator fun Number.plus(s: DerivativeStructure) = s + this
|
||||
operator fun Number.minus(s: DerivativeStructure) = s - this
|
||||
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
||||
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||
override operator fun Number.plus(b: DerivativeStructure) = b + this
|
||||
override operator fun Number.minus(b: DerivativeStructure) = b - this
|
||||
}
|
||||
|
||||
/**
|
||||
@ -113,7 +115,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
||||
/**
|
||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||
*/
|
||||
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
||||
object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
||||
override fun variable(name: String, default: Double?) =
|
||||
DiffExpression { variable(name, default?.const()) }
|
||||
|
||||
@ -136,6 +138,3 @@ object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression>
|
||||
override fun divide(a: DiffExpression, b: DiffExpression) =
|
||||
DiffExpression { a.function(this) / b.function(this) }
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,38 @@
|
||||
package scientifik.kmath.commons.random
|
||||
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
|
||||
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
|
||||
org.apache.commons.math3.random.RandomGenerator {
|
||||
private var generator = factory(intArrayOf())
|
||||
|
||||
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||
|
||||
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||
|
||||
override fun setSeed(seed: Int) {
|
||||
generator = factory(intArrayOf(seed))
|
||||
}
|
||||
|
||||
override fun setSeed(seed: IntArray) {
|
||||
generator = factory(seed)
|
||||
}
|
||||
|
||||
override fun setSeed(seed: Long) {
|
||||
setSeed(seed.toInt())
|
||||
}
|
||||
|
||||
override fun nextBytes(bytes: ByteArray) {
|
||||
generator.fillBytes(bytes)
|
||||
}
|
||||
|
||||
override fun nextInt(): Int = generator.nextInt()
|
||||
|
||||
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
||||
|
||||
override fun nextGaussian(): Double = TODO()
|
||||
|
||||
override fun nextDouble(): Double = generator.nextDouble()
|
||||
|
||||
override fun nextLong(): Long = generator.nextLong()
|
||||
}
|
@ -18,7 +18,7 @@ object Transformations {
|
||||
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
|
||||
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
|
||||
|
||||
private fun Buffer<Double>.asArray() = if (this is DoubleBuffer) {
|
||||
private fun Buffer<Double>.asArray() = if (this is RealBuffer) {
|
||||
array
|
||||
} else {
|
||||
DoubleArray(size) { i -> get(i) }
|
||||
|
@ -0,0 +1,15 @@
|
||||
package scientifik.kmath.domains
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
|
||||
/**
|
||||
* A simple geometric domain
|
||||
*/
|
||||
interface Domain<T : Any> {
|
||||
operator fun contains(point: Point<T>): Boolean
|
||||
|
||||
/**
|
||||
* Number of hyperspace dimensions
|
||||
*/
|
||||
val dimension: Int
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Copyright 2015 Alexander Nozik.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package scientifik.kmath.domains
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.structures.RealBuffer
|
||||
import scientifik.kmath.structures.indices
|
||||
|
||||
/**
|
||||
*
|
||||
* HyperSquareDomain class.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
|
||||
|
||||
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
|
||||
point[i] in lower[i]..upper[i]
|
||||
}
|
||||
|
||||
override val dimension: Int get() = lower.size
|
||||
|
||||
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
|
||||
|
||||
override fun getLowerBound(num: Int): Double? = lower[num]
|
||||
|
||||
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
|
||||
|
||||
override fun getUpperBound(num: Int): Double? = upper[num]
|
||||
|
||||
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||
val res: DoubleArray = DoubleArray(point.size) { i ->
|
||||
when {
|
||||
point[i] < lower[i] -> lower[i]
|
||||
point[i] > upper[i] -> upper[i]
|
||||
else -> point[i]
|
||||
}
|
||||
}
|
||||
return RealBuffer(*res)
|
||||
}
|
||||
|
||||
override fun volume(): Double {
|
||||
var res = 1.0
|
||||
for (i in 0 until dimension) {
|
||||
if (lower[i].isInfinite() || upper[i].isInfinite()) {
|
||||
return Double.POSITIVE_INFINITY
|
||||
}
|
||||
if (upper[i] > lower[i]) {
|
||||
res *= upper[i] - lower[i]
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Copyright 2015 Alexander Nozik.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package scientifik.kmath.domains
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
|
||||
/**
|
||||
* n-dimensional volume
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
interface RealDomain: Domain<Double> {
|
||||
|
||||
fun nearestInDomain(point: Point<Double>): Point<Double>
|
||||
|
||||
/**
|
||||
* The lower edge for the domain going down from point
|
||||
* @param num
|
||||
* @param point
|
||||
* @return
|
||||
*/
|
||||
fun getLowerBound(num: Int, point: Point<Double>): Double?
|
||||
|
||||
/**
|
||||
* The upper edge of the domain going up from point
|
||||
* @param num
|
||||
* @param point
|
||||
* @return
|
||||
*/
|
||||
fun getUpperBound(num: Int, point: Point<Double>): Double?
|
||||
|
||||
/**
|
||||
* Global lower edge
|
||||
* @param num
|
||||
* @return
|
||||
*/
|
||||
fun getLowerBound(num: Int): Double?
|
||||
|
||||
/**
|
||||
* Global upper edge
|
||||
* @param num
|
||||
* @return
|
||||
*/
|
||||
fun getUpperBound(num: Int): Double?
|
||||
|
||||
/**
|
||||
* Hyper volume
|
||||
* @return
|
||||
*/
|
||||
fun volume(): Double
|
||||
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright 2015 Alexander Nozik.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package scientifik.kmath.domains
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
|
||||
class UnconstrainedDomain(override val dimension: Int) : RealDomain {
|
||||
|
||||
override operator fun contains(point: Point<Double>): Boolean = true
|
||||
|
||||
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
|
||||
|
||||
override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
|
||||
|
||||
override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
|
||||
|
||||
override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
|
||||
|
||||
override fun nearestInDomain(point: Point<Double>): Point<Double> = point
|
||||
|
||||
override fun volume(): Double = Double.POSITIVE_INFINITY
|
||||
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
package scientifik.kmath.domains
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
|
||||
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
|
||||
|
||||
operator fun contains(d: Double): Boolean = range.contains(d)
|
||||
|
||||
override operator fun contains(point: Point<Double>): Boolean {
|
||||
require(point.size == 0)
|
||||
return contains(point[0])
|
||||
}
|
||||
|
||||
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||
require(point.size == 1)
|
||||
val value = point[0]
|
||||
return when{
|
||||
value in range -> point
|
||||
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
|
||||
else -> doubleArrayOf(range.start).asBuffer()
|
||||
}
|
||||
}
|
||||
|
||||
override fun getLowerBound(num: Int, point: Point<Double>): Double? {
|
||||
require(num == 0)
|
||||
return range.start
|
||||
}
|
||||
|
||||
override fun getUpperBound(num: Int, point: Point<Double>): Double? {
|
||||
require(num == 0)
|
||||
return range.endInclusive
|
||||
}
|
||||
|
||||
override fun getLowerBound(num: Int): Double? {
|
||||
require(num == 0)
|
||||
return range.start
|
||||
}
|
||||
|
||||
override fun getUpperBound(num: Int): Double? {
|
||||
require(num == 0)
|
||||
return range.endInclusive
|
||||
}
|
||||
|
||||
override fun volume(): Double = range.endInclusive - range.start
|
||||
|
||||
override val dimension: Int get() = 1
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
/**
|
||||
* Create a functional expression on this [Space]
|
||||
*/
|
||||
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionSpace(this).run(block)
|
||||
|
||||
/**
|
||||
* Create a functional expression on this [Ring]
|
||||
*/
|
||||
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionRing(this).run(block)
|
||||
|
||||
/**
|
||||
* Create a functional expression on this [Field]
|
||||
*/
|
||||
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionField(this).run(block)
|
@ -1,14 +1,21 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
/**
|
||||
* An elementary function that could be invoked on a map of arguments
|
||||
*/
|
||||
interface Expression<T> {
|
||||
operator fun invoke(arguments: Map<String, T>): T
|
||||
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* Create simple lazily evaluated expression inside given algebra
|
||||
*/
|
||||
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> = object: Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = block(arguments)
|
||||
}
|
||||
|
||||
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
||||
@ -16,77 +23,14 @@ operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke
|
||||
/**
|
||||
* A context for expression construction
|
||||
*/
|
||||
interface ExpressionContext<T> {
|
||||
interface ExpressionAlgebra<T, E> : Algebra<E> {
|
||||
/**
|
||||
* Introduce a variable into expression context
|
||||
*/
|
||||
fun variable(name: String, default: T? = null): Expression<T>
|
||||
fun variable(name: String, default: T? = null): E
|
||||
|
||||
/**
|
||||
* A constant expression which does not depend on arguments
|
||||
*/
|
||||
fun const(value: T): Expression<T>
|
||||
}
|
||||
|
||||
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||
}
|
||||
|
||||
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = value
|
||||
}
|
||||
|
||||
internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||
}
|
||||
|
||||
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
open class ExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionContext<T> {
|
||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
||||
|
||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
||||
|
||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
||||
|
||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
||||
|
||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
||||
|
||||
|
||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
||||
|
||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
||||
}
|
||||
|
||||
|
||||
class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionSpace<T>(field) {
|
||||
override val one: Expression<T> = ConstantExpression(field.one)
|
||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
||||
|
||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
||||
|
||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
||||
|
||||
operator fun T.times(arg: Expression<T>) = arg * this
|
||||
operator fun T.div(arg: Expression<T>) = arg / this
|
||||
fun const(value: T): E
|
||||
}
|
@ -0,0 +1,146 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
||||
Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class FunctionalBinaryOperation<T>(
|
||||
val context: Algebra<T>,
|
||||
val name: String,
|
||||
val first: Expression<T>,
|
||||
val second: Expression<T>
|
||||
) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T =
|
||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||
}
|
||||
|
||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = value
|
||||
}
|
||||
|
||||
internal class FunctionalConstProductExpression<T>(
|
||||
val context: Space<T>,
|
||||
private val expr: Expression<T>,
|
||||
val const: Number
|
||||
) : Expression<T> {
|
||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [Expression] construction.
|
||||
*
|
||||
* @param algebra The algebra to provide for Expressions built.
|
||||
*/
|
||||
abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) : ExpressionAlgebra<T, Expression<T>> {
|
||||
|
||||
/**
|
||||
* Builds an Expression of constant expression which does not depend on arguments.
|
||||
*/
|
||||
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
||||
|
||||
/**
|
||||
* Builds an Expression to access a variable.
|
||||
*/
|
||||
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
||||
|
||||
/**
|
||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||
*/
|
||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||
|
||||
/**
|
||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||
*/
|
||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
FunctionalUnaryOperation(algebra, operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [Expression] construction for [Space] algebras.
|
||||
*/
|
||||
open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||
|
||||
override val zero: Expression<T> get() = const(algebra.zero)
|
||||
|
||||
/**
|
||||
* Builds an Expression of addition of two another expressions.
|
||||
*/
|
||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
|
||||
|
||||
/**
|
||||
* Builds an Expression of multiplication of expression by number.
|
||||
*/
|
||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
||||
FunctionalConstProductExpression(algebra, a, k)
|
||||
|
||||
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||
override val one: Expression<T>
|
||||
get() = const(algebra.one)
|
||||
|
||||
/**
|
||||
* Builds an Expression of multiplication of two expressions.
|
||||
*/
|
||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
|
||||
|
||||
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
FunctionalExpressionRing<T, A>(algebra),
|
||||
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
||||
/**
|
||||
* Builds an Expression of division an expression by another one.
|
||||
*/
|
||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
|
||||
|
||||
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionSpace(this).block()
|
||||
|
||||
inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T, A>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionRing(this).block()
|
||||
|
||||
inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
|
||||
FunctionalExpressionField(this).block()
|
@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||
override val elementContext get() = RealField
|
||||
|
||||
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
|
||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer)
|
||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size,initializer)
|
||||
}
|
||||
|
||||
class BufferMatrix<T : Any>(
|
||||
@ -102,7 +102,7 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
|
||||
val array = DoubleArray(this.rowNum * other.colNum)
|
||||
|
||||
//convert to array to insure there is not memory indirection
|
||||
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is DoubleBuffer) {
|
||||
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is RealBuffer) {
|
||||
array
|
||||
} else {
|
||||
DoubleArray(size) { get(it) }
|
||||
@ -119,6 +119,6 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
|
||||
}
|
||||
}
|
||||
|
||||
val buffer = DoubleBuffer(array)
|
||||
val buffer = RealBuffer(array)
|
||||
return BufferMatrix(rowNum, other.colNum, buffer)
|
||||
}
|
@ -128,14 +128,14 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||
luRow[col] = sum
|
||||
|
||||
// maintain best permutation choice
|
||||
if (abs(sum) > largest) {
|
||||
largest = abs(sum)
|
||||
if (this@lup.abs(sum) > largest) {
|
||||
largest = this@lup.abs(sum)
|
||||
max = row
|
||||
}
|
||||
}
|
||||
|
||||
// Singularity check
|
||||
if (checkSingular(abs(lu[max, col]))) {
|
||||
if (checkSingular(this@lup.abs(lu[max, col]))) {
|
||||
error("The matrix is singular")
|
||||
}
|
||||
|
||||
|
@ -90,20 +90,20 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
operator fun Number.plus(that: Variable<T>): Variable<T> =
|
||||
derive(variable { this@plus.toDouble() * one + that.value }) { z ->
|
||||
that.d += z.d
|
||||
override operator fun Number.plus(b: Variable<T>): Variable<T> =
|
||||
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||
|
||||
operator fun Number.minus(that: Variable<T>): Variable<T> =
|
||||
derive(variable { this@minus.toDouble() * one - that.value }) { z ->
|
||||
that.d -= z.d
|
||||
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
||||
derive(variable { this@minus.toDouble() * one - b.value }) { z ->
|
||||
b.d -= z.d
|
||||
}
|
||||
|
||||
operator fun Variable<T>.minus(that: Number): Variable<T> =
|
||||
derive(variable { this@minus.value - one * that.toDouble() }) { z ->
|
||||
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
||||
derive(variable { this@minus.value - one * b.toDouble() }) { z ->
|
||||
this@minus.d += z.d
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
package scientifik.kmath.misc
|
||||
|
||||
import kotlin.math.abs
|
||||
|
||||
/**
|
||||
* Convert double range to sequence.
|
||||
*
|
||||
@ -8,8 +10,7 @@ package scientifik.kmath.misc
|
||||
*
|
||||
* If step is negative, the same goes from upper boundary downwards
|
||||
*/
|
||||
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> =
|
||||
when {
|
||||
fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Sequence<Double> = when {
|
||||
step == 0.0 -> error("Zero step in double progression")
|
||||
step > 0 -> sequence {
|
||||
var current = start
|
||||
@ -27,9 +28,18 @@ fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double>
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert double range to sequence with the fixed number of points
|
||||
*/
|
||||
fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Sequence<Double> {
|
||||
require(numPoints > 1) { "The number of points should be more than 2" }
|
||||
return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1))
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
|
||||
*/
|
||||
@Deprecated("Replace by 'toSequenceWithPoints'")
|
||||
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
||||
if (numPoints < 2) error("Can't create generic grid with less than two points")
|
||||
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
||||
|
@ -6,9 +6,43 @@ annotation class KMathContext
|
||||
/**
|
||||
* Marker interface for any algebra
|
||||
*/
|
||||
interface Algebra<T>
|
||||
interface Algebra<T> {
|
||||
/**
|
||||
* Wrap raw string or variable
|
||||
*/
|
||||
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
|
||||
|
||||
inline operator fun <T : Algebra<*>, R> T.invoke(block: T.() -> R): R = run(block)
|
||||
/**
|
||||
* Dynamic call of unary operation with name [operation] on [arg]
|
||||
*/
|
||||
fun unaryOperation(operation: String, arg: T): T
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right]
|
||||
*/
|
||||
fun binaryOperation(operation: String, left: T, right: T): T
|
||||
}
|
||||
|
||||
/**
|
||||
* An algebra with numeric representation of members
|
||||
*/
|
||||
interface NumericAlgebra<T> : Algebra<T> {
|
||||
/**
|
||||
* Wrap a number
|
||||
*/
|
||||
fun number(value: Number): T
|
||||
|
||||
fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
binaryOperation(operation, number(left), right)
|
||||
|
||||
fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
leftSideNumberOperation(operation, right, left)
|
||||
}
|
||||
|
||||
/**
|
||||
* Call a block with an [Algebra] as receiver
|
||||
*/
|
||||
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
|
||||
|
||||
/**
|
||||
* Space-like operations without neutral element
|
||||
@ -24,14 +58,34 @@ interface SpaceOperations<T> : Algebra<T> {
|
||||
*/
|
||||
fun multiply(a: T, k: Number): T
|
||||
|
||||
//Operation to be performed in this context
|
||||
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
|
||||
operator fun T.unaryMinus(): T = multiply(this, -1.0)
|
||||
|
||||
operator fun T.unaryPlus(): T = this
|
||||
|
||||
operator fun T.plus(b: T): T = add(this, b)
|
||||
operator fun T.minus(b: T): T = add(this, -b)
|
||||
operator fun T.times(k: Number) = multiply(this, k.toDouble())
|
||||
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
|
||||
operator fun Number.times(b: T) = b * this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
PLUS_OPERATION -> arg
|
||||
MINUS_OPERATION -> -arg
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
PLUS_OPERATION -> add(left, right)
|
||||
MINUS_OPERATION -> left - right
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val PLUS_OPERATION = "+"
|
||||
const val MINUS_OPERATION = "-"
|
||||
const val NOT_OPERATION = "!"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -60,22 +114,48 @@ interface RingOperations<T> : SpaceOperations<T> {
|
||||
fun multiply(a: T, b: T): T
|
||||
|
||||
operator fun T.times(b: T): T = multiply(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
TIMES_OPERATION -> multiply(left, right)
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val TIMES_OPERATION = "*"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The same as {@link Space} but with additional multiplication operation
|
||||
*/
|
||||
interface Ring<T> : Space<T>, RingOperations<T> {
|
||||
interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||
/**
|
||||
* neutral operation for multiplication
|
||||
*/
|
||||
val one: T
|
||||
|
||||
// operator fun T.plus(b: Number) = this.plus(b * one)
|
||||
// operator fun Number.plus(b: T) = b + this
|
||||
//
|
||||
// operator fun T.minus(b: Number) = this.minus(b * one)
|
||||
// operator fun Number.minus(b: T) = -b + this
|
||||
override fun number(value: Number): T = one * value.toDouble()
|
||||
|
||||
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.leftSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
|
||||
operator fun T.plus(b: Number) = this.plus(number(b))
|
||||
operator fun Number.plus(b: T) = b + this
|
||||
|
||||
operator fun T.minus(b: Number) = this.minus(number(b))
|
||||
operator fun Number.minus(b: T) = -b + this
|
||||
}
|
||||
|
||||
/**
|
||||
@ -85,6 +165,15 @@ interface FieldOperations<T> : RingOperations<T> {
|
||||
fun divide(a: T, b: T): T
|
||||
|
||||
operator fun T.div(b: T): T = divide(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
DIV_OPERATION -> divide(left, right)
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DIV_OPERATION = "/"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -8,10 +8,12 @@ import scientifik.memory.MemorySpec
|
||||
import scientifik.memory.MemoryWriter
|
||||
import kotlin.math.*
|
||||
|
||||
private val PI_DIV_2 = Complex(PI / 2, 0)
|
||||
|
||||
/**
|
||||
* A field for complex numbers
|
||||
*/
|
||||
object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
||||
object ComplexField : ExtendedField<Complex> {
|
||||
override val zero: Complex = Complex(0.0, 0.0)
|
||||
|
||||
override val one: Complex = Complex(1.0, 0.0)
|
||||
@ -30,9 +32,11 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
||||
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
|
||||
}
|
||||
|
||||
override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg))
|
||||
|
||||
override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
|
||||
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
|
||||
override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg)
|
||||
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg)
|
||||
override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2
|
||||
|
||||
override fun power(arg: Complex, pow: Number): Complex =
|
||||
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
|
||||
@ -50,6 +54,12 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
||||
operator fun Complex.minus(d: Double) = add(this, -d.toComplex())
|
||||
|
||||
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this)
|
||||
|
||||
override fun symbol(value: String): Complex = if (value == "i") {
|
||||
i
|
||||
} else {
|
||||
super.symbol(value)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -7,12 +7,32 @@ import kotlin.math.pow as kpow
|
||||
* Advanced Number-like field that implements basic operations
|
||||
*/
|
||||
interface ExtendedFieldOperations<T> :
|
||||
FieldOperations<T>,
|
||||
TrigonometricOperations<T>,
|
||||
InverseTrigonometricOperations<T>,
|
||||
PowerOperations<T>,
|
||||
ExponentialOperations<T>
|
||||
ExponentialOperations<T> {
|
||||
|
||||
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>
|
||||
override fun tan(arg: T): T = sin(arg) / cos(arg)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
TrigonometricOperations.COS_OPERATION -> cos(arg)
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(arg)
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(arg)
|
||||
InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg)
|
||||
InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg)
|
||||
InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg)
|
||||
PowerOperations.SQRT_OPERATION -> sqrt(arg)
|
||||
ExponentialOperations.EXP_OPERATION -> exp(arg)
|
||||
ExponentialOperations.LN_OPERATION -> ln(arg)
|
||||
else -> super.unaryOperation(operation, arg)
|
||||
}
|
||||
}
|
||||
|
||||
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> power(left, right)
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Real field element wrapping double.
|
||||
@ -44,6 +64,10 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
|
||||
override inline fun sin(arg: Double) = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Double) = kotlin.math.cos(arg)
|
||||
override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
|
||||
override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
|
||||
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
|
||||
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
|
||||
|
||||
override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble())
|
||||
|
||||
@ -75,6 +99,10 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
|
||||
override inline fun sin(arg: Float) = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Float) = kotlin.math.cos(arg)
|
||||
override inline fun tan(arg: Float) = kotlin.math.tan(arg)
|
||||
override inline fun acos(arg: Float) = kotlin.math.acos(arg)
|
||||
override inline fun asin(arg: Float) = kotlin.math.asin(arg)
|
||||
override inline fun atan(arg: Float) = kotlin.math.atan(arg)
|
||||
|
||||
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())
|
||||
|
||||
|
@ -13,16 +13,33 @@ package scientifik.kmath.operations
|
||||
interface TrigonometricOperations<T> : FieldOperations<T> {
|
||||
fun sin(arg: T): T
|
||||
fun cos(arg: T): T
|
||||
fun tan(arg: T): T
|
||||
|
||||
fun tg(arg: T): T = sin(arg) / cos(arg)
|
||||
companion object {
|
||||
const val SIN_OPERATION = "sin"
|
||||
const val COS_OPERATION = "cos"
|
||||
const val TAN_OPERATION = "tan"
|
||||
}
|
||||
}
|
||||
|
||||
fun ctg(arg: T): T = cos(arg) / sin(arg)
|
||||
interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
|
||||
fun asin(arg: T): T
|
||||
fun acos(arg: T): T
|
||||
fun atan(arg: T): T
|
||||
|
||||
companion object {
|
||||
const val ASIN_OPERATION = "asin"
|
||||
const val ACOS_OPERATION = "acos"
|
||||
const val ATAN_OPERATION = "atan"
|
||||
}
|
||||
}
|
||||
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> tg(arg: T): T = arg.context.tg(arg)
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.context.ctg(arg)
|
||||
fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
|
||||
fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
|
||||
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
|
||||
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
|
||||
|
||||
/* Power and roots */
|
||||
|
||||
@ -34,6 +51,11 @@ interface PowerOperations<T> : Algebra<T> {
|
||||
fun sqrt(arg: T) = power(arg, 0.5)
|
||||
|
||||
infix fun T.pow(pow: Number) = power(this, pow)
|
||||
|
||||
companion object {
|
||||
const val POW_OPERATION = "pow"
|
||||
const val SQRT_OPERATION = "sqrt"
|
||||
}
|
||||
}
|
||||
|
||||
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
||||
@ -45,6 +67,11 @@ fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
|
||||
interface ExponentialOperations<T> : Algebra<T> {
|
||||
fun exp(arg: T): T
|
||||
fun ln(arg: T): T
|
||||
|
||||
companion object {
|
||||
const val EXP_OPERATION = "exp"
|
||||
const val LN_OPERATION = "ln"
|
||||
}
|
||||
}
|
||||
|
||||
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
|
||||
|
@ -37,9 +37,9 @@ interface Buffer<T> {
|
||||
|
||||
companion object {
|
||||
|
||||
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer {
|
||||
inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
|
||||
val array = DoubleArray(size) { initializer(it) }
|
||||
return DoubleBuffer(array)
|
||||
return RealBuffer(array)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -51,7 +51,7 @@ interface Buffer<T> {
|
||||
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||
//TODO add resolution based on Annotation or companion resolution
|
||||
return when (type) {
|
||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
|
||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
||||
@ -93,7 +93,7 @@ interface MutableBuffer<T> : Buffer<T> {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||
return when (type) {
|
||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
||||
@ -109,12 +109,11 @@ interface MutableBuffer<T> : Buffer<T> {
|
||||
auto(T::class, size, initializer)
|
||||
|
||||
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
||||
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
||||
|
||||
override val size: Int
|
||||
@ -163,57 +162,6 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
||||
|
||||
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
|
||||
|
||||
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Short = array[index]
|
||||
|
||||
override fun set(index: Int, value: Short) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Short> = ShortBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
fun ShortArray.asBuffer() = ShortBuffer(this)
|
||||
|
||||
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Int = array[index]
|
||||
|
||||
override fun set(index: Int, value: Int) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Int> = IntBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
fun IntArray.asBuffer() = IntBuffer(this)
|
||||
|
||||
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Long = array[index]
|
||||
|
||||
override fun set(index: Int, value: Long) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Long> = LongBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
fun LongArray.asBuffer() = LongBuffer(this)
|
||||
|
||||
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
|
||||
override val size: Int get() = buffer.size
|
||||
|
||||
|
@ -79,6 +79,13 @@ class ComplexNDField(override val shape: IntArray) :
|
||||
|
||||
override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) }
|
||||
|
||||
override fun tan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { tan(it) }
|
||||
|
||||
override fun asin(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { asin(it) }
|
||||
|
||||
override fun acos(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {acos(it)}
|
||||
|
||||
override fun atan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {atan(it)}
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,13 +1,8 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.*
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
|
||||
interface ExtendedNDField<T : Any, F, N : NDStructure<T>> :
|
||||
NDField<T, F, N>,
|
||||
TrigonometricOperations<N>,
|
||||
PowerOperations<N>,
|
||||
ExponentialOperations<N>
|
||||
where F : ExtendedFieldOperations<T>, F : Field<T>
|
||||
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
|
||||
|
||||
|
||||
///**
|
||||
|
@ -0,0 +1,53 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import kotlin.experimental.and
|
||||
|
||||
enum class ValueFlag(val mask: Byte) {
|
||||
NAN(0b0000_0001),
|
||||
MISSING(0b0000_0010),
|
||||
NEGATIVE_INFINITY(0b0000_0100),
|
||||
POSITIVE_INFINITY(0b0000_1000)
|
||||
}
|
||||
|
||||
/**
|
||||
* A buffer with flagged values
|
||||
*/
|
||||
interface FlaggedBuffer<T> : Buffer<T> {
|
||||
fun getFlag(index: Int): Byte
|
||||
}
|
||||
|
||||
/**
|
||||
* The value is valid if all flags are down
|
||||
*/
|
||||
fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte()
|
||||
|
||||
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte()
|
||||
|
||||
fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING)
|
||||
|
||||
/**
|
||||
* A real buffer which supports flags for each value like NaN or Missing
|
||||
*/
|
||||
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
||||
init {
|
||||
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
||||
}
|
||||
|
||||
override fun getFlag(index: Int): Byte = flags[index]
|
||||
|
||||
override val size: Int get() = values.size
|
||||
|
||||
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
||||
|
||||
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
|
||||
if (isValid(it)) values[it] else null
|
||||
}.iterator()
|
||||
}
|
||||
|
||||
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
||||
for(i in indices){
|
||||
if(isValid(i)){
|
||||
block(values[i])
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Int = array[index]
|
||||
|
||||
override fun set(index: Int, value: Int) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Int> =
|
||||
IntBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
|
||||
fun IntArray.asBuffer() = IntBuffer(this)
|
@ -0,0 +1,19 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Long = array[index]
|
||||
|
||||
override fun set(index: Int, value: Long) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Long> =
|
||||
LongBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
fun LongArray.asBuffer() = LongBuffer(this)
|
@ -1,6 +1,6 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Double = array[index]
|
||||
@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||
override fun iterator() = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Double> =
|
||||
DoubleBuffer(array.copyOf())
|
||||
RealBuffer(array.copyOf())
|
||||
}
|
||||
|
||||
@Suppress("FunctionName")
|
||||
inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) })
|
||||
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
|
||||
|
||||
@Suppress("FunctionName")
|
||||
fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles)
|
||||
fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
||||
|
||||
/**
|
||||
* Transform buffer of doubles into array for high performance operations
|
||||
*/
|
||||
val MutableBuffer<out Double>.array: DoubleArray
|
||||
get() = if (this is DoubleBuffer) {
|
||||
get() = if (this is RealBuffer) {
|
||||
array
|
||||
} else {
|
||||
DoubleArray(size) { get(it) }
|
||||
}
|
||||
|
||||
fun DoubleArray.asBuffer() = DoubleBuffer(this)
|
||||
fun DoubleArray.asBuffer() = RealBuffer(this)
|
@ -9,145 +9,172 @@ import kotlin.math.*
|
||||
* A simple field over linear buffers of [Double]
|
||||
*/
|
||||
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
|
||||
return if (a is RealBuffer && b is RealBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
||||
}
|
||||
RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
||||
} else
|
||||
RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
||||
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||
val kValue = k.toDouble()
|
||||
return if (a is DoubleBuffer) {
|
||||
|
||||
return if (a is RealBuffer) {
|
||||
val aArray = a.array
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] * kValue })
|
||||
}
|
||||
RealBuffer(DoubleArray(a.size) { aArray[it] * kValue })
|
||||
} else
|
||||
RealBuffer(DoubleArray(a.size) { a[it] * kValue })
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
|
||||
return if (a is RealBuffer && b is RealBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||
}
|
||||
RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
||||
} else
|
||||
RealBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||
}
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
|
||||
return if (a is RealBuffer && b is RealBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
||||
}
|
||||
RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
||||
} else
|
||||
RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
||||
}
|
||||
|
||||
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) })
|
||||
RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) })
|
||||
}
|
||||
RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
|
||||
}
|
||||
|
||||
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) })
|
||||
RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
|
||||
|
||||
override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
|
||||
|
||||
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) })
|
||||
}
|
||||
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
|
||||
}
|
||||
|
||||
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
||||
}
|
||||
}
|
||||
RealBuffer(DoubleArray(arg.size) { acos(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { acos(arg[it]) })
|
||||
|
||||
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
override fun atan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) })
|
||||
}
|
||||
}
|
||||
RealBuffer(DoubleArray(arg.size) { atan(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
|
||||
|
||||
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
||||
}
|
||||
}
|
||||
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
||||
|
||||
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
|
||||
|
||||
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
||||
}
|
||||
|
||||
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
||||
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||
|
||||
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
||||
|
||||
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
|
||||
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.add(a, b)
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
||||
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.multiply(a, k)
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.multiply(a, b)
|
||||
}
|
||||
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.divide(a, b)
|
||||
}
|
||||
|
||||
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
||||
override fun sin(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.sin(arg)
|
||||
}
|
||||
|
||||
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
||||
override fun cos(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.cos(arg)
|
||||
}
|
||||
|
||||
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
||||
override fun tan(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.tan(arg)
|
||||
}
|
||||
|
||||
override fun asin(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.asin(arg)
|
||||
}
|
||||
|
||||
override fun acos(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.acos(arg)
|
||||
}
|
||||
|
||||
override fun atan(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.atan(arg)
|
||||
}
|
||||
|
||||
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.power(arg, pow)
|
||||
}
|
||||
|
||||
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
||||
override fun exp(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.exp(arg)
|
||||
}
|
||||
|
||||
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
||||
override fun ln(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.ln(arg)
|
||||
}
|
||||
|
||||
}
|
@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) :
|
||||
override val one by lazy { produce { one } }
|
||||
|
||||
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
||||
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||
|
||||
/**
|
||||
* Inline transform an NDStructure to
|
||||
@ -74,6 +74,13 @@ class RealNDField(override val shape: IntArray) :
|
||||
|
||||
override fun cos(arg: NDBuffer<Double>) = map(arg) { cos(it) }
|
||||
|
||||
override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) }
|
||||
|
||||
override fun asin(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { asin(it) }
|
||||
|
||||
override fun acos(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { acos(it) }
|
||||
|
||||
override fun atan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { atan(it) }
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +89,7 @@ class RealNDField(override val shape: IntArray) :
|
||||
*/
|
||||
inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
|
||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
|
||||
return BufferedNDFieldElement(this, DoubleBuffer(array))
|
||||
return BufferedNDFieldElement(this, RealBuffer(array))
|
||||
}
|
||||
|
||||
/**
|
||||
@ -96,7 +103,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int
|
||||
*/
|
||||
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
|
||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
|
||||
return BufferedNDFieldElement(context, DoubleBuffer(array))
|
||||
return BufferedNDFieldElement(context, RealBuffer(array))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -0,0 +1,20 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
|
||||
override val size: Int get() = array.size
|
||||
|
||||
override fun get(index: Int): Short = array[index]
|
||||
|
||||
override fun set(index: Int, value: Short) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override fun iterator() = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Short> =
|
||||
ShortBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
|
||||
fun ShortArray.asBuffer() = ShortBuffer(this)
|
@ -9,7 +9,7 @@ import kotlin.test.assertEquals
|
||||
class ExpressionFieldTest {
|
||||
@Test
|
||||
fun testExpression() {
|
||||
val context = ExpressionField(RealField)
|
||||
val context = FunctionalExpressionField(RealField)
|
||||
val expression = with(context) {
|
||||
val x = variable("x", 2.0)
|
||||
x * x + 2 * x + one
|
||||
@ -20,7 +20,7 @@ class ExpressionFieldTest {
|
||||
|
||||
@Test
|
||||
fun testComplex() {
|
||||
val context = ExpressionField(ComplexField)
|
||||
val context = FunctionalExpressionField(ComplexField)
|
||||
val expression = with(context) {
|
||||
val x = variable("x", Complex(2.0, 0.0))
|
||||
x * x + 2 * x + one
|
||||
@ -31,23 +31,23 @@ class ExpressionFieldTest {
|
||||
|
||||
@Test
|
||||
fun separateContext() {
|
||||
fun <T> ExpressionField<T>.expression(): Expression<T> {
|
||||
fun <T> FunctionalExpressionField<T,*>.expression(): Expression<T> {
|
||||
val x = variable("x")
|
||||
return x * x + 2 * x + one
|
||||
}
|
||||
|
||||
val expression = ExpressionField(RealField).expression()
|
||||
val expression = FunctionalExpressionField(RealField).expression()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun valueExpression() {
|
||||
val expressionBuilder: ExpressionField<Double>.() -> Expression<Double> = {
|
||||
val expressionBuilder: FunctionalExpressionField<Double,*>.() -> Expression<Double> = {
|
||||
val x = variable("x")
|
||||
x * x + 2 * x + one
|
||||
}
|
||||
|
||||
val expression = ExpressionField(RealField).expressionBuilder()
|
||||
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
}
|
||||
}
|
@ -1,10 +1,12 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import scientifik.kmath.structures.*
|
||||
import java.math.BigDecimal
|
||||
import java.math.BigInteger
|
||||
import java.math.MathContext
|
||||
|
||||
/**
|
||||
* A field wrapper for Java [BigInteger]
|
||||
*/
|
||||
object JBigIntegerField : Field<BigInteger> {
|
||||
override val zero: BigInteger = BigInteger.ZERO
|
||||
override val one: BigInteger = BigInteger.ONE
|
||||
@ -18,6 +20,9 @@ object JBigIntegerField : Field<BigInteger> {
|
||||
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
|
||||
}
|
||||
|
||||
/**
|
||||
* A Field wrapper for Java [BigDecimal]
|
||||
*/
|
||||
class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
|
||||
override val zero: BigDecimal = BigDecimal.ZERO
|
||||
override val one: BigDecimal = BigDecimal.ONE
|
@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.*
|
||||
import scientifik.kmath.chains.BlockingRealChain
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
import scientifik.kmath.structures.RealBuffer
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
|
||||
/**
|
||||
@ -45,7 +45,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
|
||||
/**
|
||||
* Specialized flow chunker for real buffer
|
||||
*/
|
||||
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
||||
fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
|
||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||
|
||||
if (this@chunked is BlockingRealChain) {
|
||||
@ -61,13 +61,13 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
||||
array[counter] = element
|
||||
counter++
|
||||
if (counter == bufferSize) {
|
||||
val buffer = DoubleBuffer(array)
|
||||
val buffer = RealBuffer(array)
|
||||
emit(buffer)
|
||||
counter = 0
|
||||
}
|
||||
}
|
||||
if (counter > 0) {
|
||||
emit(DoubleBuffer(counter) { array[it] })
|
||||
emit(RealBuffer(counter) { array[it] })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,31 +7,28 @@ import scientifik.kmath.operations.Norm
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.SpaceElement
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
import scientifik.kmath.structures.RealBuffer
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import scientifik.kmath.structures.asIterable
|
||||
import kotlin.math.sqrt
|
||||
|
||||
typealias RealPoint = Point<Double>
|
||||
|
||||
fun DoubleArray.asVector() = RealVector(this.asBuffer())
|
||||
fun List<Double>.asVector() = RealVector(this.asBuffer())
|
||||
|
||||
|
||||
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
|
||||
}
|
||||
|
||||
inline class RealVector(private val point: Point<Double>) :
|
||||
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> {
|
||||
SpaceElement<RealPoint, RealVector, VectorSpace<Double, RealField>>, RealPoint {
|
||||
|
||||
override val context: VectorSpace<Double, RealField>
|
||||
get() = space(
|
||||
point.size
|
||||
)
|
||||
override val context: VectorSpace<Double, RealField> get() = space(point.size)
|
||||
|
||||
override fun unwrap(): Point<Double> = point
|
||||
override fun unwrap(): RealPoint = point
|
||||
|
||||
override fun Point<Double>.wrap(): RealVector =
|
||||
RealVector(this)
|
||||
override fun RealPoint.wrap(): RealVector = RealVector(this)
|
||||
|
||||
override val size: Int get() = point.size
|
||||
|
||||
@ -44,16 +41,12 @@ inline class RealVector(private val point: Point<Double>) :
|
||||
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
|
||||
|
||||
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
|
||||
RealVector(DoubleBuffer(dim, initializer))
|
||||
RealVector(RealBuffer(dim, initializer))
|
||||
|
||||
operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
||||
|
||||
fun space(dim: Int): BufferVectorSpace<Double, RealField> =
|
||||
spaceCache.getOrPut(dim) {
|
||||
BufferVectorSpace(
|
||||
dim,
|
||||
RealField
|
||||
) { size, init -> Buffer.real(size, init) }
|
||||
fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) {
|
||||
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
|
||||
}
|
||||
}
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
package scientifik.kmath.real
|
||||
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
import scientifik.kmath.structures.RealBuffer
|
||||
|
||||
/**
|
||||
* Simplified [DoubleBuffer] to array comparison
|
||||
* Simplified [RealBuffer] to array comparison
|
||||
*/
|
||||
fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles)
|
||||
fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles)
|
@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext
|
||||
import scientifik.kmath.linear.VirtualMatrix
|
||||
import scientifik.kmath.operations.sum
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.RealBuffer
|
||||
import scientifik.kmath.structures.asIterable
|
||||
import kotlin.math.pow
|
||||
|
||||
@ -27,6 +27,10 @@ typealias RealMatrix = Matrix<Double>
|
||||
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
||||
MatrixContext.real.produce(rowNum, colNum, initializer)
|
||||
|
||||
fun Array<DoubleArray>.toMatrix(): RealMatrix{
|
||||
return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] }
|
||||
}
|
||||
|
||||
fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let {
|
||||
MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] }
|
||||
}
|
||||
@ -129,22 +133,22 @@ fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
|
||||
fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix =
|
||||
extractColumns(columnIndex..columnIndex)
|
||||
|
||||
fun Matrix<Double>.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
||||
fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
val column = columns[j]
|
||||
with(elementContext) {
|
||||
sum(column.asIterable())
|
||||
}
|
||||
}
|
||||
|
||||
fun Matrix<Double>.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
||||
fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column")
|
||||
}
|
||||
|
||||
fun Matrix<Double>.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
||||
fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column")
|
||||
}
|
||||
|
||||
fun Matrix<Double>.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j ->
|
||||
fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
columns[j].asIterable().average()
|
||||
}
|
||||
|
||||
|
@ -1,17 +1,9 @@
|
||||
package scientifik.kmath.histogram
|
||||
|
||||
import scientifik.kmath.domains.Domain
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.structures.ArrayBuffer
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
|
||||
/**
|
||||
* A simple geometric domain
|
||||
* TODO move to geometry module
|
||||
*/
|
||||
interface Domain<T : Any> {
|
||||
operator fun contains(vector: Point<out T>): Boolean
|
||||
val dimension: Int
|
||||
}
|
||||
import scientifik.kmath.structures.RealBuffer
|
||||
|
||||
/**
|
||||
* The bin in the histogram. The histogram is by definition always done in the real space
|
||||
@ -51,9 +43,9 @@ interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
||||
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
||||
|
||||
fun MutableHistogram<Double, *>.put(vararg point: Number) =
|
||||
put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
||||
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
||||
|
||||
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(DoubleBuffer(point))
|
||||
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(RealBuffer(point))
|
||||
|
||||
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
package scientifik.kmath.histogram
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.real.asVector
|
||||
import scientifik.kmath.operations.SpaceOperations
|
||||
import scientifik.kmath.real.asVector
|
||||
import scientifik.kmath.structures.*
|
||||
import kotlin.math.floor
|
||||
|
||||
@ -21,7 +21,7 @@ data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val c
|
||||
|
||||
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
||||
|
||||
override fun contains(vector: Point<out T>): Boolean = def.contains(vector)
|
||||
override fun contains(point: Point<T>): Boolean = def.contains(point)
|
||||
|
||||
override val dimension: Int
|
||||
get() = def.center.size
|
||||
@ -50,7 +50,7 @@ class RealHistogram(
|
||||
override val dimension: Int get() = lower.size
|
||||
|
||||
|
||||
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||
|
||||
init {
|
||||
// argument checks
|
||||
|
@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
|
||||
|
||||
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
|
||||
|
||||
override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0])
|
||||
override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
|
||||
|
||||
internal operator fun inc() = this.also { counter.increment() }
|
||||
|
||||
|
@ -10,6 +10,7 @@ interface MemorySpec<T : Any> {
|
||||
val objectSize: Int
|
||||
|
||||
fun MemoryReader.read(offset: Int): T
|
||||
//TODO consider thread safety
|
||||
fun MemoryWriter.write(offset: Int, value: T)
|
||||
}
|
||||
|
||||
|
@ -3,10 +3,12 @@ pluginManagement {
|
||||
val toolsVersion = "0.5.0"
|
||||
|
||||
plugins {
|
||||
id("kotlinx.benchmark") version "0.2.0-dev-8"
|
||||
id("scientifik.mpp") version toolsVersion
|
||||
id("scientifik.jvm") version toolsVersion
|
||||
id("scientifik.atomic") version toolsVersion
|
||||
id("scientifik.publish") version toolsVersion
|
||||
kotlin("plugin.allopen") version "1.3.72"
|
||||
}
|
||||
|
||||
repositories {
|
||||
@ -45,5 +47,6 @@ include(
|
||||
":kmath-dimensions",
|
||||
":kmath-for-real",
|
||||
":kmath-geometry",
|
||||
":kmath-ast",
|
||||
":examples"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user