Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116

Merged
CommanderTvis merged 50 commits from nd4j into dev 2020-10-29 19:58:53 +03:00
69 changed files with 2430 additions and 374 deletions
Showing only changes of commit 4849f400ab - Show all commits

View File

@ -2,7 +2,7 @@ plugins {
id("scientifik.publish") apply false id("scientifik.publish") apply false
} }
val kmathVersion by extra("0.1.4-dev-7") val kmathVersion by extra("0.1.4-dev-8")
val bintrayRepo by extra("scientifik") val bintrayRepo by extra("scientifik")
val githubProject by extra("kmath") val githubProject by extra("kmath")

View File

@ -2,7 +2,7 @@
Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`).
There are different types of buffers: There are different types of buffers:
* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. * Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays.
* Boxing `ListBuffer` wrapping a list * Boxing `ListBuffer` wrapping a list
* Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value
* `MemoryBuffer` allows direct allocation of objects in continuous memory block. * `MemoryBuffer` allows direct allocation of objects in continuous memory block.

View File

@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
plugins { plugins {
java java
kotlin("jvm") kotlin("jvm")
kotlin("plugin.allopen") version "1.3.71" kotlin("plugin.allopen") version "1.3.72"
id("kotlinx.benchmark") version "0.2.0-dev-7" id("kotlinx.benchmark") version "0.2.0-dev-8"
} }
configure<AllOpenExtension> { configure<AllOpenExtension> {
@ -24,6 +24,7 @@ sourceSets {
} }
dependencies { dependencies {
implementation(project(":kmath-ast"))
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))
@ -33,8 +34,8 @@ dependencies {
implementation(project(":kmath-dimensions")) implementation(project(":kmath-dimensions"))
implementation("com.kyonifer:koma-core-ejml:0.12") implementation("com.kyonifer:koma-core-ejml:0.12")
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-7") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
"benchmarksCompile"(sourceSets.main.get().compileClasspath) "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
} }
// Configure benchmark // Configure benchmark

View File

@ -10,8 +10,8 @@ import scientifik.kmath.operations.complex
class BufferBenchmark { class BufferBenchmark {
@Benchmark @Benchmark
fun genericDoubleBufferReadWrite() { fun genericRealBufferReadWrite() {
val buffer = DoubleBuffer(size){it.toDouble()} val buffer = RealBuffer(size){it.toDouble()}
(0 until size).forEach { (0 until size).forEach {
buffer[it] buffer[it]

View File

@ -20,48 +20,39 @@ class ViktorBenchmark {
final val viktorField = ViktorNDField(intArrayOf(dim, dim)) final val viktorField = ViktorNDField(intArrayOf(dim, dim))
@Benchmark @Benchmark
fun `Automatic field addition`() { fun automaticFieldAddition() {
autoField.run { autoField.run {
var res = one var res = one
repeat(n) { repeat(n) { res += one }
res += 1.0
}
} }
} }
@Benchmark @Benchmark
fun `Viktor field addition`() { fun viktorFieldAddition() {
viktorField.run { viktorField.run {
var res = one var res = one
repeat(n) { repeat(n) { res += one }
res += one
}
} }
} }
@Benchmark @Benchmark
fun `Raw Viktor`() { fun rawViktor() {
val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim))
var res = one var res = one
repeat(n) { repeat(n) { res = res + one }
res = res + one
}
} }
@Benchmark @Benchmark
fun `Real field log`() { fun realdFieldLog() {
realField.run { realField.run {
val fortyTwo = produce { 42.0 } val fortyTwo = produce { 42.0 }
var res = one var res = one
repeat(n) { res = ln(fortyTwo) }
repeat(n) {
res = ln(fortyTwo)
}
} }
} }
@Benchmark @Benchmark
fun `Raw Viktor log`() { fun rawViktorLog() {
val fortyTwo = F64Array.full(dim, dim, init = 42.0) val fortyTwo = F64Array.full(dim, dim, init = 42.0)
var res: F64Array var res: F64Array
repeat(n) { repeat(n) {

View File

@ -0,0 +1,70 @@
package scientifik.kmath.ast
import scientifik.kmath.asm.compile
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.expressionInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import kotlin.random.Random
import kotlin.system.measureTimeMillis
class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField
fun functionalExpression() {
val expr = algebra.expressionInField {
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
}
invokeAndSum(expr)
}
fun mstExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}
invokeAndSum(expr)
}
fun asmExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}.compile()
invokeAndSum(expr)
}
private fun invokeAndSum(expr: Expression<Double>) {
val random = Random(0)
var sum = 0.0
repeat(1000000) {
sum += expr("x" to random.nextDouble())
}
println(sum)
}
}
fun main() {
val benchmark = ExpressionsInterpretersBenchmark()
val fe = measureTimeMillis {
benchmark.functionalExpression()
}
println("fe=$fe")
val mst = measureTimeMillis {
benchmark.mstExpression()
}
println("mst=$mst")
val asm = measureTimeMillis {
benchmark.asmExpression()
}
println("asm=$asm")
}

View File

@ -27,7 +27,7 @@ fun main() {
val complexTime = measureTimeMillis { val complexTime = measureTimeMillis {
complexField.run { complexField.run {
var res = one var res: NDBuffer<Complex> = one
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
} }

View File

@ -23,14 +23,14 @@ fun main() {
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField.run { autoField.run {
var res = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += number(1.0)
} }
} }
} }
measureAndPrint("Element addition"){ measureAndPrint("Element addition") {
var res = genericField.one var res = genericField.one
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
@ -63,7 +63,7 @@ fun main() {
genericField.run { genericField.run {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += one // con't avoid using `one` due to resolution ambiguity
} }
} }
} }

View File

@ -6,7 +6,7 @@ fun main(args: Array<String>) {
val n = 6000 val n = 6000
val array = DoubleArray(n * n) { 1.0 } val array = DoubleArray(n * n) { 1.0 }
val buffer = DoubleBuffer(array) val buffer = RealBuffer(array)
val strides = DefaultStrides(intArrayOf(n, n)) val strides = DefaultStrides(intArrayOf(n, n))
val structure = BufferNDStructure(strides, buffer) val structure = BufferNDStructure(strides, buffer)

View File

@ -26,10 +26,10 @@ fun main(args: Array<String>) {
} }
println("Array mapping finished in $time2 millis") println("Array mapping finished in $time2 millis")
val buffer = DoubleBuffer(DoubleArray(n * n) { 1.0 }) val buffer = RealBuffer(DoubleArray(n * n) { 1.0 })
val time3 = measureTimeMillis { val time3 = measureTimeMillis {
val target = DoubleBuffer(DoubleArray(n * n)) val target = RealBuffer(DoubleArray(n * n))
val res = array.forEachIndexed { index, value -> val res = array.forEachIndexed { index, value ->
target[index] = value + 1 target[index] = value + 1
} }

62
kmath-ast/README.md Normal file
View File

@ -0,0 +1,62 @@
# AST-based expression representation and operations (`kmath-ast`)
This subproject implements the following features:
- Expression Language and its parser.
- MST as expression language's syntax intermediate representation.
- Type-safe builder of MST.
- Evaluating expressions by traversing MST.
## Dynamic expression code generation with OW2 ASM
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
a special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder:
```kotlin
RealField.mstInField { symbol("x") + 2 }.compile()
```
… leads to generation of bytecode, which can be decompiled to the following Java class:
```java
package scientifik.kmath.asm.generated;
import java.util.Map;
import scientifik.kmath.asm.internal.MapIntrinsics;
import scientifik.kmath.expressions.Expression;
import scientifik.kmath.operations.RealField;
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
private final RealField algebra;
private final Object[] constants;
public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) {
this.algebra = algebra;
this.constants = constants;
}
public final Double invoke(Map<String, ? extends Double> arguments) {
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D);
}
}
```
### Example Usage
This API is an extension to MST and MstExpression APIs. You may optimize both MST and MSTExpression:
```kotlin
RealField.mstInField { symbol("x") + 2 }.compile()
RealField.expression("x+2".parseMath())
```
### Known issues
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
class loading overhead.
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).

View File

@ -0,0 +1,37 @@
plugins {
id("scientifik.mpp")
}
repositories {
maven("https://dl.bintray.com/hotkeytlt/maven")
}
kotlin.sourceSets {
// all {
// languageSettings.apply{
// enableLanguageFeature("NewInference")
// }
// }
commonMain {
dependencies {
api(project(":kmath-core"))
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3")
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
}
}
jvmMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
implementation("org.ow2.asm:asm:8.0.1")
implementation("org.ow2.asm:asm-commons:8.0.1")
implementation(kotlin("reflect"))
}
}
jsMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
}
}
}

View File

@ -0,0 +1,67 @@
package scientifik.kmath.ast
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField
/**
* A Mathematical Syntax Tree node for mathematical expressions
*/
sealed class MST {
/**
* A node containing unparsed string
*/
data class Symbolic(val value: String) : MST()
/**
* A node containing a number
*/
data class Numeric(val value: Number) : MST()
/**
* A node containing an unary operation
*/
data class Unary(val operation: String, val value: MST) : MST() {
companion object {
const val ABS_OPERATION = "abs"
//TODO add operations
}
}
/**
* A node containing binary operation
*/
data class Binary(val operation: String, val left: MST, val right: MST) : MST() {
companion object
}
}
//TODO add a function with positional arguments
//TODO add a function with named arguments
fun <T> Algebra<T>.evaluate(node: MST): T {
return when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
node.left.value.toDouble(),
node.right.value.toDouble()
)
number(number)
}
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
}
}
}
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -0,0 +1,72 @@
package scientifik.kmath.ast
import scientifik.kmath.operations.*
object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST =
MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MST.Binary(operation, left, right)
}
object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0)
override fun number(value: Number): MST = MstAlgebra.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
}
object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0)
override val one: MST = number(1.0)
override fun number(value: Number): MST = MstAlgebra.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
}
object MstField : Field<MST> {
override val zero: MST = number(0.0)
override val one: MST = number(1.0)
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
override fun number(value: Number): MST = MstAlgebra.number(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST =
binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k))
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
}

View File

@ -0,0 +1,55 @@
package scientifik.kmath.ast
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.FunctionalExpressionField
import scientifik.kmath.expressions.FunctionalExpressionRing
import scientifik.kmath.expressions.FunctionalExpressionSpace
import scientifik.kmath.operations.*
/**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions.
*/
class MstExpression<T>(val algebra: Algebra<T>, val mst: MST) : Expression<T> {
/**
* Substitute algebra raw value
*/
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun number(value: Number): T = if (algebra is NumericAlgebra)
algebra.number(value)
else
error("Numeric nodes are not supported by $this")
}
override fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
}
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E,
block: E.() -> MST
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
MstExpression(this, MstSpace.block())
inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
MstExpression(this, MstRing.block())
inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> =
MstExpression(this, MstField.block())
inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> =
algebra.mstInSpace(block)
inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> =
algebra.mstInRing(block)
inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> =
algebra.mstInField(block)

View File

@ -0,0 +1,59 @@
package scientifik.kmath.ast
import com.github.h0tk3y.betterParse.combinators.*
import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations
import scientifik.kmath.operations.PowerOperations
import scientifik.kmath.operations.RingOperations
import scientifik.kmath.operations.SpaceOperations
/**
* TODO move to common
*/
private object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
val lpar by token("\\(".toRegex())
val rpar by token("\\)".toRegex())
val mul by token("\\*".toRegex())
val pow by token("\\^".toRegex())
val div by token("/".toRegex())
val minus by token("-".toRegex())
val plus by token("\\+".toRegex())
val ws by token("\\s+".toRegex(), ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
val term: Parser<MST> by number or
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
(skip(lpar) and parser(this::rootParser) and skip(rpar))
val powChain by leftAssociative(term, pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b)
}
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
if (op == div) {
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else {
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
}
}
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
if (op == plus) {
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else {
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
}
}
override val rootParser: Parser<MST> by subSumChain
}
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)

View File

@ -0,0 +1,60 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.buildAlgebraOperationCall
import scientifik.kmath.asm.internal.buildName
import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass
/**
* Compile given MST to an Expression using AST compiler
*/
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) {
when (node) {
is MST.Symbolic -> loadVariable(node.value)
is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>)
algebra.number(node.value)
else
error("Number literals are not supported in $algebra")
loadTConstant(constant)
}
is MST.Unary -> buildAlgebraOperationCall(
context = algebra,
name = node.operation,
fallbackMethodName = "unaryOperation",
arity = 1
) { visit(node.value) }
is MST.Binary -> buildAlgebraOperationCall(
context = algebra,
name = node.operation,
fallbackMethodName = "binaryOperation",
arity = 2
) {
visit(node.left)
visit(node.right)
}
}
}
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
}
/**
* Compile an [MST] to ASM using given algebra
*/
inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
/**
* Optimize performance of an [MstExpression] using ASM codegen
*/
inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)

View File

@ -0,0 +1,518 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import java.util.*
import kotlin.reflect.KClass
/**
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
*
* @param T the type of AsmExpression to unwrap.
* @param algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class.
* @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
*/
internal class AsmBuilder<T> internal constructor(
private val classOfT: KClass<*>,
private val algebra: Algebra<T>,
private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
) {
/**
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
*/
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
}
/**
* The instance of [ClassLoader] used by this builder.
*/
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
/**
* ASM Type for [algebra]
*/
private val tAlgebraType: Type = algebra::class.asm
/**
* ASM type for [T]
*/
internal val tType: Type = classOfT.asm
/**
* ASM type for new class
*/
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
/**
* Index of `this` variable in invoke method of the built subclass.
*/
private val invokeThisVar: Int = 0
/**
* Index of `arguments` variable in invoke method of the built subclass.
*/
private val invokeArgumentsVar: Int = 1
/**
* List of constants to provide to the subclass.
*/
private val constants: MutableList<Any> = mutableListOf()
/**
* Method visitor of `invoke` method of the subclass.
*/
private lateinit var invokeMethodVisitor: InstructionAdapter
/**
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
*/
internal var primitiveMode: Boolean = false
/**
* Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
*/
internal var primitiveMask: Type = OBJECT_TYPE
/**
* Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
*/
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
/**
* Stack of useful objects types on stack to verify types.
*/
private val typeStack: ArrayDeque<Type> = ArrayDeque()
/**
* Stack of useful objects types on stack expected by algebra calls.
*/
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>().apply { push(tType) }
/**
* The cache for instance built by this builder.
*/
private var generatedInstance: Expression<T>? = null
/**
* Subclasses, loads and instantiates [Expression] for given parameters.
*
* The built instance is cached.
*/
@Suppress("UNCHECKED_CAST")
fun getInstance(): Expression<T> {
generatedInstance?.let { return it }
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
primitiveMode = true
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
primitiveMaskBoxed = tType
}
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit(
V1_8,
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
classType.internalName,
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
OBJECT_TYPE.internalName,
arrayOf(EXPRESSION_TYPE.internalName)
)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "algebra",
descriptor = tAlgebraType.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitMethod(
ACC_PUBLIC,
"<init>",
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
null,
null
).instructionAdapter {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = label()
load(thisVar, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
label()
load(thisVar, classType)
load(algebraVar, tAlgebraType)
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
label()
load(thisVar, classType)
load(constantsVar, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
label()
visitInsn(RETURN)
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
visitLocalVariable(
"algebra",
tAlgebraType.descriptor,
null,
l0,
l4,
algebraVar
)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
visitMaxs(0, 3)
visitEnd()
}
visitMethod(
ACC_PUBLIC or ACC_FINAL,
"invoke",
Type.getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
null
).instructionAdapter {
invokeMethodVisitor = this
visitCode()
val l0 = label()
invokeLabel0Visitor()
areturn(tType)
val l1 = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
l0,
l1,
invokeThisVar
)
visitLocalVariable(
"arguments",
MAP_TYPE.descriptor,
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
l0,
l1,
invokeArgumentsVar
)
visitMaxs(0, 2)
visitEnd()
}
visitMethod(
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
"invoke",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null,
null
).instructionAdapter {
val thisVar = 0
val argumentsVar = 1
visitCode()
val l0 = label()
load(thisVar, OBJECT_TYPE)
load(argumentsVar, MAP_TYPE)
invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false)
areturn(tType)
val l1 = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
l0,
l1,
thisVar
)
visitMaxs(0, 2)
visitEnd()
}
visitEnd()
}
val new = classLoader
.defineClass(className, classWriter.toByteArray())
.constructors
.first()
.newInstance(algebra, constants.toTypedArray()) as Expression<T>
generatedInstance = new
return new
}
/**
* Loads a [T] constant from [constants].
*/
internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT
loadNumberConstant(value as Number, mustBeBoxed)
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
return
}
loadConstant(value as Any, tType)
}
/**
* Boxes the current value and pushes it.
*/
private fun box(): Unit = invokeMethodVisitor.invokestatic(
tType.internalName,
"valueOf",
Type.getMethodDescriptor(tType, primitiveMask),
false
)
/**
* Unboxes the current boxed value and pushes it.
*/
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
NUMBER_TYPE.internalName,
NUMBER_CONVERTER_METHODS.getValue(primitiveMask),
Type.getMethodDescriptor(primitiveMask),
false
)
/**
* Loads [java.lang.Object] constant from constants.
*/
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
loadThis()
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
iconst(idx)
visitInsn(AALOAD)
checkcast(type)
}
/**
* Loads this variable.
*/
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
/**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
* from it).
*/
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
val boxed = value::class.asm
val primitive = BOXED_TO_PRIMITIVES[boxed]
if (primitive != null) {
when (primitive) {
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
}
if (mustBeBoxed) {
box()
invokeMethodVisitor.checkcast(tType)
}
return
}
loadConstant(value, boxed)
if (!mustBeBoxed) unbox()
else invokeMethodVisitor.checkcast(tType)
}
/**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
* provided.
*/
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
load(invokeArgumentsVar, MAP_TYPE)
aconst(name)
if (defaultValue != null)
loadTConstant(defaultValue)
else
aconst(null)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
"getOrFail",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
false
)
checkcast(tType)
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT)
typeStack.push(tType)
else {
unbox()
typeStack.push(primitiveMask)
}
}
/**
* Loads algebra from according field of the class and casts it to class of [algebra] provided.
*/
internal fun loadAlgebra() {
loadThis()
invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor)
}
/**
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
* called before the arguments and this operation.
*
* The result is casted to [T] automatically.
*/
internal fun invokeAlgebraOperation(
owner: String,
method: String,
descriptor: String,
expectedArity: Int,
opcode: Int = INVOKEINTERFACE
) {
run loop@{
repeat(expectedArity) {
if (typeStack.isEmpty()) return@loop
typeStack.pop()
}
}
invokeMethodVisitor.visitMethodInsn(
opcode,
owner,
method,
descriptor,
opcode == INVOKEINTERFACE
)
invokeMethodVisitor.checkcast(tType)
val isLastExpr = expectationStack.size == 1
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT || isLastExpr)
typeStack.push(tType)
else {
unbox()
typeStack.push(primitiveMask)
}
}
/**
* Writes a LDC Instruction with string constant provided.
*/
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
internal companion object {
/**
* Maps JVM primitive numbers boxed types to their primitive ASM types.
*/
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
hashMapOf(
java.lang.Byte::class to Type.BYTE_TYPE,
java.lang.Short::class to Type.SHORT_TYPE,
java.lang.Integer::class to Type.INT_TYPE,
java.lang.Long::class to Type.LONG_TYPE,
java.lang.Float::class to Type.FLOAT_TYPE,
java.lang.Double::class to Type.DOUBLE_TYPE
)
}
/**
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
*/
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
/**
* Maps primitive ASM types to [Number] functions unboxing them.
*/
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
hashMapOf(
Type.BYTE_TYPE to "byteValue",
Type.SHORT_TYPE to "shortValue",
Type.INT_TYPE to "intValue",
Type.LONG_TYPE to "longValue",
Type.FLOAT_TYPE to "floatValue",
Type.DOUBLE_TYPE to "doubleValue"
)
}
/**
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
*/
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
/**
* ASM type for [Expression].
*/
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
/**
* ASM type for [java.lang.Number].
*/
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
/**
* ASM type for [java.util.Map].
*/
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
/**
* ASM type for [java.lang.Object].
*/
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
/**
* ASM type for array of [java.lang.Object].
*/
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
/**
* ASM type for [Algebra].
*/
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
/**
* ASM type for [java.lang.String].
*/
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
/**
* ASM type for MapIntrinsics.
*/
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") }
}
}

View File

@ -0,0 +1,148 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.*
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.ast.MST
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import kotlin.reflect.KClass
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
hashMapOf(
"+" to 2 to "add",
"*" to 2 to "multiply",
"/" to 2 to "divide",
"+" to 1 to "unaryPlus",
"-" to 1 to "unaryMinus",
"-" to 2 to "minus"
)
}
internal val KClass<*>.asm: Type
get() = Type.getType(java)
/**
* Creates an [InstructionAdapter] from this [MethodVisitor].
*/
private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
/**
* Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it.
*/
internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
instructionAdapter().apply(block)
/**
* Constructs a [Label], then applies it to this visitor.
*/
internal fun MethodVisitor.label(): Label {
val l = Label()
visitLabel(l)
return l
}
/**
* Creates a class name for [Expression] subclassed to implement [mst] provided.
*
* This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
*/
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(mst, collision + 1)
}
@Suppress("FunctionName")
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter =
ClassWriter(flags).apply(block)
internal inline fun ClassWriter.visitField(
access: Int,
name: String,
descriptor: String,
signature: String?,
value: Any?,
block: FieldVisitor.() -> Unit
): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block)
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
* type expectation stack for needed arity.
*
* @return `true` if contains, else `false`.
*/
private fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
val theName = methodNameAdapters[name to arity] ?: name
val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null
val t = if (primitiveMode && hasSpecific) primitiveMask else tType
repeat(arity) { expectationStack.push(t) }
return hasSpecific
}
/**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts
* [AsmBuilder.invokeAlgebraOperation] of this method.
*
* @return `true` if contains, else `false`.
*/
private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val theName = methodNameAdapters[name to arity] ?: name
context.javaClass.methods.find {
var suitableSignature = it.name == theName && it.parameters.size == arity
if (primitiveMode && it.isBridge)
suitableSignature = false
suitableSignature
} ?: return false
val owner = context::class.asm
invokeAlgebraOperation(
owner = owner.internalName,
method = theName,
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }),
expectedArity = arity,
opcode = INVOKEVIRTUAL
)
return true
}
/**
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
*/
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
context: Algebra<T>,
name: String,
fallbackMethodName: String,
arity: Int,
parameters: AsmBuilder<T>.() -> Unit
) {
loadAlgebra()
if (!buildExpectationStack(context, name, arity)) loadStringConstant(name)
parameters()
if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
method = fallbackMethodName,
descriptor = Type.getMethodDescriptor(
AsmBuilder.OBJECT_TYPE,
AsmBuilder.STRING_TYPE,
*Array(arity) { AsmBuilder.OBJECT_TYPE }
),
expectedArity = arity
)
}

View File

@ -0,0 +1,10 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Label
import org.objectweb.asm.commons.InstructionAdapter
internal fun InstructionAdapter.label(): Label {
val l = Label()
visitLabel(l)
return l
}

View File

@ -0,0 +1,7 @@
@file:JvmName("MapIntrinsics")
package scientifik.kmath.asm.internal
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
return this[key] ?: default ?: error("Parameter not found: $key")
}

View File

@ -0,0 +1,110 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.mstInRing
import scientifik.kmath.ast.mstInSpace
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmAlgebras {
@Test
fun space() {
val res1 = ByteRing.mstInSpace {
binaryOperation(
"+",
unaryOperation(
"+",
number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)),
2
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + symbol("x") + zero
}("x" to 2.toByte())
val res2 = ByteRing.mstInSpace {
binaryOperation(
"+",
unaryOperation(
"+",
number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)),
2
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + symbol("x") + zero
}.compile()("x" to 2.toByte())
assertEquals(res1, res2)
}
@Test
fun ring() {
val res1 = ByteRing.mstInRing {
binaryOperation(
"+",
unaryOperation(
"+",
(symbol("x") - (2.toByte() + (multiply(
add(number(1), number(1)),
2
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}("x" to 3.toByte())
val res2 = ByteRing.mstInRing {
binaryOperation(
"+",
unaryOperation(
"+",
(symbol("x") - (2.toByte() + (multiply(
add(number(1), number(1)),
2
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
}
@Test
fun field() {
val res1 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
"+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0)
val res2 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
"+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0)
assertEquals(res1, res2)
}
}

View File

@ -0,0 +1,31 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.mstInSpace
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmExpressions {
@Test
fun testUnaryOperationInvocation() {
val expression = RealField.mstInSpace { -symbol("x") }.compile()
val res = expression("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile()
val res = expression("x" to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
assertEquals(4.0, res)
}
}

View File

@ -0,0 +1,46 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.mstInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmSpecialization {
@Test
fun testUnaryPlus() {
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
assertEquals(2.0, expr("x" to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
assertEquals(-2.0, expr("x" to 2.0))
}
@Test
fun testAdd() {
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
assertEquals(4.0, expr("x" to 2.0))
}
@Test
fun testSine() {
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 0.0))
}
@Test
fun testMinus() {
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
assertEquals(0.0, expr("x" to 2.0))
}
@Test
fun testDivide() {
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
assertEquals(1.0, expr("x" to 2.0))
}
}

View File

@ -0,0 +1,22 @@
package scietifik.kmath.asm
import scientifik.kmath.ast.mstInRing
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class TestAsmVariables {
@Test
fun testVariableWithoutDefault() {
val expr = ByteRing.mstInRing { symbol("x") }
assertEquals(1.toByte(), expr("x" to 1.toByte()))
}
@Test
fun testVariableWithoutDefaultFails() {
val expr = ByteRing.mstInRing { symbol("x") }
assertFailsWith<IllegalStateException> { expr() }
}
}

View File

@ -0,0 +1,26 @@
package scietifik.kmath.ast
import scientifik.kmath.asm.compile
import scientifik.kmath.asm.expression
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class AsmTest {
@Test
fun `compile MST`() {
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.expression(mst)()
assertEquals(Complex(10.0, 0.0), res)
}
@Test
fun `compile MSTExpression`() {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()()
assertEquals(Complex(10.0, 0.0), res)
}
}

View File

@ -0,0 +1,25 @@
package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserTest {
@Test
fun `evaluate MST`() {
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(Complex(10.0, 0.0), res)
}
@Test
fun `evaluate MSTExpression`() {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res)
}
}

View File

@ -2,7 +2,7 @@ package scientifik.kmath.commons.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionContext import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import kotlin.properties.ReadOnlyProperty import kotlin.properties.ReadOnlyProperty
@ -59,8 +59,10 @@ class DerivativeStructureField(
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow) is Double -> arg.pow(pow)
@ -74,10 +76,10 @@ class DerivativeStructureField(
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble()) override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble()) override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
operator fun Number.plus(s: DerivativeStructure) = s + this override operator fun Number.plus(b: DerivativeStructure) = b + this
operator fun Number.minus(s: DerivativeStructure) = s - this override operator fun Number.minus(b: DerivativeStructure) = b - this
} }
/** /**
@ -113,7 +115,7 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1)
/** /**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure]) * A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/ */
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> { object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
override fun variable(name: String, default: Double?) = override fun variable(name: String, default: Double?) =
DiffExpression { variable(name, default?.const()) } DiffExpression { variable(name, default?.const()) }
@ -136,6 +138,3 @@ object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression>
override fun divide(a: DiffExpression, b: DiffExpression) = override fun divide(a: DiffExpression, b: DiffExpression) =
DiffExpression { a.function(this) / b.function(this) } DiffExpression { a.function(this) / b.function(this) }
} }

View File

@ -0,0 +1,38 @@
package scientifik.kmath.commons.random
import scientifik.kmath.prob.RandomGenerator
class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) :
org.apache.commons.math3.random.RandomGenerator {
private var generator = factory(intArrayOf())
override fun nextBoolean(): Boolean = generator.nextBoolean()
override fun nextFloat(): Float = generator.nextDouble().toFloat()
override fun setSeed(seed: Int) {
generator = factory(intArrayOf(seed))
}
override fun setSeed(seed: IntArray) {
generator = factory(seed)
}
override fun setSeed(seed: Long) {
setSeed(seed.toInt())
}
override fun nextBytes(bytes: ByteArray) {
generator.fillBytes(bytes)
}
override fun nextInt(): Int = generator.nextInt()
override fun nextInt(n: Int): Int = generator.nextInt(n)
override fun nextGaussian(): Double = TODO()
override fun nextDouble(): Double = generator.nextDouble()
override fun nextLong(): Long = generator.nextLong()
}

View File

@ -18,7 +18,7 @@ object Transformations {
private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> = private fun Buffer<Complex>.toArray(): Array<org.apache.commons.math3.complex.Complex> =
Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) }
private fun Buffer<Double>.asArray() = if (this is DoubleBuffer) { private fun Buffer<Double>.asArray() = if (this is RealBuffer) {
array array
} else { } else {
DoubleArray(size) { i -> get(i) } DoubleArray(size) { i -> get(i) }

View File

@ -8,4 +8,4 @@ kotlin.sourceSets {
api(project(":kmath-memory")) api(project(":kmath-memory"))
} }
} }
} }

View File

@ -0,0 +1,15 @@
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
/**
* A simple geometric domain
*/
interface Domain<T : Any> {
operator fun contains(point: Point<T>): Boolean
/**
* Number of hyperspace dimensions
*/
val dimension: Int
}

View File

@ -0,0 +1,67 @@
/*
* Copyright 2015 Alexander Nozik.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.indices
/**
*
* HyperSquareDomain class.
*
* @author Alexander Nozik
*/
class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
point[i] in lower[i]..upper[i]
}
override val dimension: Int get() = lower.size
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
override fun getLowerBound(num: Int): Double? = lower[num]
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
override fun getUpperBound(num: Int): Double? = upper[num]
override fun nearestInDomain(point: Point<Double>): Point<Double> {
val res: DoubleArray = DoubleArray(point.size) { i ->
when {
point[i] < lower[i] -> lower[i]
point[i] > upper[i] -> upper[i]
else -> point[i]
}
}
return RealBuffer(*res)
}
override fun volume(): Double {
var res = 1.0
for (i in 0 until dimension) {
if (lower[i].isInfinite() || upper[i].isInfinite()) {
return Double.POSITIVE_INFINITY
}
if (upper[i] > lower[i]) {
res *= upper[i] - lower[i]
}
}
return res
}
}

View File

@ -0,0 +1,65 @@
/*
* Copyright 2015 Alexander Nozik.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
/**
* n-dimensional volume
*
* @author Alexander Nozik
*/
interface RealDomain: Domain<Double> {
fun nearestInDomain(point: Point<Double>): Point<Double>
/**
* The lower edge for the domain going down from point
* @param num
* @param point
* @return
*/
fun getLowerBound(num: Int, point: Point<Double>): Double?
/**
* The upper edge of the domain going up from point
* @param num
* @param point
* @return
*/
fun getUpperBound(num: Int, point: Point<Double>): Double?
/**
* Global lower edge
* @param num
* @return
*/
fun getLowerBound(num: Int): Double?
/**
* Global upper edge
* @param num
* @return
*/
fun getUpperBound(num: Int): Double?
/**
* Hyper volume
* @return
*/
fun volume(): Double
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2015 Alexander Nozik.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
class UnconstrainedDomain(override val dimension: Int) : RealDomain {
override operator fun contains(point: Point<Double>): Boolean = true
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
override fun nearestInDomain(point: Point<Double>): Point<Double> = point
override fun volume(): Double = Double.POSITIVE_INFINITY
}

View File

@ -0,0 +1,48 @@
package scientifik.kmath.domains
import scientifik.kmath.linear.Point
import scientifik.kmath.structures.asBuffer
inline class UnivariateDomain(val range: ClosedFloatingPointRange<Double>) : RealDomain {
operator fun contains(d: Double): Boolean = range.contains(d)
override operator fun contains(point: Point<Double>): Boolean {
require(point.size == 0)
return contains(point[0])
}
override fun nearestInDomain(point: Point<Double>): Point<Double> {
require(point.size == 1)
val value = point[0]
return when{
value in range -> point
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
else -> doubleArrayOf(range.start).asBuffer()
}
}
override fun getLowerBound(num: Int, point: Point<Double>): Double? {
require(num == 0)
return range.start
}
override fun getUpperBound(num: Int, point: Point<Double>): Double? {
require(num == 0)
return range.endInclusive
}
override fun getLowerBound(num: Int): Double? {
require(num == 0)
return range.start
}
override fun getUpperBound(num: Int): Double? {
require(num == 0)
return range.endInclusive
}
override fun volume(): Double = range.endInclusive - range.start
override val dimension: Int get() = 1
}

View File

@ -0,0 +1,23 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
/**
* Create a functional expression on this [Space]
*/
fun <T> Space<T>.spaceExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
FunctionalExpressionSpace(this).run(block)
/**
* Create a functional expression on this [Ring]
*/
fun <T> Ring<T>.ringExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
FunctionalExpressionRing(this).run(block)
/**
* Create a functional expression on this [Field]
*/
fun <T> Field<T>.fieldExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).run(block)

View File

@ -1,14 +1,21 @@
package scientifik.kmath.expressions package scientifik.kmath.expressions
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
/** /**
* An elementary function that could be invoked on a map of arguments * An elementary function that could be invoked on a map of arguments
*/ */
interface Expression<T> { interface Expression<T> {
operator fun invoke(arguments: Map<String, T>): T operator fun invoke(arguments: Map<String, T>): T
companion object
}
/**
* Create simple lazily evaluated expression inside given algebra
*/
fun <T> Algebra<T>.expression(block: Algebra<T>.(arguments: Map<String, T>) -> T): Expression<T> = object: Expression<T> {
override fun invoke(arguments: Map<String, T>): T = block(arguments)
} }
operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs)) operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
@ -16,77 +23,14 @@ operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke
/** /**
* A context for expression construction * A context for expression construction
*/ */
interface ExpressionContext<T> { interface ExpressionAlgebra<T, E> : Algebra<E> {
/** /**
* Introduce a variable into expression context * Introduce a variable into expression context
*/ */
fun variable(name: String, default: T? = null): Expression<T> fun variable(name: String, default: T? = null): E
/** /**
* A constant expression which does not depend on arguments * A constant expression which does not depend on arguments
*/ */
fun const(value: T): Expression<T> fun const(value: T): E
}
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class ConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value
}
internal class SumExpression<T>(val context: Space<T>, val first: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
}
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.multiply(first.invoke(arguments), second.invoke(arguments))
}
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
}
open class ExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionContext<T> {
override val zero: Expression<T> = ConstantExpression(space.zero)
override fun const(value: T): Expression<T> = ConstantExpression(value)
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
operator fun Expression<T>.plus(arg: T) = this + const(arg)
operator fun Expression<T>.minus(arg: T) = this - const(arg)
operator fun T.plus(arg: Expression<T>) = arg + this
operator fun T.minus(arg: Expression<T>) = arg - this
}
class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionSpace<T>(field) {
override val one: Expression<T> = ConstantExpression(field.one)
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
operator fun Expression<T>.times(arg: T) = this * const(arg)
operator fun Expression<T>.div(arg: T) = this / const(arg)
operator fun T.times(arg: Expression<T>) = arg * this
operator fun T.div(arg: Expression<T>) = arg / this
} }

View File

@ -0,0 +1,146 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
}
internal class FunctionalBinaryOperation<T>(
val context: Algebra<T>,
val name: String,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
}
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value
}
internal class FunctionalConstProductExpression<T>(
val context: Space<T>,
private val expr: Expression<T>,
val const: Number
) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
/**
* A context class for [Expression] construction.
*
* @param algebra The algebra to provide for Expressions built.
*/
abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(val algebra: A) : ExpressionAlgebra<T, Expression<T>> {
/**
* Builds an Expression of constant expression which does not depend on arguments.
*/
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
/**
* Builds an Expression to access a variable.
*/
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, operation, left, right)
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
FunctionalUnaryOperation(algebra, operation, arg)
}
/**
* A context class for [Expression] construction for [Space] algebras.
*/
open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
override val zero: Expression<T> get() = const(algebra.zero)
/**
* Builds an Expression of addition of two another expressions.
*/
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an Expression of multiplication of expression by number.
*/
override fun multiply(a: Expression<T>, k: Number): Expression<T> =
FunctionalConstProductExpression(algebra, a, k)
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
}
open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: Expression<T>
get() = const(algebra.one)
/**
* Builds an Expression of multiplication of two expressions.
*/
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
}
open class FunctionalExpressionField<T, A>(algebra: A) :
FunctionalExpressionRing<T, A>(algebra),
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an Expression of division an expression by another one.
*/
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
}
inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionSpace(this).block()
inline fun <T, A : Ring<T>> A.expressionInRing(block: FunctionalExpressionRing<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionRing(this).block()
inline fun <T, A : Field<T>> A.expressionInField(block: FunctionalExpressionField<T, A>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).block()

View File

@ -30,11 +30,11 @@ object RealMatrixContext : GenericMatrixContext<Double, RealField> {
override val elementContext get() = RealField override val elementContext get() = RealField
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> { override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
val buffer = DoubleBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = DoubleBuffer(size,initializer) override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size,initializer)
} }
class BufferMatrix<T : Any>( class BufferMatrix<T : Any>(
@ -102,7 +102,7 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
val array = DoubleArray(this.rowNum * other.colNum) val array = DoubleArray(this.rowNum * other.colNum)
//convert to array to insure there is not memory indirection //convert to array to insure there is not memory indirection
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is DoubleBuffer) { fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is RealBuffer) {
array array
} else { } else {
DoubleArray(size) { get(it) } DoubleArray(size) { get(it) }
@ -119,6 +119,6 @@ infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Do
} }
} }
val buffer = DoubleBuffer(array) val buffer = RealBuffer(array)
return BufferMatrix(rowNum, other.colNum, buffer) return BufferMatrix(rowNum, other.colNum, buffer)
} }

View File

@ -128,14 +128,14 @@ fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
luRow[col] = sum luRow[col] = sum
// maintain best permutation choice // maintain best permutation choice
if (abs(sum) > largest) { if (this@lup.abs(sum) > largest) {
largest = abs(sum) largest = this@lup.abs(sum)
max = row max = row
} }
} }
// Singularity check // Singularity check
if (checkSingular(abs(lu[max, col]))) { if (checkSingular(this@lup.abs(lu[max, col]))) {
error("The matrix is singular") error("The matrix is singular")
} }

View File

@ -90,20 +90,20 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
// Overloads for Double constants // Overloads for Double constants
operator fun Number.plus(that: Variable<T>): Variable<T> = override operator fun Number.plus(b: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + that.value }) { z -> derive(variable { this@plus.toDouble() * one + b.value }) { z ->
that.d += z.d b.d += z.d
} }
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this) override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
operator fun Number.minus(that: Variable<T>): Variable<T> = override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - that.value }) { z -> derive(variable { this@minus.toDouble() * one - b.value }) { z ->
that.d -= z.d b.d -= z.d
} }
operator fun Variable<T>.minus(that: Number): Variable<T> = override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * that.toDouble() }) { z -> derive(variable { this@minus.value - one * b.toDouble() }) { z ->
this@minus.d += z.d this@minus.d += z.d
} }
} }

View File

@ -1,5 +1,7 @@
package scientifik.kmath.misc package scientifik.kmath.misc
import kotlin.math.abs
/** /**
* Convert double range to sequence. * Convert double range to sequence.
* *
@ -8,28 +10,36 @@ package scientifik.kmath.misc
* *
* If step is negative, the same goes from upper boundary downwards * If step is negative, the same goes from upper boundary downwards
*/ */
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> = fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Sequence<Double> = when {
when { step == 0.0 -> error("Zero step in double progression")
step == 0.0 -> error("Zero step in double progression") step > 0 -> sequence {
step > 0 -> sequence { var current = start
var current = start while (current <= endInclusive) {
while (current <= endInclusive) { yield(current)
yield(current) current += step
current += step
}
}
else -> sequence {
var current = endInclusive
while (current >= start) {
yield(current)
current += step
}
}
} }
}
else -> sequence {
var current = endInclusive
while (current >= start) {
yield(current)
current += step
}
}
}
/**
* Convert double range to sequence with the fixed number of points
*/
fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int): Sequence<Double> {
require(numPoints > 1) { "The number of points should be more than 2" }
return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1))
}
/** /**
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
*/ */
@Deprecated("Replace by 'toSequenceWithPoints'")
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray { fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't create generic grid with less than two points") if (numPoints < 2) error("Can't create generic grid with less than two points")
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }

View File

@ -6,9 +6,43 @@ annotation class KMathContext
/** /**
* Marker interface for any algebra * Marker interface for any algebra
*/ */
interface Algebra<T> interface Algebra<T> {
/**
* Wrap raw string or variable
*/
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
inline operator fun <T : Algebra<*>, R> T.invoke(block: T.() -> R): R = run(block) /**
* Dynamic call of unary operation with name [operation] on [arg]
*/
fun unaryOperation(operation: String, arg: T): T
/**
* Dynamic call of binary operation [operation] on [left] and [right]
*/
fun binaryOperation(operation: String, left: T, right: T): T
}
/**
* An algebra with numeric representation of members
*/
interface NumericAlgebra<T> : Algebra<T> {
/**
* Wrap a number
*/
fun number(value: Number): T
fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
binaryOperation(operation, number(left), right)
fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left)
}
/**
* Call a block with an [Algebra] as receiver
*/
inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
/** /**
* Space-like operations without neutral element * Space-like operations without neutral element
@ -24,14 +58,34 @@ interface SpaceOperations<T> : Algebra<T> {
*/ */
fun multiply(a: T, k: Number): T fun multiply(a: T, k: Number): T
//Operation to be performed in this context //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.unaryMinus(): T = multiply(this, -1.0)
operator fun T.unaryPlus(): T = this
operator fun T.plus(b: T): T = add(this, b) operator fun T.plus(b: T): T = add(this, b)
operator fun T.minus(b: T): T = add(this, -b) operator fun T.minus(b: T): T = add(this, -b)
operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.times(k: Number) = multiply(this, k.toDouble())
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
operator fun Number.times(b: T) = b * this operator fun Number.times(b: T) = b * this
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
PLUS_OPERATION -> arg
MINUS_OPERATION -> -arg
else -> error("Unary operation $operation not defined in $this")
}
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
PLUS_OPERATION -> add(left, right)
MINUS_OPERATION -> left - right
else -> error("Binary operation $operation not defined in $this")
}
companion object {
const val PLUS_OPERATION = "+"
const val MINUS_OPERATION = "-"
const val NOT_OPERATION = "!"
}
} }
@ -60,22 +114,48 @@ interface RingOperations<T> : SpaceOperations<T> {
fun multiply(a: T, b: T): T fun multiply(a: T, b: T): T
operator fun T.times(b: T): T = multiply(this, b) operator fun T.times(b: T): T = multiply(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
TIMES_OPERATION -> multiply(left, right)
else -> super.binaryOperation(operation, left, right)
}
companion object {
const val TIMES_OPERATION = "*"
}
} }
/** /**
* The same as {@link Space} but with additional multiplication operation * The same as {@link Space} but with additional multiplication operation
*/ */
interface Ring<T> : Space<T>, RingOperations<T> { interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
/** /**
* neutral operation for multiplication * neutral operation for multiplication
*/ */
val one: T val one: T
// operator fun T.plus(b: Number) = this.plus(b * one) override fun number(value: Number): T = one * value.toDouble()
// operator fun Number.plus(b: T) = b + this
// override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
// operator fun T.minus(b: Number) = this.minus(b * one) SpaceOperations.PLUS_OPERATION -> left + right
// operator fun Number.minus(b: T) = -b + this SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right)
}
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
operator fun T.plus(b: Number) = this.plus(number(b))
operator fun Number.plus(b: T) = b + this
operator fun T.minus(b: Number) = this.minus(number(b))
operator fun Number.minus(b: T) = -b + this
} }
/** /**
@ -85,6 +165,15 @@ interface FieldOperations<T> : RingOperations<T> {
fun divide(a: T, b: T): T fun divide(a: T, b: T): T
operator fun T.div(b: T): T = divide(this, b) operator fun T.div(b: T): T = divide(this, b)
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
DIV_OPERATION -> divide(left, right)
else -> super.binaryOperation(operation, left, right)
}
companion object {
const val DIV_OPERATION = "/"
}
} }
/** /**

View File

@ -8,10 +8,12 @@ import scientifik.memory.MemorySpec
import scientifik.memory.MemoryWriter import scientifik.memory.MemoryWriter
import kotlin.math.* import kotlin.math.*
private val PI_DIV_2 = Complex(PI / 2, 0)
/** /**
* A field for complex numbers * A field for complex numbers
*/ */
object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> { object ComplexField : ExtendedField<Complex> {
override val zero: Complex = Complex(0.0, 0.0) override val zero: Complex = Complex(0.0, 0.0)
override val one: Complex = Complex(1.0, 0.0) override val one: Complex = Complex(1.0, 0.0)
@ -30,9 +32,11 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
} }
override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg)) override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg)
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg)
override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2
override fun power(arg: Complex, pow: Number): Complex = override fun power(arg: Complex, pow: Number): Complex =
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
@ -50,6 +54,12 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) operator fun Complex.minus(d: Double) = add(this, -d.toComplex())
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this)
override fun symbol(value: String): Complex = if (value == "i") {
i
} else {
super.symbol(value)
}
} }
/** /**

View File

@ -7,12 +7,32 @@ import kotlin.math.pow as kpow
* Advanced Number-like field that implements basic operations * Advanced Number-like field that implements basic operations
*/ */
interface ExtendedFieldOperations<T> : interface ExtendedFieldOperations<T> :
FieldOperations<T>, InverseTrigonometricOperations<T>,
TrigonometricOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T> ExponentialOperations<T> {
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> override fun tan(arg: T): T = sin(arg) / cos(arg)
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg)
TrigonometricOperations.TAN_OPERATION -> tan(arg)
InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg)
InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg)
InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg)
PowerOperations.SQRT_OPERATION -> sqrt(arg)
ExponentialOperations.EXP_OPERATION -> exp(arg)
ExponentialOperations.LN_OPERATION -> ln(arg)
else -> super.unaryOperation(operation, arg)
}
}
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right)
}
}
/** /**
* Real field element wrapping double. * Real field element wrapping double.
@ -44,6 +64,10 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
override inline fun sin(arg: Double) = kotlin.math.sin(arg) override inline fun sin(arg: Double) = kotlin.math.sin(arg)
override inline fun cos(arg: Double) = kotlin.math.cos(arg) override inline fun cos(arg: Double) = kotlin.math.cos(arg)
override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble())
@ -75,6 +99,10 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override inline fun sin(arg: Float) = kotlin.math.sin(arg) override inline fun sin(arg: Float) = kotlin.math.sin(arg)
override inline fun cos(arg: Float) = kotlin.math.cos(arg) override inline fun cos(arg: Float) = kotlin.math.cos(arg)
override inline fun tan(arg: Float) = kotlin.math.tan(arg)
override inline fun acos(arg: Float) = kotlin.math.acos(arg)
override inline fun asin(arg: Float) = kotlin.math.asin(arg)
override inline fun atan(arg: Float) = kotlin.math.atan(arg)
override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat())

View File

@ -13,16 +13,33 @@ package scientifik.kmath.operations
interface TrigonometricOperations<T> : FieldOperations<T> { interface TrigonometricOperations<T> : FieldOperations<T> {
fun sin(arg: T): T fun sin(arg: T): T
fun cos(arg: T): T fun cos(arg: T): T
fun tan(arg: T): T
fun tg(arg: T): T = sin(arg) / cos(arg) companion object {
const val SIN_OPERATION = "sin"
const val COS_OPERATION = "cos"
const val TAN_OPERATION = "tan"
}
}
fun ctg(arg: T): T = cos(arg) / sin(arg) interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
fun asin(arg: T): T
fun acos(arg: T): T
fun atan(arg: T): T
companion object {
const val ASIN_OPERATION = "asin"
const val ACOS_OPERATION = "acos"
const val ATAN_OPERATION = "atan"
}
} }
fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg) fun <T : MathElement<out TrigonometricOperations<T>>> sin(arg: T): T = arg.context.sin(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg) fun <T : MathElement<out TrigonometricOperations<T>>> cos(arg: T): T = arg.context.cos(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> tg(arg: T): T = arg.context.tg(arg) fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.context.tan(arg)
fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.context.ctg(arg) fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/* Power and roots */ /* Power and roots */
@ -34,6 +51,11 @@ interface PowerOperations<T> : Algebra<T> {
fun sqrt(arg: T) = power(arg, 0.5) fun sqrt(arg: T) = power(arg, 0.5)
infix fun T.pow(pow: Number) = power(this, pow) infix fun T.pow(pow: Number) = power(this, pow)
companion object {
const val POW_OPERATION = "pow"
const val SQRT_OPERATION = "sqrt"
}
} }
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power) infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
@ -42,9 +64,14 @@ fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/* Exponential */ /* Exponential */
interface ExponentialOperations<T>: Algebra<T> { interface ExponentialOperations<T> : Algebra<T> {
fun exp(arg: T): T fun exp(arg: T): T
fun ln(arg: T): T fun ln(arg: T): T
companion object {
const val EXP_OPERATION = "exp"
const val LN_OPERATION = "ln"
}
} }
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg) fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
@ -54,4 +81,4 @@ interface Norm<in T : Any, out R> {
fun norm(arg: T): R fun norm(arg: T): R
} }
fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg) fun <T : MathElement<out Norm<T, R>>, R> norm(arg: T): R = arg.context.norm(arg)

View File

@ -37,9 +37,9 @@ interface Buffer<T> {
companion object { companion object {
inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer {
val array = DoubleArray(size) { initializer(it) } val array = DoubleArray(size) { initializer(it) }
return DoubleBuffer(array) return RealBuffer(array)
} }
/** /**
@ -51,7 +51,7 @@ interface Buffer<T> {
inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> { inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
//TODO add resolution based on Annotation or companion resolution //TODO add resolution based on Annotation or companion resolution
return when (type) { return when (type) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
@ -93,7 +93,7 @@ interface MutableBuffer<T> : Buffer<T> {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> { inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> {
return when (type) { return when (type) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
@ -109,12 +109,11 @@ interface MutableBuffer<T> : Buffer<T> {
auto(T::class, size, initializer) auto(T::class, size, initializer)
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double -> val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
DoubleBuffer(DoubleArray(size) { initializer(it) }) RealBuffer(DoubleArray(size) { initializer(it) })
} }
} }
} }
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> { inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
override val size: Int override val size: Int
@ -163,57 +162,6 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this) fun <T> Array<T>.asBuffer(): ArrayBuffer<T> = ArrayBuffer(this)
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
override val size: Int get() = array.size
override fun get(index: Int): Short = array[index]
override fun set(index: Int, value: Short) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Short> = ShortBuffer(array.copyOf())
}
fun ShortArray.asBuffer() = ShortBuffer(this)
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
override val size: Int get() = array.size
override fun get(index: Int): Int = array[index]
override fun set(index: Int, value: Int) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Int> = IntBuffer(array.copyOf())
}
fun IntArray.asBuffer() = IntBuffer(this)
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
override val size: Int get() = array.size
override fun get(index: Int): Long = array[index]
override fun set(index: Int, value: Long) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Long> = LongBuffer(array.copyOf())
}
fun LongArray.asBuffer() = LongBuffer(this)
inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> { inline class ReadOnlyBuffer<T>(val buffer: MutableBuffer<T>) : Buffer<T> {
override val size: Int get() = buffer.size override val size: Int get() = buffer.size

View File

@ -79,6 +79,13 @@ class ComplexNDField(override val shape: IntArray) :
override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {acos(it)}
override fun atan(arg: NDBuffer<Complex>): NDBuffer<Complex> = map(arg) {atan(it)}
} }

View File

@ -1,13 +1,8 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.* import scientifik.kmath.operations.ExtendedField
interface ExtendedNDField<T : Any, F, N : NDStructure<T>> : interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : NDField<T, F, N>, ExtendedField<N>
NDField<T, F, N>,
TrigonometricOperations<N>,
PowerOperations<N>,
ExponentialOperations<N>
where F : ExtendedFieldOperations<T>, F : Field<T>
///** ///**

View File

@ -0,0 +1,53 @@
package scientifik.kmath.structures
import kotlin.experimental.and
enum class ValueFlag(val mask: Byte) {
NAN(0b0000_0001),
MISSING(0b0000_0010),
NEGATIVE_INFINITY(0b0000_0100),
POSITIVE_INFINITY(0b0000_1000)
}
/**
* A buffer with flagged values
*/
interface FlaggedBuffer<T> : Buffer<T> {
fun getFlag(index: Int): Byte
}
/**
* The value is valid if all flags are down
*/
fun FlaggedBuffer<*>.isValid(index: Int) = getFlag(index) != 0.toByte()
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag) = (getFlag(index) and flag.mask) != 0.toByte()
fun FlaggedBuffer<*>.isMissing(index: Int) = hasFlag(index, ValueFlag.MISSING)
/**
* A real buffer which supports flags for each value like NaN or Missing
*/
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
init {
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
}
override fun getFlag(index: Int): Byte = flags[index]
override val size: Int get() = values.size
override fun get(index: Int): Double? = if (isValid(index)) values[index] else null
override fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
if (isValid(it)) values[it] else null
}.iterator()
}
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
for(i in indices){
if(isValid(i)){
block(values[i])
}
}
}

View File

@ -0,0 +1,20 @@
package scientifik.kmath.structures
inline class IntBuffer(val array: IntArray) : MutableBuffer<Int> {
override val size: Int get() = array.size
override fun get(index: Int): Int = array[index]
override fun set(index: Int, value: Int) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Int> =
IntBuffer(array.copyOf())
}
fun IntArray.asBuffer() = IntBuffer(this)

View File

@ -0,0 +1,19 @@
package scientifik.kmath.structures
inline class LongBuffer(val array: LongArray) : MutableBuffer<Long> {
override val size: Int get() = array.size
override fun get(index: Int): Long = array[index]
override fun set(index: Int, value: Long) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Long> =
LongBuffer(array.copyOf())
}
fun LongArray.asBuffer() = LongBuffer(this)

View File

@ -1,6 +1,6 @@
package scientifik.kmath.structures package scientifik.kmath.structures
inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> { inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
override val size: Int get() = array.size override val size: Int get() = array.size
override fun get(index: Int): Double = array[index] override fun get(index: Int): Double = array[index]
@ -12,23 +12,23 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
override fun iterator() = array.iterator() override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Double> = override fun copy(): MutableBuffer<Double> =
DoubleBuffer(array.copyOf()) RealBuffer(array.copyOf())
} }
@Suppress("FunctionName") @Suppress("FunctionName")
inline fun DoubleBuffer(size: Int, init: (Int) -> Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) }) inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
@Suppress("FunctionName") @Suppress("FunctionName")
fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(doubles) fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
/** /**
* Transform buffer of doubles into array for high performance operations * Transform buffer of doubles into array for high performance operations
*/ */
val MutableBuffer<out Double>.array: DoubleArray val MutableBuffer<out Double>.array: DoubleArray
get() = if (this is DoubleBuffer) { get() = if (this is RealBuffer) {
array array
} else { } else {
DoubleArray(size) { get(it) } DoubleArray(size) { get(it) }
} }
fun DoubleArray.asBuffer() = DoubleBuffer(this) fun DoubleArray.asBuffer() = RealBuffer(this)

View File

@ -9,145 +9,172 @@ import kotlin.math.*
* A simple field over linear buffers of [Double] * A simple field over linear buffers of [Double]
*/ */
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> { object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
return if (a is DoubleBuffer && b is DoubleBuffer) {
return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
} else { } else
DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
}
} }
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer { override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
val kValue = k.toDouble() val kValue = k.toDouble()
return if (a is DoubleBuffer) {
return if (a is RealBuffer) {
val aArray = a.array val aArray = a.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) RealBuffer(DoubleArray(a.size) { aArray[it] * kValue })
} else { } else
DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) RealBuffer(DoubleArray(a.size) { a[it] * kValue })
}
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
return if (a is DoubleBuffer && b is DoubleBuffer) {
return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
} else { } else
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) RealBuffer(DoubleArray(a.size) { a[it] * b[it] })
}
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
return if (a is DoubleBuffer && b is DoubleBuffer) {
return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
} else { } else
DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
}
} }
override fun sin(arg: Buffer<Double>): DoubleBuffer { override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
return if (arg is DoubleBuffer) { val array = arg.array
val array = arg.array RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) } else {
} else { RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) })
}
} }
override fun cos(arg: Buffer<Double>): DoubleBuffer { override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
return if (arg is DoubleBuffer) { val array = arg.array
val array = arg.array RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) } else
} else { RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) })
} override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
} else
RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
} else {
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
} }
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer { override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
return if (arg is DoubleBuffer) { val array = arg.array
val array = arg.array RealBuffer(DoubleArray(arg.size) { acos(array[it]) })
DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) } else
} else { RealBuffer(DoubleArray(arg.size) { acos(arg[it]) })
DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
}
}
override fun exp(arg: Buffer<Double>): DoubleBuffer { override fun atan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
return if (arg is DoubleBuffer) { val array = arg.array
val array = arg.array RealBuffer(DoubleArray(arg.size) { atan(array[it]) })
DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) } else
} else { RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) })
}
}
override fun ln(arg: Buffer<Double>): DoubleBuffer { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
return if (arg is DoubleBuffer) { val array = arg.array
val array = arg.array RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) } else
} else { RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) })
} override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
} val array = arg.array
RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
} else
RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
} else
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
} }
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> { class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } } override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.add(a, b) return RealBufferFieldOperations.add(a, b)
} }
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer { override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.multiply(a, k) return RealBufferFieldOperations.multiply(a, k)
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.multiply(a, b) return RealBufferFieldOperations.multiply(a, b)
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.divide(a, b) return RealBufferFieldOperations.divide(a, b)
} }
override fun sin(arg: Buffer<Double>): DoubleBuffer { override fun sin(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.sin(arg) return RealBufferFieldOperations.sin(arg)
} }
override fun cos(arg: Buffer<Double>): DoubleBuffer { override fun cos(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.cos(arg) return RealBufferFieldOperations.cos(arg)
} }
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer { override fun tan(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.tan(arg)
}
override fun asin(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.asin(arg)
}
override fun acos(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.acos(arg)
}
override fun atan(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.atan(arg)
}
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.power(arg, pow) return RealBufferFieldOperations.power(arg, pow)
} }
override fun exp(arg: Buffer<Double>): DoubleBuffer { override fun exp(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.exp(arg) return RealBufferFieldOperations.exp(arg)
} }
override fun ln(arg: Buffer<Double>): DoubleBuffer { override fun ln(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.ln(arg) return RealBufferFieldOperations.ln(arg)
} }
}
}

View File

@ -16,7 +16,7 @@ class RealNDField(override val shape: IntArray) :
override val one by lazy { produce { one } } override val one by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> = inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
DoubleBuffer(DoubleArray(size) { initializer(it) }) RealBuffer(DoubleArray(size) { initializer(it) })
/** /**
* Inline transform an NDStructure to * Inline transform an NDStructure to
@ -74,6 +74,13 @@ class RealNDField(override val shape: IntArray) :
override fun cos(arg: NDBuffer<Double>) = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Double>) = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { atan(it) }
} }
@ -82,7 +89,7 @@ class RealNDField(override val shape: IntArray) :
*/ */
inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
return BufferedNDFieldElement(this, DoubleBuffer(array)) return BufferedNDFieldElement(this, RealBuffer(array))
} }
/** /**
@ -96,7 +103,7 @@ inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: Int
*/ */
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
return BufferedNDFieldElement(context, DoubleBuffer(array)) return BufferedNDFieldElement(context, RealBuffer(array))
} }
/** /**

View File

@ -0,0 +1,20 @@
package scientifik.kmath.structures
inline class ShortBuffer(val array: ShortArray) : MutableBuffer<Short> {
override val size: Int get() = array.size
override fun get(index: Int): Short = array[index]
override fun set(index: Int, value: Short) {
array[index] = value
}
override fun iterator() = array.iterator()
override fun copy(): MutableBuffer<Short> =
ShortBuffer(array.copyOf())
}
fun ShortArray.asBuffer() = ShortBuffer(this)

View File

@ -9,7 +9,7 @@ import kotlin.test.assertEquals
class ExpressionFieldTest { class ExpressionFieldTest {
@Test @Test
fun testExpression() { fun testExpression() {
val context = ExpressionField(RealField) val context = FunctionalExpressionField(RealField)
val expression = with(context) { val expression = with(context) {
val x = variable("x", 2.0) val x = variable("x", 2.0)
x * x + 2 * x + one x * x + 2 * x + one
@ -20,7 +20,7 @@ class ExpressionFieldTest {
@Test @Test
fun testComplex() { fun testComplex() {
val context = ExpressionField(ComplexField) val context = FunctionalExpressionField(ComplexField)
val expression = with(context) { val expression = with(context) {
val x = variable("x", Complex(2.0, 0.0)) val x = variable("x", Complex(2.0, 0.0))
x * x + 2 * x + one x * x + 2 * x + one
@ -31,23 +31,23 @@ class ExpressionFieldTest {
@Test @Test
fun separateContext() { fun separateContext() {
fun <T> ExpressionField<T>.expression(): Expression<T> { fun <T> FunctionalExpressionField<T,*>.expression(): Expression<T> {
val x = variable("x") val x = variable("x")
return x * x + 2 * x + one return x * x + 2 * x + one
} }
val expression = ExpressionField(RealField).expression() val expression = FunctionalExpressionField(RealField).expression()
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression("x" to 1.0), 4.0)
} }
@Test @Test
fun valueExpression() { fun valueExpression() {
val expressionBuilder: ExpressionField<Double>.() -> Expression<Double> = { val expressionBuilder: FunctionalExpressionField<Double,*>.() -> Expression<Double> = {
val x = variable("x") val x = variable("x")
x * x + 2 * x + one x * x + 2 * x + one
} }
val expression = ExpressionField(RealField).expressionBuilder() val expression = FunctionalExpressionField(RealField).expressionBuilder()
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression("x" to 1.0), 4.0)
} }
} }

View File

@ -1,10 +1,12 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.structures.*
import java.math.BigDecimal import java.math.BigDecimal
import java.math.BigInteger import java.math.BigInteger
import java.math.MathContext import java.math.MathContext
/**
* A field wrapper for Java [BigInteger]
*/
object JBigIntegerField : Field<BigInteger> { object JBigIntegerField : Field<BigInteger> {
override val zero: BigInteger = BigInteger.ZERO override val zero: BigInteger = BigInteger.ZERO
override val one: BigInteger = BigInteger.ONE override val one: BigInteger = BigInteger.ONE
@ -18,6 +20,9 @@ object JBigIntegerField : Field<BigInteger> {
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
} }
/**
* A Field wrapper for Java [BigDecimal]
*/
class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> { class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
override val zero: BigDecimal = BigDecimal.ZERO override val zero: BigDecimal = BigDecimal.ZERO
override val one: BigDecimal = BigDecimal.ONE override val one: BigDecimal = BigDecimal.ONE

View File

@ -5,7 +5,7 @@ import kotlinx.coroutines.flow.*
import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
/** /**
@ -45,7 +45,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
/** /**
* Specialized flow chunker for real buffer * Specialized flow chunker for real buffer
*/ */
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow { fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
require(bufferSize > 0) { "Resulting chunk size must be more than zero" } require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
if (this@chunked is BlockingRealChain) { if (this@chunked is BlockingRealChain) {
@ -61,13 +61,13 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
array[counter] = element array[counter] = element
counter++ counter++
if (counter == bufferSize) { if (counter == bufferSize) {
val buffer = DoubleBuffer(array) val buffer = RealBuffer(array)
emit(buffer) emit(buffer)
counter = 0 counter = 0
} }
} }
if (counter > 0) { if (counter > 0) {
emit(DoubleBuffer(counter) { array[it] }) emit(RealBuffer(counter) { array[it] })
} }
} }
} }

View File

@ -7,31 +7,28 @@ import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.SpaceElement import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
import scientifik.kmath.structures.asIterable import scientifik.kmath.structures.asIterable
import kotlin.math.sqrt import kotlin.math.sqrt
typealias RealPoint = Point<Double>
fun DoubleArray.asVector() = RealVector(this.asBuffer()) fun DoubleArray.asVector() = RealVector(this.asBuffer())
fun List<Double>.asVector() = RealVector(this.asBuffer()) fun List<Double>.asVector() = RealVector(this.asBuffer())
object VectorL2Norm : Norm<Point<out Number>, Double> { object VectorL2Norm : Norm<Point<out Number>, Double> {
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
} }
inline class RealVector(private val point: Point<Double>) : inline class RealVector(private val point: Point<Double>) :
SpaceElement<Point<Double>, RealVector, VectorSpace<Double, RealField>>, Point<Double> { SpaceElement<RealPoint, RealVector, VectorSpace<Double, RealField>>, RealPoint {
override val context: VectorSpace<Double, RealField> override val context: VectorSpace<Double, RealField> get() = space(point.size)
get() = space(
point.size
)
override fun unwrap(): Point<Double> = point override fun unwrap(): RealPoint = point
override fun Point<Double>.wrap(): RealVector = override fun RealPoint.wrap(): RealVector = RealVector(this)
RealVector(this)
override val size: Int get() = point.size override val size: Int get() = point.size
@ -44,16 +41,12 @@ inline class RealVector(private val point: Point<Double>) :
private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>() private val spaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = inline operator fun invoke(dim: Int, initializer: (Int) -> Double) =
RealVector(DoubleBuffer(dim, initializer)) RealVector(RealBuffer(dim, initializer))
operator fun invoke(vararg values: Double): RealVector = values.asVector() operator fun invoke(vararg values: Double): RealVector = values.asVector()
fun space(dim: Int): BufferVectorSpace<Double, RealField> = fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) {
spaceCache.getOrPut(dim) { BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
BufferVectorSpace( }
dim,
RealField
) { size, init -> Buffer.real(size, init) }
}
} }
} }

View File

@ -1,8 +1,8 @@
package scientifik.kmath.real package scientifik.kmath.real
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
/** /**
* Simplified [DoubleBuffer] to array comparison * Simplified [RealBuffer] to array comparison
*/ */
fun DoubleBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles)

View File

@ -5,8 +5,8 @@ import scientifik.kmath.linear.RealMatrixContext.elementContext
import scientifik.kmath.linear.VirtualMatrix import scientifik.kmath.linear.VirtualMatrix
import scientifik.kmath.operations.sum import scientifik.kmath.operations.sum
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer
import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.RealBuffer
import scientifik.kmath.structures.asIterable import scientifik.kmath.structures.asIterable
import kotlin.math.pow import kotlin.math.pow
@ -27,6 +27,10 @@ typealias RealMatrix = Matrix<Double>
fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
MatrixContext.real.produce(rowNum, colNum, initializer) MatrixContext.real.produce(rowNum, colNum, initializer)
fun Array<DoubleArray>.toMatrix(): RealMatrix{
return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] }
}
fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let { fun Sequence<DoubleArray>.toMatrix(): RealMatrix = toList().let {
MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] }
} }
@ -129,22 +133,22 @@ fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix = fun Matrix<Double>.extractColumn(columnIndex: Int): RealMatrix =
extractColumns(columnIndex..columnIndex) extractColumns(columnIndex..columnIndex)
fun Matrix<Double>.sumByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
val column = columns[j] val column = columns[j]
with(elementContext) { with(elementContext) {
sum(column.asIterable()) sum(column.asIterable())
} }
} }
fun Matrix<Double>.minByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column")
} }
fun Matrix<Double>.maxByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column")
} }
fun Matrix<Double>.averageByColumn(): DoubleBuffer = DoubleBuffer(colNum) { j -> fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().average() columns[j].asIterable().average()
} }

View File

@ -1,17 +1,9 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.domains.Domain
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.RealBuffer
/**
* A simple geometric domain
* TODO move to geometry module
*/
interface Domain<T : Any> {
operator fun contains(vector: Point<out T>): Boolean
val dimension: Int
}
/** /**
* The bin in the histogram. The histogram is by definition always done in the real space * The bin in the histogram. The histogram is by definition always done in the real space
@ -51,9 +43,9 @@ interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point)) fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
fun MutableHistogram<Double, *>.put(vararg point: Number) = fun MutableHistogram<Double, *>.put(vararg point: Number) =
put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(DoubleBuffer(point)) fun MutableHistogram<Double, *>.put(vararg point: Double) = put(RealBuffer(point))
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) } fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }

View File

@ -1,8 +1,8 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.real.asVector
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.real.asVector
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.math.floor import kotlin.math.floor
@ -21,7 +21,7 @@ data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val c
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> { class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
override fun contains(vector: Point<out T>): Boolean = def.contains(vector) override fun contains(point: Point<T>): Boolean = def.contains(point)
override val dimension: Int override val dimension: Int
get() = def.center.size get() = def.center.size
@ -50,7 +50,7 @@ class RealHistogram(
override val dimension: Int get() = lower.size override val dimension: Int get() = lower.size
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
init { init {
// argument checks // argument checks

View File

@ -16,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0]) override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
internal operator fun inc() = this.also { counter.increment() } internal operator fun inc() = this.also { counter.increment() }

View File

@ -10,6 +10,7 @@ interface MemorySpec<T : Any> {
val objectSize: Int val objectSize: Int
fun MemoryReader.read(offset: Int): T fun MemoryReader.read(offset: Int): T
//TODO consider thread safety
fun MemoryWriter.write(offset: Int, value: T) fun MemoryWriter.write(offset: Int, value: T)
} }

View File

@ -3,10 +3,12 @@ pluginManagement {
val toolsVersion = "0.5.0" val toolsVersion = "0.5.0"
plugins { plugins {
id("kotlinx.benchmark") version "0.2.0-dev-8"
id("scientifik.mpp") version toolsVersion id("scientifik.mpp") version toolsVersion
id("scientifik.jvm") version toolsVersion id("scientifik.jvm") version toolsVersion
id("scientifik.atomic") version toolsVersion id("scientifik.atomic") version toolsVersion
id("scientifik.publish") version toolsVersion id("scientifik.publish") version toolsVersion
kotlin("plugin.allopen") version "1.3.72"
} }
repositories { repositories {
@ -45,5 +47,6 @@ include(
":kmath-dimensions", ":kmath-dimensions",
":kmath-for-real", ":kmath-for-real",
":kmath-geometry", ":kmath-geometry",
":kmath-ast",
":examples" ":examples"
) )