forked from kscience/kmath
Update tools and Kotlin, specify public
explicitly, minor contracts refactor
This commit is contained in:
parent
5e4522bb06
commit
6b79e79d21
@ -3,7 +3,7 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
|||||||
plugins {
|
plugins {
|
||||||
java
|
java
|
||||||
kotlin("jvm")
|
kotlin("jvm")
|
||||||
kotlin("plugin.allopen") version "1.4.0"
|
kotlin("plugin.allopen") version "1.4.20-dev-3898-14"
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -13,6 +13,7 @@ repositories {
|
|||||||
maven("http://dl.bintray.com/kyonifer/maven")
|
maven("http://dl.bintray.com/kyonifer/maven")
|
||||||
maven("https://dl.bintray.com/mipt-npm/scientifik")
|
maven("https://dl.bintray.com/mipt-npm/scientifik")
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,9 +56,4 @@ kotlin.sourceSets.all {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType<KotlinCompile> {
|
tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" }
|
||||||
kotlinOptions {
|
|
||||||
jvmTarget = "11"
|
|
||||||
freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -37,7 +37,7 @@ public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
|||||||
/**
|
/**
|
||||||
* [Ring] over [MST] nodes.
|
* [Ring] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST = number(0.0)
|
override val zero: MST = number(0.0)
|
||||||
override val one: MST = number(1.0)
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
@ -58,18 +58,18 @@ object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
|||||||
/**
|
/**
|
||||||
* [Field] over [MST] nodes.
|
* [Field] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
object MstField : Field<MST> {
|
public object MstField : Field<MST> {
|
||||||
override val zero: MST = number(0.0)
|
public override val zero: MST = number(0.0)
|
||||||
override val one: MST = number(1.0)
|
public override val one: MST = number(1.0)
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MstRing.symbol(value)
|
public override fun symbol(value: String): MST = MstRing.symbol(value)
|
||||||
override fun number(value: Number): MST = MstRing.number(value)
|
public override fun number(value: Number): MST = MstRing.number(value)
|
||||||
override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
||||||
override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
||||||
override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
||||||
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
public override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||||
MstRing.binaryOperation(operation, left, right)
|
MstRing.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
||||||
@ -78,7 +78,7 @@ object MstField : Field<MST> {
|
|||||||
/**
|
/**
|
||||||
* [ExtendedField] over [MST] nodes.
|
* [ExtendedField] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
object MstExtendedField : ExtendedField<MST> {
|
public object MstExtendedField : ExtendedField<MST> {
|
||||||
override val zero: MST = number(0.0)
|
override val zero: MST = number(0.0)
|
||||||
override val one: MST = number(1.0)
|
override val one: MST = number(1.0)
|
||||||
|
|
||||||
|
@ -26,22 +26,21 @@ import scientifik.kmath.structures.indices
|
|||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
|
public class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
|
||||||
|
public override val dimension: Int get() = lower.size
|
||||||
|
|
||||||
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
|
public override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
|
||||||
point[i] in lower[i]..upper[i]
|
point[i] in lower[i]..upper[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
override val dimension: Int get() = lower.size
|
public override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
|
||||||
|
|
||||||
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
|
public override fun getLowerBound(num: Int): Double? = lower[num]
|
||||||
|
|
||||||
override fun getLowerBound(num: Int): Double? = lower[num]
|
public override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
|
||||||
|
|
||||||
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
|
public override fun getUpperBound(num: Int): Double? = upper[num]
|
||||||
|
|
||||||
override fun getUpperBound(num: Int): Double? = upper[num]
|
public override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||||
|
|
||||||
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
|
||||||
val res = DoubleArray(point.size) { i ->
|
val res = DoubleArray(point.size) { i ->
|
||||||
when {
|
when {
|
||||||
point[i] < lower[i] -> lower[i]
|
point[i] < lower[i] -> lower[i]
|
||||||
@ -53,16 +52,14 @@ public class HyperSquareDomain(private val lower: RealBuffer, private val upper:
|
|||||||
return RealBuffer(*res)
|
return RealBuffer(*res)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun volume(): Double {
|
public override fun volume(): Double {
|
||||||
var res = 1.0
|
var res = 1.0
|
||||||
|
|
||||||
for (i in 0 until dimension) {
|
for (i in 0 until dimension) {
|
||||||
if (lower[i].isInfinite() || upper[i].isInfinite()) {
|
if (lower[i].isInfinite() || upper[i].isInfinite()) return Double.POSITIVE_INFINITY
|
||||||
return Double.POSITIVE_INFINITY
|
if (upper[i] > lower[i]) res *= upper[i] - lower[i]
|
||||||
}
|
|
||||||
if (upper[i] > lower[i]) {
|
|
||||||
res *= upper[i] - lower[i]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,18 +17,18 @@ package scientifik.kmath.domains
|
|||||||
|
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
public class UnconstrainedDomain(override val dimension: Int) : RealDomain {
|
public class UnconstrainedDomain(public override val dimension: Int) : RealDomain {
|
||||||
override operator fun contains(point: Point<Double>): Boolean = true
|
public override operator fun contains(point: Point<Double>): Boolean = true
|
||||||
|
|
||||||
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
|
public override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
|
||||||
|
|
||||||
override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
|
public override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
|
||||||
|
|
||||||
override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
|
public override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
|
||||||
|
|
||||||
override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
|
public override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
|
||||||
|
|
||||||
override fun nearestInDomain(point: Point<Double>): Point<Double> = point
|
public override fun nearestInDomain(point: Point<Double>): Point<Double> = point
|
||||||
|
|
||||||
override fun volume(): Double = Double.POSITIVE_INFINITY
|
public override fun volume(): Double = Double.POSITIVE_INFINITY
|
||||||
}
|
}
|
||||||
|
@ -4,16 +4,20 @@ import scientifik.kmath.linear.Point
|
|||||||
import scientifik.kmath.structures.asBuffer
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
|
||||||
public inline class UnivariateDomain(public val range: ClosedFloatingPointRange<Double>) : RealDomain {
|
public inline class UnivariateDomain(public val range: ClosedFloatingPointRange<Double>) : RealDomain {
|
||||||
|
public override val dimension: Int
|
||||||
|
get() = 1
|
||||||
|
|
||||||
public operator fun contains(d: Double): Boolean = range.contains(d)
|
public operator fun contains(d: Double): Boolean = range.contains(d)
|
||||||
|
|
||||||
override operator fun contains(point: Point<Double>): Boolean {
|
public override operator fun contains(point: Point<Double>): Boolean {
|
||||||
require(point.size == 0)
|
require(point.size == 0)
|
||||||
return contains(point[0])
|
return contains(point[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
public override fun nearestInDomain(point: Point<Double>): Point<Double> {
|
||||||
require(point.size == 1)
|
require(point.size == 1)
|
||||||
val value = point[0]
|
val value = point[0]
|
||||||
|
|
||||||
return when {
|
return when {
|
||||||
value in range -> point
|
value in range -> point
|
||||||
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
|
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
|
||||||
@ -21,27 +25,25 @@ public inline class UnivariateDomain(public val range: ClosedFloatingPointRange<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun getLowerBound(num: Int, point: Point<Double>): Double? {
|
public override fun getLowerBound(num: Int, point: Point<Double>): Double? {
|
||||||
require(num == 0)
|
require(num == 0)
|
||||||
return range.start
|
return range.start
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun getUpperBound(num: Int, point: Point<Double>): Double? {
|
public override fun getUpperBound(num: Int, point: Point<Double>): Double? {
|
||||||
require(num == 0)
|
require(num == 0)
|
||||||
return range.endInclusive
|
return range.endInclusive
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun getLowerBound(num: Int): Double? {
|
public override fun getLowerBound(num: Int): Double? {
|
||||||
require(num == 0)
|
require(num == 0)
|
||||||
return range.start
|
return range.start
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun getUpperBound(num: Int): Double? {
|
public override fun getUpperBound(num: Int): Double? {
|
||||||
require(num == 0)
|
require(num == 0)
|
||||||
return range.endInclusive
|
return range.endInclusive
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun volume(): Double = range.endInclusive - range.start
|
public override fun volume(): Double = range.endInclusive - range.start
|
||||||
|
|
||||||
override val dimension: Int get() = 1
|
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,8 @@ import scientifik.kmath.operations.*
|
|||||||
|
|
||||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
||||||
Expression<T> {
|
Expression<T> {
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
|
public override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
|
context.unaryOperation(name, expr.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalBinaryOperation<T>(
|
internal class FunctionalBinaryOperation<T>(
|
||||||
@ -13,17 +14,17 @@ internal class FunctionalBinaryOperation<T>(
|
|||||||
val first: Expression<T>,
|
val first: Expression<T>,
|
||||||
val second: Expression<T>
|
val second: Expression<T>
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
public override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
public override operator fun invoke(arguments: Map<String, T>): T =
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = value
|
public override operator fun invoke(arguments: Map<String, T>): T = value
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FunctionalConstProductExpression<T>(
|
internal class FunctionalConstProductExpression<T>(
|
||||||
@ -31,7 +32,7 @@ internal class FunctionalConstProductExpression<T>(
|
|||||||
private val expr: Expression<T>,
|
private val expr: Expression<T>,
|
||||||
val const: Number
|
val const: Number
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
public override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -44,23 +45,23 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression which does not depend on arguments.
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression to access a variable.
|
* Builds an Expression to access a variable.
|
||||||
*/
|
*/
|
||||||
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
*/
|
*/
|
||||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
FunctionalBinaryOperation(algebra, operation, left, right)
|
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
*/
|
*/
|
||||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
FunctionalUnaryOperation(algebra, operation, arg)
|
FunctionalUnaryOperation(algebra, operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,18 +70,18 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
|
|||||||
*/
|
*/
|
||||||
public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||||
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||||
override val zero: Expression<T> get() = const(algebra.zero)
|
public override val zero: Expression<T> get() = const(algebra.zero)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of addition of two another expressions.
|
* Builds an Expression of addition of two another expressions.
|
||||||
*/
|
*/
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of multiplication of expression by number.
|
* Builds an Expression of multiplication of expression by number.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
||||||
FunctionalConstProductExpression(algebra, a, k)
|
FunctionalConstProductExpression(algebra, a, k)
|
||||||
|
|
||||||
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
@ -88,31 +89,31 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
|||||||
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||||
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||||
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||||
override val one: Expression<T>
|
public override val one: Expression<T>
|
||||||
get() = const(algebra.one)
|
get() = const(algebra.one)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of multiplication of two expressions.
|
* Builds an Expression of multiplication of two expressions.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,38 +123,38 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of division an expression by another one.
|
* Builds an Expression of division an expression by another one.
|
||||||
*/
|
*/
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||||
FunctionalExpressionField<T, A>(algebra),
|
FunctionalExpressionField<T, A>(algebra),
|
||||||
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||||
override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
public override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||||
override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
public override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||||
override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
public override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||||
override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
public override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||||
override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
public override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||||
|
|
||||||
override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
||||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||||
|
|
||||||
override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
public override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||||
override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionField>.unaryOperation(operation, arg)
|
super<FunctionalExpressionField>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
|
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,82 +8,82 @@ import scientifik.kmath.structures.*
|
|||||||
* Basic implementation of Matrix space based on [NDStructure]
|
* Basic implementation of Matrix space based on [NDStructure]
|
||||||
*/
|
*/
|
||||||
public class BufferMatrixContext<T : Any, R : Ring<T>>(
|
public class BufferMatrixContext<T : Any, R : Ring<T>>(
|
||||||
override val elementContext: R,
|
public override val elementContext: R,
|
||||||
private val bufferFactory: BufferFactory<T>
|
private val bufferFactory: BufferFactory<T>
|
||||||
) : GenericMatrixContext<T, R> {
|
) : GenericMatrixContext<T, R> {
|
||||||
|
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
|
|
||||||
val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
return BufferMatrix(rows, columns, buffer)
|
return BufferMatrix(rows, columns, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
public override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
public object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
public object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||||
|
public override val elementContext: RealField
|
||||||
|
get() = RealField
|
||||||
|
|
||||||
override val elementContext: RealField get() = RealField
|
public override inline fun produce(
|
||||||
|
rows: Int,
|
||||||
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
columns: Int,
|
||||||
|
initializer: (i: Int, j: Int) -> Double
|
||||||
|
): Matrix<Double> {
|
||||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
return BufferMatrix(rows, columns, buffer)
|
return BufferMatrix(rows, columns, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size, initializer)
|
public override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
|
||||||
|
RealBuffer(size, initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class BufferMatrix<T : Any>(
|
public class BufferMatrix<T : Any>(
|
||||||
override val rowNum: Int,
|
public override val rowNum: Int,
|
||||||
override val colNum: Int,
|
public override val colNum: Int,
|
||||||
public val buffer: Buffer<out T>,
|
public val buffer: Buffer<out T>,
|
||||||
override val features: Set<MatrixFeature> = emptySet()
|
public override val features: Set<MatrixFeature> = emptySet()
|
||||||
) : FeaturedMatrix<T> {
|
) : FeaturedMatrix<T> {
|
||||||
|
override val shape: IntArray
|
||||||
|
get() = intArrayOf(rowNum, colNum)
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (buffer.size != rowNum * colNum) {
|
require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" }
|
||||||
error("Dimension mismatch for matrix structure")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
|
||||||
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
||||||
|
|
||||||
override operator fun get(index: IntArray): T = get(index[0], index[1])
|
public override operator fun get(index: IntArray): T = get(index[0], index[1])
|
||||||
|
public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
||||||
|
|
||||||
override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
public override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
|
||||||
for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
public override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
|
|
||||||
return when (other) {
|
return when (other) {
|
||||||
is NDStructure<*> -> return NDStructure.equals(this, other)
|
is NDStructure<*> -> return NDStructure.equals(this, other)
|
||||||
else -> false
|
else -> false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
public override fun hashCode(): Int {
|
||||||
var result = buffer.hashCode()
|
var result = buffer.hashCode()
|
||||||
result = 31 * result + features.hashCode()
|
result = 31 * result + features.hashCode()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String {
|
public override fun toString(): String {
|
||||||
return if (rowNum <= 5 && colNum <= 5) {
|
return if (rowNum <= 5 && colNum <= 5)
|
||||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
||||||
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
|
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
|
||||||
buffer.asSequence().joinToString(separator = "\t") { it.toString() }
|
buffer.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||||
}
|
}
|
||||||
} else {
|
else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
||||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,26 +92,21 @@ public class BufferMatrix<T : Any>(
|
|||||||
*/
|
*/
|
||||||
public infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
|
public infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
|
||||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
|
||||||
val array = DoubleArray(this.rowNum * other.colNum)
|
val array = DoubleArray(this.rowNum * other.colNum)
|
||||||
|
|
||||||
//convert to array to insure there is not memory indirection
|
//convert to array to insure there is not memory indirection
|
||||||
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is RealBuffer) {
|
fun Buffer<out Double>.unsafeArray() = if (this is RealBuffer)
|
||||||
array
|
array
|
||||||
} else {
|
else
|
||||||
DoubleArray(size) { get(it) }
|
DoubleArray(size) { get(it) }
|
||||||
}
|
|
||||||
|
|
||||||
val a = this.buffer.unsafeArray()
|
val a = this.buffer.unsafeArray()
|
||||||
val b = other.buffer.unsafeArray()
|
val b = other.buffer.unsafeArray()
|
||||||
|
|
||||||
for (i in (0 until rowNum)) {
|
for (i in (0 until rowNum))
|
||||||
for (j in (0 until other.colNum)) {
|
for (j in (0 until other.colNum))
|
||||||
for (k in (0 until colNum)) {
|
for (k in (0 until colNum))
|
||||||
array[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j]
|
array[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j]
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val buffer = RealBuffer(array)
|
val buffer = RealBuffer(array)
|
||||||
return BufferMatrix(rowNum, other.colNum, buffer)
|
return BufferMatrix(rowNum, other.colNum, buffer)
|
||||||
|
@ -26,10 +26,8 @@ public interface FeaturedMatrix<T : Any> : Matrix<T> {
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> {
|
public inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
|
||||||
contract { callsInPlace(initializer) }
|
MatrixContext.real.produce(rows, columns, initializer)
|
||||||
return MatrixContext.real.produce(rows, columns, initializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build a square matrix from given elements.
|
* Build a square matrix from given elements.
|
||||||
|
@ -76,7 +76,6 @@ public inline fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.l
|
|||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
checkSingular: (T) -> Boolean
|
checkSingular: (T) -> Boolean
|
||||||
): LUPDecomposition<T> {
|
): LUPDecomposition<T> {
|
||||||
contract { callsInPlace(checkSingular) }
|
|
||||||
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
|
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
|
||||||
val m = matrix.colNum
|
val m = matrix.colNum
|
||||||
val pivot = IntArray(matrix.rowNum)
|
val pivot = IntArray(matrix.rowNum)
|
||||||
@ -153,10 +152,7 @@ public inline fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.l
|
|||||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
checkSingular: (T) -> Boolean
|
checkSingular: (T) -> Boolean
|
||||||
): LUPDecomposition<T> {
|
): LUPDecomposition<T> = lup(T::class, matrix, checkSingular)
|
||||||
contract { callsInPlace(checkSingular) }
|
|
||||||
return lup(T::class, matrix, checkSingular)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
|
public fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
|
||||||
lup(Double::class, matrix) { it < 1e-11 }
|
lup(Double::class, matrix) { it < 1e-11 }
|
||||||
@ -216,7 +212,6 @@ public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext
|
|||||||
b: Matrix<T>,
|
b: Matrix<T>,
|
||||||
checkSingular: (T) -> Boolean
|
checkSingular: (T) -> Boolean
|
||||||
): Matrix<T> {
|
): Matrix<T> {
|
||||||
contract { callsInPlace(checkSingular) }
|
|
||||||
// Use existing decomposition if it is provided by matrix
|
// Use existing decomposition if it is provided by matrix
|
||||||
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
|
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
|
||||||
return decomposition.solve(T::class, b)
|
return decomposition.solve(T::class, b)
|
||||||
@ -227,10 +222,7 @@ public fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix
|
|||||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
checkSingular: (T) -> Boolean
|
checkSingular: (T) -> Boolean
|
||||||
): Matrix<T> {
|
): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
||||||
contract { callsInPlace(checkSingular) }
|
|
||||||
return solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
|
public fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
|
||||||
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }
|
||||||
|
@ -19,10 +19,9 @@ public interface LinearSolver<T : Any> {
|
|||||||
* Convert matrix to vector if it is possible
|
* Convert matrix to vector if it is possible
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> Matrix<T>.asPoint(): Point<T> =
|
public fun <T : Any> Matrix<T>.asPoint(): Point<T> =
|
||||||
if (this.colNum == 1) {
|
if (this.colNum == 1)
|
||||||
VirtualBuffer(rowNum) { get(it, 0) }
|
VirtualBuffer(rowNum) { get(it, 0) }
|
||||||
} else {
|
else
|
||||||
error("Can't convert matrix with more than one column to vector")
|
error("Can't convert matrix with more than one column to vector")
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
||||||
|
@ -12,10 +12,8 @@ import kotlin.jvm.JvmName
|
|||||||
* @param R the type of resulting iterable.
|
* @param R the type of resulting iterable.
|
||||||
* @param initial lazy evaluated.
|
* @param initial lazy evaluated.
|
||||||
*/
|
*/
|
||||||
public inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> {
|
public inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> =
|
||||||
contract { callsInPlace(operation) }
|
object : Iterator<R> {
|
||||||
|
|
||||||
return object : Iterator<R> {
|
|
||||||
var state: R = initial
|
var state: R = initial
|
||||||
|
|
||||||
override fun hasNext(): Boolean = this@cumulative.hasNext()
|
override fun hasNext(): Boolean = this@cumulative.hasNext()
|
||||||
@ -25,7 +23,6 @@ public inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operatio
|
|||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> =
|
public inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> =
|
||||||
Iterable { this@cumulative.iterator().cumulative(initial, operation) }
|
Iterable { this@cumulative.iterator().cumulative(initial, operation) }
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stub for DSL the [Algebra] is.
|
* Stub for DSL the [Algebra] is.
|
||||||
*/
|
*/
|
||||||
|
@ -40,18 +40,17 @@ public class BigInt internal constructor(
|
|||||||
private val sign: Byte,
|
private val sign: Byte,
|
||||||
private val magnitude: Magnitude
|
private val magnitude: Magnitude
|
||||||
) : Comparable<BigInt> {
|
) : Comparable<BigInt> {
|
||||||
|
public override fun compareTo(other: BigInt): Int = when {
|
||||||
override fun compareTo(other: BigInt): Int = when {
|
|
||||||
(this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0
|
(this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0
|
||||||
this.sign < other.sign -> -1
|
this.sign < other.sign -> -1
|
||||||
this.sign > other.sign -> 1
|
this.sign > other.sign -> 1
|
||||||
else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude)
|
else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean =
|
public override fun equals(other: Any?): Boolean =
|
||||||
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
|
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
|
||||||
|
|
||||||
override fun hashCode(): Int = magnitude.hashCode() + sign
|
public override fun hashCode(): Int = magnitude.hashCode() + sign
|
||||||
|
|
||||||
public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude)
|
public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude)
|
||||||
|
|
||||||
@ -456,15 +455,11 @@ public fun String.parseBigInteger(): BigInt? {
|
|||||||
return res * sign
|
return res * sign
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> {
|
public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
||||||
contract { callsInPlace(initializer) }
|
boxing(size, initializer)
|
||||||
return boxing(size, initializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> {
|
public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
|
||||||
contract { callsInPlace(initializer) }
|
boxing(size, initializer)
|
||||||
return boxing(size, initializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
||||||
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
||||||
|
@ -6,7 +6,6 @@ import scientifik.kmath.structures.MutableBuffer
|
|||||||
import scientifik.memory.MemoryReader
|
import scientifik.memory.MemoryReader
|
||||||
import scientifik.memory.MemorySpec
|
import scientifik.memory.MemorySpec
|
||||||
import scientifik.memory.MemoryWriter
|
import scientifik.memory.MemoryWriter
|
||||||
import kotlin.contracts.contract
|
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -165,7 +164,8 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
|
|||||||
* @property re The real part.
|
* @property re The real part.
|
||||||
* @property im The imaginary part.
|
* @property im The imaginary part.
|
||||||
*/
|
*/
|
||||||
public data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
public data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>,
|
||||||
|
Comparable<Complex> {
|
||||||
public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
||||||
|
|
||||||
override val context: ComplexField get() = ComplexField
|
override val context: ComplexField get() = ComplexField
|
||||||
@ -197,12 +197,8 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
|
|||||||
*/
|
*/
|
||||||
public fun Number.toComplex(): Complex = Complex(this, 0.0)
|
public fun Number.toComplex(): Complex = Complex(this, 0.0)
|
||||||
|
|
||||||
public inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
public inline fun Buffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
|
||||||
contract { callsInPlace(init) }
|
MemoryBuffer.create(Complex, size, init)
|
||||||
return MemoryBuffer.create(Complex, size, init)
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
|
||||||
contract { callsInPlace(init) }
|
MemoryBuffer.create(Complex, size, init)
|
||||||
return MemoryBuffer.create(Complex, size, init)
|
|
||||||
}
|
|
||||||
|
@ -62,7 +62,7 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
|||||||
*
|
*
|
||||||
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
|
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
|
||||||
*/
|
*/
|
||||||
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
public inline class Real(public val value: Double) : FieldElement<Double, Real, RealField> {
|
||||||
override val context: RealField
|
override val context: RealField
|
||||||
get() = RealField
|
get() = RealField
|
||||||
|
|
||||||
@ -70,14 +70,14 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
|||||||
|
|
||||||
override fun Double.wrap(): Real = Real(value)
|
override fun Double.wrap(): Real = Real(value)
|
||||||
|
|
||||||
companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for [Double] without boxing. Does not produce appropriate field element.
|
* A field for [Double] without boxing. Does not produce appropriate field element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
public object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||||
override val zero: Double
|
override val zero: Double
|
||||||
get() = 0.0
|
get() = 0.0
|
||||||
|
|
||||||
@ -127,7 +127,7 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
|||||||
* A field for [Float] without boxing. Does not produce appropriate field element.
|
* A field for [Float] without boxing. Does not produce appropriate field element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||||
override val zero: Float
|
override val zero: Float
|
||||||
get() = 0.0f
|
get() = 0.0f
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
|||||||
* A field for [Int] without boxing. Does not produce corresponding ring element.
|
* A field for [Int] without boxing. Does not produce corresponding ring element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object IntRing : Ring<Int>, Norm<Int, Int> {
|
public object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||||
override val zero: Int
|
override val zero: Int
|
||||||
get() = 0
|
get() = 0
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ object IntRing : Ring<Int>, Norm<Int, Int> {
|
|||||||
* A field for [Short] without boxing. Does not produce appropriate ring element.
|
* A field for [Short] without boxing. Does not produce appropriate ring element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object ShortRing : Ring<Short>, Norm<Short, Short> {
|
public object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||||
override val zero: Short
|
override val zero: Short
|
||||||
get() = 0
|
get() = 0
|
||||||
|
|
||||||
@ -225,7 +225,7 @@ object ShortRing : Ring<Short>, Norm<Short, Short> {
|
|||||||
* A field for [Byte] without boxing. Does not produce appropriate ring element.
|
* A field for [Byte] without boxing. Does not produce appropriate ring element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||||
override val zero: Byte
|
override val zero: Byte
|
||||||
get() = 0
|
get() = 0
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
|||||||
* A field for [Double] without boxing. Does not produce appropriate ring element.
|
* A field for [Double] without boxing. Does not produce appropriate ring element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
object LongRing : Ring<Long>, Norm<Long, Long> {
|
public object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||||
override val zero: Long
|
override val zero: Long
|
||||||
get() = 0
|
get() = 0
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ public class BoxingNDRing<T, R : Ring<T>>(
|
|||||||
transform: R.(T, T) -> T
|
transform: R.(T, T) -> T
|
||||||
): BufferedNDRingElement<T, R> {
|
): BufferedNDRingElement<T, R> {
|
||||||
check(a, b)
|
check(a, b)
|
||||||
|
|
||||||
return BufferedNDRingElement(
|
return BufferedNDRingElement(
|
||||||
this,
|
this,
|
||||||
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
|
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
|
||||||
|
@ -5,7 +5,7 @@ import scientifik.kmath.operations.*
|
|||||||
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||||
public val strides: Strides
|
public val strides: Strides
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>): Unit =
|
public override fun check(vararg elements: NDBuffer<T>): Unit =
|
||||||
require(elements.all { it.strides == strides }) { ("Strides mismatch") }
|
require(elements.all { it.strides == strides }) { ("Strides mismatch") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -29,7 +29,7 @@ public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
|||||||
|
|
||||||
|
|
||||||
public interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
|
public interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
|
||||||
override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
|
public override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
|
||||||
}
|
}
|
||||||
|
|
||||||
public interface BufferedNDRing<T, R : Ring<T>> : NDRing<T, R, NDBuffer<T>>, BufferedNDSpace<T, R> {
|
public interface BufferedNDRing<T, R : Ring<T>> : NDRing<T, R, NDBuffer<T>>, BufferedNDSpace<T, R> {
|
||||||
|
@ -2,8 +2,6 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
import kotlin.contracts.ExperimentalContracts
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -56,7 +54,8 @@ public interface Buffer<T> {
|
|||||||
/**
|
/**
|
||||||
* Create a boxing buffer of given type
|
* Create a boxing buffer of given type
|
||||||
*/
|
*/
|
||||||
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer))
|
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
|
ListBuffer(List(size, initializer))
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
public inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
|
||||||
@ -115,11 +114,11 @@ public interface MutableBuffer<T> : Buffer<T> {
|
|||||||
/**
|
/**
|
||||||
* Create a boxing mutable buffer of given type
|
* Create a boxing mutable buffer of given type
|
||||||
*/
|
*/
|
||||||
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
MutableListBuffer(MutableList(size, initializer))
|
MutableListBuffer(MutableList(size, initializer))
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
public inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
when (type) {
|
when (type) {
|
||||||
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
|
||||||
@ -132,12 +131,11 @@ public interface MutableBuffer<T> : Buffer<T> {
|
|||||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
public inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
auto(T::class, size, initializer)
|
auto(T::class, size, initializer)
|
||||||
|
|
||||||
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double ->
|
public val real: MutableBufferFactory<Double> =
|
||||||
RealBuffer(DoubleArray(size) { initializer(it) })
|
{ size, initializer -> RealBuffer(DoubleArray(size) { initializer(it) }) }
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,7 +145,7 @@ public interface MutableBuffer<T> : Buffer<T> {
|
|||||||
* @param T the type of elements contained in the buffer.
|
* @param T the type of elements contained in the buffer.
|
||||||
* @property list The underlying list.
|
* @property list The underlying list.
|
||||||
*/
|
*/
|
||||||
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
public inline class ListBuffer<T>(public val list: List<T>) : Buffer<T> {
|
||||||
override val size: Int
|
override val size: Int
|
||||||
get() = list.size
|
get() = list.size
|
||||||
|
|
||||||
@ -158,7 +156,7 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
|||||||
/**
|
/**
|
||||||
* Returns an [ListBuffer] that wraps the original list.
|
* Returns an [ListBuffer] that wraps the original list.
|
||||||
*/
|
*/
|
||||||
fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
public fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified
|
* Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified
|
||||||
@ -167,10 +165,7 @@ fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an array element given its index.
|
* It should return the value for an array element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> {
|
public inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer()
|
||||||
contract { callsInPlace(init) }
|
|
||||||
return List(size, init).asBuffer()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [MutableBuffer] implementation over [MutableList].
|
* [MutableBuffer] implementation over [MutableList].
|
||||||
@ -178,7 +173,7 @@ inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> {
|
|||||||
* @param T the type of elements contained in the buffer.
|
* @param T the type of elements contained in the buffer.
|
||||||
* @property list The underlying list.
|
* @property list The underlying list.
|
||||||
*/
|
*/
|
||||||
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
public inline class MutableListBuffer<T>(public val list: MutableList<T>) : MutableBuffer<T> {
|
||||||
override val size: Int
|
override val size: Int
|
||||||
get() = list.size
|
get() = list.size
|
||||||
|
|
||||||
@ -198,7 +193,7 @@ inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
|||||||
* @param T the type of elements contained in the buffer.
|
* @param T the type of elements contained in the buffer.
|
||||||
* @property array The underlying array.
|
* @property array The underlying array.
|
||||||
*/
|
*/
|
||||||
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
public class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
||||||
// Can't inline because array is invariant
|
// Can't inline because array is invariant
|
||||||
override val size: Int
|
override val size: Int
|
||||||
get() = array.size
|
get() = array.size
|
||||||
|
@ -4,7 +4,6 @@ import scientifik.kmath.operations.Complex
|
|||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
import scientifik.kmath.operations.FieldElement
|
import scientifik.kmath.operations.FieldElement
|
||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
import kotlin.contracts.ExperimentalContracts
|
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
@ -98,7 +97,7 @@ public class ComplexNDField(override val shape: IntArray) :
|
|||||||
/**
|
/**
|
||||||
* Fast element production using function inlining
|
* Fast element production using function inlining
|
||||||
*/
|
*/
|
||||||
inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline initializer: ComplexField.(Int) -> Complex): ComplexNDElement {
|
public inline fun BufferedNDField<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): ComplexNDElement {
|
||||||
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
|
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
|
||||||
return BufferedNDFieldElement(this, buffer)
|
return BufferedNDFieldElement(this, buffer)
|
||||||
}
|
}
|
||||||
@ -106,14 +105,13 @@ inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline init
|
|||||||
/**
|
/**
|
||||||
* Map one [ComplexNDElement] using function with indices.
|
* Map one [ComplexNDElement] using function with indices.
|
||||||
*/
|
*/
|
||||||
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
|
public inline fun ComplexNDElement.mapIndexed(transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
|
||||||
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map one [ComplexNDElement] using function without indices.
|
* Map one [ComplexNDElement] using function without indices.
|
||||||
*/
|
*/
|
||||||
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
public inline fun ComplexNDElement.map(transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
||||||
contract { callsInPlace(transform) }
|
|
||||||
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
||||||
return BufferedNDFieldElement(context, buffer)
|
return BufferedNDFieldElement(context, buffer)
|
||||||
}
|
}
|
||||||
@ -121,10 +119,9 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) ->
|
|||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||||
*/
|
*/
|
||||||
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
|
public operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
|
||||||
ndElement.map { this@invoke(it) }
|
ndElement.map { this@invoke(it) }
|
||||||
|
|
||||||
|
|
||||||
/* plus and minus */
|
/* plus and minus */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -142,8 +139,10 @@ public operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement = map
|
|||||||
|
|
||||||
public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
||||||
|
|
||||||
public fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement =
|
public fun NDElement.Companion.complex(
|
||||||
NDField.complex(*shape).produce(initializer)
|
vararg shape: Int,
|
||||||
|
initializer: ComplexField.(IntArray) -> Complex
|
||||||
|
): ComplexNDElement = NDField.complex(*shape).produce(initializer)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produce a context for n-dimensional operations inside this real field
|
* Produce a context for n-dimensional operations inside this real field
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlin.contracts.ExperimentalContracts
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
import kotlin.experimental.and
|
import kotlin.experimental.and
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -34,23 +32,23 @@ public enum class ValueFlag(public val mask: Byte) {
|
|||||||
/**
|
/**
|
||||||
* A buffer with flagged values.
|
* A buffer with flagged values.
|
||||||
*/
|
*/
|
||||||
interface FlaggedBuffer<T> : Buffer<T> {
|
public interface FlaggedBuffer<T> : Buffer<T> {
|
||||||
fun getFlag(index: Int): Byte
|
public fun getFlag(index: Int): Byte
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The value is valid if all flags are down
|
* The value is valid if all flags are down
|
||||||
*/
|
*/
|
||||||
fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
|
public fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
|
||||||
|
|
||||||
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
|
public fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
|
||||||
|
|
||||||
fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
|
public fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A real buffer which supports flags for each value like NaN or Missing
|
* A real buffer which supports flags for each value like NaN or Missing
|
||||||
*/
|
*/
|
||||||
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
|
||||||
init {
|
init {
|
||||||
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
require(values.size == flags.size) { "Values and flags must have the same dimensions" }
|
||||||
}
|
}
|
||||||
@ -66,9 +64,7 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
|
|||||||
}.iterator()
|
}.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
public inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
|
||||||
contract { callsInPlace(block) }
|
|
||||||
|
|
||||||
indices
|
indices
|
||||||
.asSequence()
|
.asSequence()
|
||||||
.filter(::isValid)
|
.filter(::isValid)
|
||||||
|
@ -8,7 +8,7 @@ import kotlin.contracts.contract
|
|||||||
*
|
*
|
||||||
* @property array the underlying array.
|
* @property array the underlying array.
|
||||||
*/
|
*/
|
||||||
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
|
public inline class FloatBuffer(public val array: FloatArray) : MutableBuffer<Float> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Float = array[index]
|
override operator fun get(index: Int): Float = array[index]
|
||||||
@ -30,20 +30,17 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer {
|
public inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) })
|
||||||
contract { callsInPlace(init) }
|
|
||||||
return FloatBuffer(FloatArray(size) { init(it) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [FloatBuffer] of given elements.
|
* Returns a new [FloatBuffer] of given elements.
|
||||||
*/
|
*/
|
||||||
fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats)
|
public fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a [FloatArray] containing all of the elements of this [MutableBuffer].
|
* Returns a [FloatArray] containing all of the elements of this [MutableBuffer].
|
||||||
*/
|
*/
|
||||||
val MutableBuffer<out Float>.array: FloatArray
|
public val MutableBuffer<out Float>.array: FloatArray
|
||||||
get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) })
|
get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -52,4 +49,4 @@ val MutableBuffer<out Float>.array: FloatArray
|
|||||||
* @receiver the array.
|
* @receiver the array.
|
||||||
* @return the new buffer.
|
* @return the new buffer.
|
||||||
*/
|
*/
|
||||||
fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this)
|
public fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this)
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlin.contracts.ExperimentalContracts
|
|
||||||
import kotlin.contracts.InvocationKind
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized [MutableBuffer] implementation over [IntArray].
|
* Specialized [MutableBuffer] implementation over [IntArray].
|
||||||
*
|
*
|
||||||
@ -31,20 +27,17 @@ public inline class IntBuffer(public val array: IntArray) : MutableBuffer<Int> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer {
|
public inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) })
|
||||||
contract { callsInPlace(init) }
|
|
||||||
return IntBuffer(IntArray(size) { init(it) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [IntBuffer] of given elements.
|
* Returns a new [IntBuffer] of given elements.
|
||||||
*/
|
*/
|
||||||
fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints)
|
public fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
|
* Returns a [IntArray] containing all of the elements of this [MutableBuffer].
|
||||||
*/
|
*/
|
||||||
val MutableBuffer<out Int>.array: IntArray
|
public val MutableBuffer<out Int>.array: IntArray
|
||||||
get() = (if (this is IntBuffer) array else IntArray(size) { get(it) })
|
get() = (if (this is IntBuffer) array else IntArray(size) { get(it) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -53,4 +46,4 @@ val MutableBuffer<out Int>.array: IntArray
|
|||||||
* @receiver the array.
|
* @receiver the array.
|
||||||
* @return the new buffer.
|
* @return the new buffer.
|
||||||
*/
|
*/
|
||||||
fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)
|
public fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)
|
||||||
|
@ -31,10 +31,7 @@ public inline class LongBuffer(public val array: LongArray) : MutableBuffer<Long
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer {
|
public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) })
|
||||||
contract { callsInPlace(init) }
|
|
||||||
return LongBuffer(LongArray(size) { init(it) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [LongBuffer] of given elements.
|
* Returns a new [LongBuffer] of given elements.
|
||||||
|
@ -9,7 +9,7 @@ import scientifik.memory.*
|
|||||||
* @property memory the underlying memory segment.
|
* @property memory the underlying memory segment.
|
||||||
* @property spec the spec of [T] type.
|
* @property spec the spec of [T] type.
|
||||||
*/
|
*/
|
||||||
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
|
public open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
|
||||||
override val size: Int get() = memory.size / spec.objectSize
|
override val size: Int get() = memory.size / spec.objectSize
|
||||||
|
|
||||||
private val reader: MemoryReader = memory.reader()
|
private val reader: MemoryReader = memory.reader()
|
||||||
@ -17,19 +17,16 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
|
|||||||
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
||||||
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
||||||
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
||||||
|
|
||||||
inline fun <T : Any> create(
|
public inline fun <T : Any> create(
|
||||||
spec: MemorySpec<T>,
|
spec: MemorySpec<T>,
|
||||||
size: Int,
|
size: Int,
|
||||||
crossinline initializer: (Int) -> T
|
initializer: (Int) -> T
|
||||||
): MemoryBuffer<T> =
|
): MemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
||||||
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
(0 until size).forEach { buffer[it] = initializer(it) }
|
||||||
(0 until size).forEach {
|
|
||||||
buffer[it] = initializer(it)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -41,7 +38,7 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
|
|||||||
* @property memory the underlying memory segment.
|
* @property memory the underlying memory segment.
|
||||||
* @property spec the spec of [T] type.
|
* @property spec the spec of [T] type.
|
||||||
*/
|
*/
|
||||||
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
|
public class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
|
||||||
MutableBuffer<T> {
|
MutableBuffer<T> {
|
||||||
|
|
||||||
private val writer: MemoryWriter = memory.writer()
|
private val writer: MemoryWriter = memory.writer()
|
||||||
@ -49,19 +46,16 @@ class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : Memory
|
|||||||
override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
|
override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
|
||||||
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
|
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
|
public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
|
||||||
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
|
||||||
|
|
||||||
inline fun <T : Any> create(
|
public inline fun <T : Any> create(
|
||||||
spec: MemorySpec<T>,
|
spec: MemorySpec<T>,
|
||||||
size: Int,
|
size: Int,
|
||||||
crossinline initializer: (Int) -> T
|
crossinline initializer: (Int) -> T
|
||||||
): MutableMemoryBuffer<T> =
|
): MutableMemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
||||||
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
|
(0 until size).forEach { buffer[it] = initializer(it) }
|
||||||
(0 until size).forEach {
|
|
||||||
buffer[it] = initializer(it)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -115,19 +115,18 @@ public interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing
|
|||||||
|
|
||||||
public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
|
public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
|
private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
|
||||||
private val realNDFieldCache = HashMap<IntArray, RealNDField>()
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
||||||
*/
|
*/
|
||||||
fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a nd-field with boxing generic buffer
|
* Create a nd-field with boxing generic buffer
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Field<T>> boxing(
|
public fun <T : Any, F : Field<T>> boxing(
|
||||||
field: F,
|
field: F,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
@ -137,7 +136,7 @@ public interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing
|
|||||||
* Create a most suitable implementation for nd-field using reified class.
|
* Create a most suitable implementation for nd-field using reified class.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): BufferedNDField<T, F> =
|
public inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): BufferedNDField<T, F> =
|
||||||
when {
|
when {
|
||||||
T::class == Double::class -> real(*shape) as BufferedNDField<T, F>
|
T::class == Double::class -> real(*shape) as BufferedNDField<T, F>
|
||||||
T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
|
T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
|
||||||
|
@ -4,6 +4,7 @@ import scientifik.kmath.operations.Field
|
|||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The root for all [NDStructure] based algebra elements. Does not implement algebra element root because of problems with recursive self-types
|
* The root for all [NDStructure] based algebra elements. Does not implement algebra element root because of problems with recursive self-types
|
||||||
@ -11,31 +12,30 @@ import scientifik.kmath.operations.Space
|
|||||||
* @param C the type of the context for the element
|
* @param C the type of the context for the element
|
||||||
* @param N the type of the underlying [NDStructure]
|
* @param N the type of the underlying [NDStructure]
|
||||||
*/
|
*/
|
||||||
interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
public interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
||||||
|
public val context: NDAlgebra<T, C, N>
|
||||||
|
|
||||||
val context: NDAlgebra<T, C, N>
|
public fun unwrap(): N
|
||||||
|
|
||||||
fun unwrap(): N
|
public fun N.wrap(): NDElement<T, C, N>
|
||||||
|
|
||||||
fun N.wrap(): NDElement<T, C, N>
|
public companion object {
|
||||||
|
|
||||||
companion object {
|
|
||||||
/**
|
/**
|
||||||
* Create a optimized NDArray of doubles
|
* Create a optimized NDArray of doubles
|
||||||
*/
|
*/
|
||||||
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
|
public fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
|
||||||
NDField.real(*shape).produce(initializer)
|
NDField.real(*shape).produce(initializer)
|
||||||
|
|
||||||
inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
|
public inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
|
||||||
real(intArrayOf(dim)) { initializer(it[0]) }
|
real(intArrayOf(dim)) { initializer(it[0]) }
|
||||||
|
|
||||||
inline fun real2D(
|
public inline fun real2D(
|
||||||
dim1: Int,
|
dim1: Int,
|
||||||
dim2: Int,
|
dim2: Int,
|
||||||
crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 }
|
crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 }
|
||||||
): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||||
|
|
||||||
inline fun real3D(
|
public inline fun real3D(
|
||||||
dim1: Int,
|
dim1: Int,
|
||||||
dim2: Int,
|
dim2: Int,
|
||||||
dim3: Int,
|
dim3: Int,
|
||||||
@ -46,7 +46,7 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Simple boxing NDArray
|
* Simple boxing NDArray
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Field<T>> boxing(
|
public fun <T : Any, F : Field<T>> boxing(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
field: F,
|
field: F,
|
||||||
initializer: F.(IntArray) -> T
|
initializer: F.(IntArray) -> T
|
||||||
@ -55,7 +55,7 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
|||||||
return ndField.produce(initializer)
|
return ndField.produce(initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any, F : Field<T>> auto(
|
public inline fun <reified T : Any, F : Field<T>> auto(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
field: F,
|
field: F,
|
||||||
noinline initializer: F.(IntArray) -> T
|
noinline initializer: F.(IntArray) -> T
|
||||||
@ -66,17 +66,16 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
|
||||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
|
|
||||||
context.mapIndexed(unwrap(), transform).wrap()
|
context.mapIndexed(unwrap(), transform).wrap()
|
||||||
|
|
||||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
|
public fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
|
||||||
context.map(unwrap(), transform).wrap()
|
context.map(unwrap(), transform).wrap()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole [NDElement]
|
* Element by element application of any operation on elements to the whole [NDElement]
|
||||||
*/
|
*/
|
||||||
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> =
|
public operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> =
|
||||||
ndElement.map { value -> this@invoke(value) }
|
ndElement.map { value -> this@invoke(value) }
|
||||||
|
|
||||||
/* plus and minus */
|
/* plus and minus */
|
||||||
@ -84,13 +83,13 @@ operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElem
|
|||||||
/**
|
/**
|
||||||
* Summation operation for [NDElement] and single element
|
* Summation operation for [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> =
|
public operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> =
|
||||||
map { value -> arg + value }
|
map { value -> arg + value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [NDElement] and single element
|
* Subtraction operation between [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> =
|
public operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> =
|
||||||
map { value -> arg - value }
|
map { value -> arg - value }
|
||||||
|
|
||||||
/* prod and div */
|
/* prod and div */
|
||||||
@ -98,13 +97,13 @@ operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg:
|
|||||||
/**
|
/**
|
||||||
* Product operation for [NDElement] and single element
|
* Product operation for [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> =
|
public operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> =
|
||||||
map { value -> arg * value }
|
map { value -> arg * value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Division operation between [NDElement] and single element
|
* Division operation between [NDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
|
public operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
|
||||||
map { value -> arg / value }
|
map { value -> arg / value }
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlin.contracts.ExperimentalContracts
|
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
@ -12,17 +11,17 @@ import kotlin.reflect.KClass
|
|||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
*/
|
*/
|
||||||
interface NDStructure<T> {
|
public interface NDStructure<T> {
|
||||||
/**
|
/**
|
||||||
* The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of
|
* The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of
|
||||||
* this structure.
|
* this structure.
|
||||||
*/
|
*/
|
||||||
val shape: IntArray
|
public val shape: IntArray
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The count of dimensions in this structure. It should be equal to size of [shape].
|
* The count of dimensions in this structure. It should be equal to size of [shape].
|
||||||
*/
|
*/
|
||||||
val dimension: Int get() = shape.size
|
public val dimension: Int get() = shape.size
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the value at the specified indices.
|
* Returns the value at the specified indices.
|
||||||
@ -30,24 +29,24 @@ interface NDStructure<T> {
|
|||||||
* @param index the indices.
|
* @param index the indices.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
operator fun get(index: IntArray): T
|
public operator fun get(index: IntArray): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the sequence of all the elements associated by their indices.
|
* Returns the sequence of all the elements associated by their indices.
|
||||||
*
|
*
|
||||||
* @return the lazy sequence of pairs of indices to values.
|
* @return the lazy sequence of pairs of indices to values.
|
||||||
*/
|
*/
|
||||||
fun elements(): Sequence<Pair<IntArray, T>>
|
public fun elements(): Sequence<Pair<IntArray, T>>
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean
|
override fun equals(other: Any?): Boolean
|
||||||
|
|
||||||
override fun hashCode(): Int
|
override fun hashCode(): Int
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* Indicates whether some [NDStructure] is equal to another one.
|
* Indicates whether some [NDStructure] is equal to another one.
|
||||||
*/
|
*/
|
||||||
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
public fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||||
if (st1 === st2) return true
|
if (st1 === st2) return true
|
||||||
|
|
||||||
// fast comparison of buffers if possible
|
// fast comparison of buffers if possible
|
||||||
@ -68,7 +67,7 @@ interface NDStructure<T> {
|
|||||||
*
|
*
|
||||||
* Strides should be reused if possible.
|
* Strides should be reused if possible.
|
||||||
*/
|
*/
|
||||||
fun <T> build(
|
public fun <T> build(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T
|
||||||
@ -78,39 +77,39 @@ interface NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
||||||
*/
|
*/
|
||||||
inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
inline fun <T : Any> auto(
|
public inline fun <T : Any> auto(
|
||||||
type: KClass<T>,
|
type: KClass<T>,
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
fun <T> build(
|
public fun <T> build(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T
|
||||||
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||||
|
|
||||||
inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
@JvmName("autoVarArg")
|
@JvmName("autoVarArg")
|
||||||
inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
inline fun <T : Any> auto(
|
public inline fun <T : Any> auto(
|
||||||
type: KClass<T>,
|
type: KClass<T>,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T
|
||||||
@ -125,68 +124,68 @@ interface NDStructure<T> {
|
|||||||
* @param index the indices.
|
* @param index the indices.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
public operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents mutable [NDStructure].
|
* Represents mutable [NDStructure].
|
||||||
*/
|
*/
|
||||||
interface MutableNDStructure<T> : NDStructure<T> {
|
public interface MutableNDStructure<T> : NDStructure<T> {
|
||||||
/**
|
/**
|
||||||
* Inserts an item at the specified indices.
|
* Inserts an item at the specified indices.
|
||||||
*
|
*
|
||||||
* @param index the indices.
|
* @param index the indices.
|
||||||
* @param value the value.
|
* @param value the value.
|
||||||
*/
|
*/
|
||||||
operator fun set(index: IntArray, value: T)
|
public operator fun set(index: IntArray, value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
public inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
|
||||||
contract { callsInPlace(action) }
|
|
||||||
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A way to convert ND index to linear one and back.
|
* A way to convert ND index to linear one and back.
|
||||||
*/
|
*/
|
||||||
interface Strides {
|
public interface Strides {
|
||||||
/**
|
/**
|
||||||
* Shape of NDstructure
|
* Shape of NDstructure
|
||||||
*/
|
*/
|
||||||
val shape: IntArray
|
public val shape: IntArray
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Array strides
|
* Array strides
|
||||||
*/
|
*/
|
||||||
val strides: List<Int>
|
public val strides: List<Int>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get linear index from multidimensional index
|
* Get linear index from multidimensional index
|
||||||
*/
|
*/
|
||||||
fun offset(index: IntArray): Int
|
public fun offset(index: IntArray): Int
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get multidimensional from linear
|
* Get multidimensional from linear
|
||||||
*/
|
*/
|
||||||
fun index(offset: Int): IntArray
|
public fun index(offset: Int): IntArray
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||||
*/
|
*/
|
||||||
val linearSize: Int
|
public val linearSize: Int
|
||||||
|
|
||||||
|
// TODO introduce a fast way to calculate index of the next element?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Iterate over ND indices in a natural order
|
* Iterate over ND indices in a natural order
|
||||||
*/
|
*/
|
||||||
fun indices(): Sequence<IntArray> {
|
public fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map { index(it) }
|
||||||
//TODO introduce a fast way to calculate index of the next element?
|
|
||||||
return (0 until linearSize).asSequence().map { index(it) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple implementation of [Strides].
|
* Simple implementation of [Strides].
|
||||||
*/
|
*/
|
||||||
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
||||||
|
override val linearSize: Int
|
||||||
|
get() = strides[shape.size]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Strides for memory access
|
* Strides for memory access
|
||||||
*/
|
*/
|
||||||
@ -194,6 +193,7 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
sequence {
|
sequence {
|
||||||
var current = 1
|
var current = 1
|
||||||
yield(1)
|
yield(1)
|
||||||
|
|
||||||
shape.forEach {
|
shape.forEach {
|
||||||
current *= it
|
current *= it
|
||||||
yield(current)
|
yield(current)
|
||||||
@ -212,17 +212,16 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
val res = IntArray(shape.size)
|
val res = IntArray(shape.size)
|
||||||
var current = offset
|
var current = offset
|
||||||
var strideIndex = strides.size - 2
|
var strideIndex = strides.size - 2
|
||||||
|
|
||||||
while (strideIndex >= 0) {
|
while (strideIndex >= 0) {
|
||||||
res[strideIndex] = (current / strides[strideIndex])
|
res[strideIndex] = (current / strides[strideIndex])
|
||||||
current %= strides[strideIndex]
|
current %= strides[strideIndex]
|
||||||
strideIndex--
|
strideIndex--
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
override val linearSize: Int
|
|
||||||
get() = strides[shape.size]
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other !is DefaultStrides) return false
|
if (other !is DefaultStrides) return false
|
||||||
@ -232,13 +231,14 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
|
|
||||||
override fun hashCode(): Int = shape.contentHashCode()
|
override fun hashCode(): Int = shape.contentHashCode()
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cached builder for default strides
|
* Cached builder for default strides
|
||||||
*/
|
*/
|
||||||
operator fun invoke(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
public operator fun invoke(shape: IntArray): Strides =
|
||||||
|
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,16 +247,16 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
|||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
*/
|
*/
|
||||||
abstract class NDBuffer<T> : NDStructure<T> {
|
public abstract class NDBuffer<T> : NDStructure<T> {
|
||||||
/**
|
/**
|
||||||
* The underlying buffer.
|
* The underlying buffer.
|
||||||
*/
|
*/
|
||||||
abstract val buffer: Buffer<T>
|
public abstract val buffer: Buffer<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The strides to access elements of [Buffer] by linear indices.
|
* The strides to access elements of [Buffer] by linear indices.
|
||||||
*/
|
*/
|
||||||
abstract val strides: Strides
|
public abstract val strides: Strides
|
||||||
|
|
||||||
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||||
|
|
||||||
@ -278,7 +278,7 @@ abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Boxing generic [NDStructure]
|
* Boxing generic [NDStructure]
|
||||||
*/
|
*/
|
||||||
class BufferNDStructure<T>(
|
public class BufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>
|
||||||
) : NDBuffer<T>() {
|
) : NDBuffer<T>() {
|
||||||
@ -292,13 +292,13 @@ class BufferNDStructure<T>(
|
|||||||
/**
|
/**
|
||||||
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
|
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
|
||||||
*/
|
*/
|
||||||
inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
||||||
factory: BufferFactory<R> = Buffer.Companion::auto,
|
factory: BufferFactory<R> = Buffer.Companion::auto,
|
||||||
crossinline transform: (T) -> R
|
crossinline transform: (T) -> R
|
||||||
): BufferNDStructure<R> {
|
): BufferNDStructure<R> {
|
||||||
return if (this is BufferNDStructure<T>) {
|
return if (this is BufferNDStructure<T>)
|
||||||
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
||||||
} else {
|
else {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
}
|
}
|
||||||
@ -307,7 +307,7 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
/**
|
/**
|
||||||
* Mutable ND buffer based on linear [MutableBuffer].
|
* Mutable ND buffer based on linear [MutableBuffer].
|
||||||
*/
|
*/
|
||||||
class MutableBufferNDStructure<T>(
|
public class MutableBufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: MutableBuffer<T>
|
override val buffer: MutableBuffer<T>
|
||||||
) : NDBuffer<T>(), MutableNDStructure<T> {
|
) : NDBuffer<T>(), MutableNDStructure<T> {
|
||||||
@ -321,7 +321,7 @@ class MutableBufferNDStructure<T>(
|
|||||||
override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any> NDStructure<T>.combine(
|
public inline fun <reified T : Any> NDStructure<T>.combine(
|
||||||
struct: NDStructure<T>,
|
struct: NDStructure<T>,
|
||||||
crossinline block: (T, T) -> T
|
crossinline block: (T, T) -> T
|
||||||
): NDStructure<T> {
|
): NDStructure<T> {
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlin.contracts.ExperimentalContracts
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized [MutableBuffer] implementation over [DoubleArray].
|
* Specialized [MutableBuffer] implementation over [DoubleArray].
|
||||||
*
|
*
|
||||||
* @property array the underlying array.
|
* @property array the underlying array.
|
||||||
*/
|
*/
|
||||||
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
public inline class RealBuffer(public val array: DoubleArray) : MutableBuffer<Double> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Double = array[index]
|
override operator fun get(index: Int): Double = array[index]
|
||||||
@ -30,20 +27,17 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer {
|
public inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
|
||||||
contract { callsInPlace(init) }
|
|
||||||
return RealBuffer(DoubleArray(size) { init(it) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [RealBuffer] of given elements.
|
* Returns a new [RealBuffer] of given elements.
|
||||||
*/
|
*/
|
||||||
fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
public fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
|
* Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
|
||||||
*/
|
*/
|
||||||
val MutableBuffer<out Double>.array: DoubleArray
|
public val MutableBuffer<out Double>.array: DoubleArray
|
||||||
get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) })
|
get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -52,4 +46,4 @@ val MutableBuffer<out Double>.array: DoubleArray
|
|||||||
* @receiver the array.
|
* @receiver the array.
|
||||||
* @return the new buffer.
|
* @return the new buffer.
|
||||||
*/
|
*/
|
||||||
fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this)
|
public fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this)
|
||||||
|
@ -4,11 +4,10 @@ import scientifik.kmath.operations.ExtendedField
|
|||||||
import scientifik.kmath.operations.ExtendedFieldOperations
|
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [ExtendedFieldOperations] over [RealBuffer].
|
* [ExtendedFieldOperations] over [RealBuffer].
|
||||||
*/
|
*/
|
||||||
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||||
require(b.size == a.size) {
|
require(b.size == a.size) {
|
||||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
||||||
@ -73,9 +72,8 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
|||||||
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
|
||||||
} else {
|
} else
|
||||||
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
|
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
|
||||||
}
|
|
||||||
|
|
||||||
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
@ -92,37 +90,44 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
|||||||
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
|
||||||
|
|
||||||
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
|
||||||
|
|
||||||
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
|
||||||
|
|
||||||
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
|
||||||
|
|
||||||
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
|
||||||
|
|
||||||
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
|
||||||
|
|
||||||
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
|
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
||||||
|
|
||||||
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
@ -132,7 +137,8 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
|||||||
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
||||||
} else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
} else
|
||||||
|
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -140,7 +146,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
|||||||
*
|
*
|
||||||
* @property size the size of buffers to operate on.
|
* @property size the size of buffers to operate on.
|
||||||
*/
|
*/
|
||||||
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
public class RealBufferField(public val size: Int) : ExtendedField<Buffer<Double>> {
|
||||||
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||||
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||||
|
|
||||||
|
@ -112,26 +112,22 @@ public inline fun RealNDElement.map(crossinline transform: RealField.(Double) ->
|
|||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array. Just like in numpy.
|
* Element by element application of any operation on elements to the whole array. Just like in numpy.
|
||||||
*/
|
*/
|
||||||
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement =
|
public operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement =
|
||||||
ndElement.map { this@invoke(it) }
|
ndElement.map { this@invoke(it) }
|
||||||
|
|
||||||
|
|
||||||
/* plus and minus */
|
/* plus and minus */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Summation operation for [BufferedNDElement] and single element
|
* Summation operation for [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun RealNDElement.plus(arg: Double): RealNDElement =
|
public operator fun RealNDElement.plus(arg: Double): RealNDElement = map { it + arg }
|
||||||
map { it + arg }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [BufferedNDElement] and single element
|
* Subtraction operation between [BufferedNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun RealNDElement.minus(arg: Double): RealNDElement =
|
public operator fun RealNDElement.minus(arg: Double): RealNDElement = map { it - arg }
|
||||||
map { it - arg }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produce a context for n-dimensional operations inside this real field
|
* Produce a context for n-dimensional operations inside this real field
|
||||||
*/
|
*/
|
||||||
|
public inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)
|
||||||
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlin.contracts.ExperimentalContracts
|
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -30,10 +29,7 @@ public inline class ShortBuffer(public val array: ShortArray) : MutableBuffer<Sh
|
|||||||
* The function [init] is called for each array element sequentially starting from the first one.
|
* The function [init] is called for each array element sequentially starting from the first one.
|
||||||
* It should return the value for an buffer element given its index.
|
* It should return the value for an buffer element given its index.
|
||||||
*/
|
*/
|
||||||
public inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer {
|
public inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) })
|
||||||
contract { callsInPlace(init) }
|
|
||||||
return ShortBuffer(ShortArray(size) { init(it) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new [ShortBuffer] of given elements.
|
* Returns a new [ShortBuffer] of given elements.
|
||||||
|
@ -2,7 +2,6 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.RingElement
|
import scientifik.kmath.operations.RingElement
|
||||||
import scientifik.kmath.operations.ShortRing
|
import scientifik.kmath.operations.ShortRing
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
public typealias ShortNDElement = BufferedNDRingElement<Short, ShortRing>
|
public typealias ShortNDElement = BufferedNDRingElement<Short, ShortRing>
|
||||||
|
|
||||||
@ -69,11 +68,8 @@ public class ShortNDRing(override val shape: IntArray) :
|
|||||||
/**
|
/**
|
||||||
* Fast element production using function inlining.
|
* Fast element production using function inlining.
|
||||||
*/
|
*/
|
||||||
public inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement {
|
public inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement =
|
||||||
contract { callsInPlace(initializer) }
|
BufferedNDRingElement(this, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
|
||||||
val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }
|
|
||||||
return BufferedNDRingElement(this, ShortBuffer(array))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element by element application of any operation on elements to the whole array.
|
* Element by element application of any operation on elements to the whole array.
|
||||||
|
@ -3,7 +3,7 @@ package scientifik.kmath.structures
|
|||||||
/**
|
/**
|
||||||
* A structure that is guaranteed to be one-dimensional
|
* A structure that is guaranteed to be one-dimensional
|
||||||
*/
|
*/
|
||||||
interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
||||||
override val dimension: Int get() = 1
|
override val dimension: Int get() = 1
|
||||||
|
|
||||||
override operator fun get(index: IntArray): T {
|
override operator fun get(index: IntArray): T {
|
||||||
@ -11,14 +11,13 @@ interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
|||||||
return get(index[0])
|
return get(index[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A 1D wrapper for nd-structure
|
* A 1D wrapper for nd-structure
|
||||||
*/
|
*/
|
||||||
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
|
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
|
||||||
|
|
||||||
override val shape: IntArray get() = structure.shape
|
override val shape: IntArray get() = structure.shape
|
||||||
override val size: Int get() = structure.shape[0]
|
override val size: Int get() = structure.shape[0]
|
||||||
|
|
||||||
@ -45,18 +44,12 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
|||||||
/**
|
/**
|
||||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
*/
|
*/
|
||||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
public fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||||
if (this is NDBuffer) {
|
if (this is NDBuffer) Buffer1DWrapper(this.buffer) else Structure1DWrapper(this)
|
||||||
Buffer1DWrapper(this.buffer)
|
} else
|
||||||
} else {
|
|
||||||
Structure1DWrapper(this)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent this buffer as 1D structure
|
* Represent this buffer as 1D structure
|
||||||
*/
|
*/
|
||||||
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
public fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||||
|
@ -21,11 +21,8 @@ public interface Structure2D<T> : NDStructure<T> {
|
|||||||
get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
|
get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||||
for (i in (0 until rowNum)) {
|
for (i in (0 until rowNum))
|
||||||
for (j in (0 until colNum)) {
|
for (j in (0 until colNum)) yield(intArrayOf(i, j) to get(i, j))
|
||||||
yield(intArrayOf(i, j) to get(i, j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
@ -45,10 +42,9 @@ private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Stru
|
|||||||
/**
|
/**
|
||||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||||
*/
|
*/
|
||||||
public fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
public fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2)
|
||||||
Structure2DWrapper(this)
|
Structure2DWrapper(this)
|
||||||
} else {
|
else
|
||||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||||
}
|
|
||||||
|
|
||||||
public typealias Matrix<T> = Structure2D<T>
|
public typealias Matrix<T> = Structure2D<T>
|
||||||
|
@ -3,9 +3,8 @@ package scientifik.kmath.coroutines
|
|||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.channels.produce
|
import kotlinx.coroutines.channels.produce
|
||||||
import kotlinx.coroutines.flow.*
|
import kotlinx.coroutines.flow.*
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
val Dispatchers.Math: CoroutineDispatcher
|
public val Dispatchers.Math: CoroutineDispatcher
|
||||||
get() = Default
|
get() = Default
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -15,31 +14,25 @@ internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: s
|
|||||||
private var deferred: Deferred<T>? = null
|
private var deferred: Deferred<T>? = null
|
||||||
|
|
||||||
internal fun start(scope: CoroutineScope) {
|
internal fun start(scope: CoroutineScope) {
|
||||||
if (deferred == null) {
|
if (deferred == null) deferred = scope.async(dispatcher, block = block)
|
||||||
deferred = scope.async(dispatcher, block = block)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
|
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
|
||||||
}
|
}
|
||||||
|
|
||||||
class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
|
public class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
|
||||||
override suspend fun collect(collector: FlowCollector<T>) {
|
override suspend fun collect(collector: FlowCollector<T>): Unit = deferredFlow.collect { collector.emit((it.await())) }
|
||||||
deferredFlow.collect { collector.emit((it.await())) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T, R> Flow<T>.async(
|
public fun <T, R> Flow<T>.async(
|
||||||
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
||||||
block: suspend CoroutineScope.(T) -> R
|
block: suspend CoroutineScope.(T) -> R
|
||||||
): AsyncFlow<R> {
|
): AsyncFlow<R> {
|
||||||
val flow = map {
|
val flow = map { LazyDeferred(dispatcher) { block(it) } }
|
||||||
LazyDeferred(dispatcher) { block(it) }
|
|
||||||
}
|
|
||||||
return AsyncFlow(flow)
|
return AsyncFlow(flow)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
|
public fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
|
||||||
AsyncFlow(deferredFlow.map { input ->
|
AsyncFlow(deferredFlow.map { input ->
|
||||||
//TODO add function composition
|
//TODO add function composition
|
||||||
LazyDeferred(input.dispatcher) {
|
LazyDeferred(input.dispatcher) {
|
||||||
@ -48,7 +41,7 @@ fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
|
public suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
|
||||||
require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" }
|
require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" }
|
||||||
|
|
||||||
coroutineScope {
|
coroutineScope {
|
||||||
@ -76,18 +69,14 @@ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend inline fun <T> AsyncFlow<T>.collect(concurrency: Int, crossinline action: suspend (value: T) -> Unit) {
|
public suspend inline fun <T> AsyncFlow<T>.collect(
|
||||||
contract { callsInPlace(action) }
|
concurrency: Int,
|
||||||
|
crossinline action: suspend (value: T) -> Unit
|
||||||
collect(concurrency, object : FlowCollector<T> {
|
): Unit = collect(concurrency, object : FlowCollector<T> {
|
||||||
override suspend fun emit(value: T): Unit = action(value)
|
override suspend fun emit(value: T): Unit = action(value)
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
inline fun <T, R> Flow<T>.mapParallel(
|
public inline fun <T, R> Flow<T>.mapParallel(
|
||||||
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
||||||
crossinline transform: suspend (T) -> R
|
crossinline transform: suspend (T) -> R
|
||||||
): Flow<R> {
|
): Flow<R> = flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher)
|
||||||
contract { callsInPlace(transform) }
|
|
||||||
return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher)
|
|
||||||
}
|
|
||||||
|
@ -3,38 +3,31 @@ package scientifik.kmath.structures
|
|||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
import scientifik.kmath.coroutines.Math
|
import scientifik.kmath.coroutines.Math
|
||||||
|
|
||||||
class LazyNDStructure<T>(
|
public class LazyNDStructure<T>(
|
||||||
val scope: CoroutineScope,
|
public val scope: CoroutineScope,
|
||||||
override val shape: IntArray,
|
public override val shape: IntArray,
|
||||||
val function: suspend (IntArray) -> T
|
public val function: suspend (IntArray) -> T
|
||||||
) : NDStructure<T> {
|
) : NDStructure<T> {
|
||||||
private val cache: MutableMap<IntArray, Deferred<T>> = hashMapOf()
|
private val cache: MutableMap<IntArray, Deferred<T>> = hashMapOf()
|
||||||
|
|
||||||
fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) {
|
public fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) {
|
||||||
scope.async(context = Dispatchers.Math) {
|
scope.async(context = Dispatchers.Math) { function(index) }
|
||||||
function(index)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend fun await(index: IntArray): T = deferred(index).await()
|
public suspend fun await(index: IntArray): T = deferred(index).await()
|
||||||
|
public override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() }
|
||||||
|
|
||||||
override operator fun get(index: IntArray): T = runBlocking {
|
public override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
deferred(index).await()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
val res = runBlocking {
|
val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } }
|
||||||
strides.indices().toList().map { index -> index to await(index) }
|
|
||||||
}
|
|
||||||
return res.asSequence()
|
return res.asSequence()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
public override fun equals(other: Any?): Boolean {
|
||||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
public override fun hashCode(): Int {
|
||||||
var result = scope.hashCode()
|
var result = scope.hashCode()
|
||||||
result = 31 * result + shape.contentHashCode()
|
result = 31 * result + shape.contentHashCode()
|
||||||
result = 31 * result + function.hashCode()
|
result = 31 * result + function.hashCode()
|
||||||
@ -43,21 +36,21 @@ class LazyNDStructure<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> NDStructure<T>.deferred(index: IntArray): Deferred<T> =
|
public fun <T> NDStructure<T>.deferred(index: IntArray): Deferred<T> =
|
||||||
if (this is LazyNDStructure<T>) this.deferred(index) else CompletableDeferred(get(index))
|
if (this is LazyNDStructure<T>) this.deferred(index) else CompletableDeferred(get(index))
|
||||||
|
|
||||||
suspend fun <T> NDStructure<T>.await(index: IntArray): T =
|
public suspend fun <T> NDStructure<T>.await(index: IntArray): T =
|
||||||
if (this is LazyNDStructure<T>) this.await(index) else get(index)
|
if (this is LazyNDStructure<T>) this.await(index) else get(index)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PENDING would benefit from KEEP-176
|
* PENDING would benefit from KEEP-176
|
||||||
*/
|
*/
|
||||||
inline fun <T, R> NDStructure<T>.mapAsyncIndexed(
|
public inline fun <T, R> NDStructure<T>.mapAsyncIndexed(
|
||||||
scope: CoroutineScope,
|
scope: CoroutineScope,
|
||||||
crossinline function: suspend (T, index: IntArray) -> R
|
crossinline function: suspend (T, index: IntArray) -> R
|
||||||
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) }
|
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) }
|
||||||
|
|
||||||
inline fun <T, R> NDStructure<T>.mapAsync(
|
public inline fun <T, R> NDStructure<T>.mapAsync(
|
||||||
scope: CoroutineScope,
|
scope: CoroutineScope,
|
||||||
crossinline function: suspend (T) -> R
|
crossinline function: suspend (T) -> R
|
||||||
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index)) }
|
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index)) }
|
||||||
|
@ -13,35 +13,36 @@ import scientifik.kmath.structures.Structure2D
|
|||||||
/**
|
/**
|
||||||
* A matrix with compile-time controlled dimension
|
* A matrix with compile-time controlled dimension
|
||||||
*/
|
*/
|
||||||
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||||
companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
|
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
|
||||||
*/
|
*/
|
||||||
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
|
public inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
|
||||||
if (structure.rowNum != Dimension.dim<R>().toInt()) {
|
require(structure.rowNum == Dimension.dim<R>().toInt()) {
|
||||||
error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}")
|
"Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}"
|
||||||
}
|
}
|
||||||
if (structure.colNum != Dimension.dim<C>().toInt()) {
|
|
||||||
error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}")
|
require(structure.colNum == Dimension.dim<C>().toInt()) {
|
||||||
|
"Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}"
|
||||||
}
|
}
|
||||||
|
|
||||||
return DMatrixWrapper(structure)
|
return DMatrixWrapper(structure)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The same as [coerce] but without dimension checks. Use with caution
|
* The same as [DMatrix.coerce] but without dimension checks. Use with caution
|
||||||
*/
|
*/
|
||||||
fun <T, R : Dimension, C : Dimension> coerceUnsafe(structure: Structure2D<T>): DMatrix<T, R, C> {
|
public fun <T, R : Dimension, C : Dimension> coerceUnsafe(structure: Structure2D<T>): DMatrix<T, R, C> =
|
||||||
return DMatrixWrapper(structure)
|
DMatrixWrapper(structure)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An inline wrapper for a Matrix
|
* An inline wrapper for a Matrix
|
||||||
*/
|
*/
|
||||||
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||||
val structure: Structure2D<T>
|
public val structure: Structure2D<T>
|
||||||
) : DMatrix<T, R, C> {
|
) : DMatrix<T, R, C> {
|
||||||
override val shape: IntArray get() = structure.shape
|
override val shape: IntArray get() = structure.shape
|
||||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
@ -50,25 +51,24 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
|||||||
/**
|
/**
|
||||||
* Dimension-safe point
|
* Dimension-safe point
|
||||||
*/
|
*/
|
||||||
interface DPoint<T, D : Dimension> : Point<T> {
|
public interface DPoint<T, D : Dimension> : Point<T> {
|
||||||
companion object {
|
public companion object {
|
||||||
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
|
public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
|
||||||
if (point.size != Dimension.dim<D>().toInt()) {
|
require(point.size == Dimension.dim<D>().toInt()) {
|
||||||
error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}")
|
"Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}"
|
||||||
}
|
}
|
||||||
|
|
||||||
return DPointWrapper(point)
|
return DPointWrapper(point)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T, D : Dimension> coerceUnsafe(point: Point<T>): DPoint<T, D> {
|
public fun <T, D : Dimension> coerceUnsafe(point: Point<T>): DPoint<T, D> = DPointWrapper(point)
|
||||||
return DPointWrapper(point)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dimension-safe point wrapper
|
* Dimension-safe point wrapper
|
||||||
*/
|
*/
|
||||||
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) :
|
public inline class DPointWrapper<T, D : Dimension>(public val point: Point<T>) :
|
||||||
DPoint<T, D> {
|
DPoint<T, D> {
|
||||||
override val size: Int get() = point.size
|
override val size: Int get() = point.size
|
||||||
|
|
||||||
@ -81,16 +81,15 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) :
|
|||||||
/**
|
/**
|
||||||
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
||||||
*/
|
*/
|
||||||
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) {
|
public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: GenericMatrixContext<T, Ri>) {
|
||||||
|
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
||||||
|
require(rowNum == Dimension.dim<R>().toInt()) {
|
||||||
|
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
|
||||||
|
}
|
||||||
|
|
||||||
inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
require(colNum == Dimension.dim<C>().toInt()) {
|
||||||
check(
|
"Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum"
|
||||||
rowNum == Dimension.dim<R>().toInt()
|
}
|
||||||
) { "Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum" }
|
|
||||||
|
|
||||||
check(
|
|
||||||
colNum == Dimension.dim<C>().toInt()
|
|
||||||
) { "Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum" }
|
|
||||||
|
|
||||||
return DMatrix.coerceUnsafe(this)
|
return DMatrix.coerceUnsafe(this)
|
||||||
}
|
}
|
||||||
@ -98,13 +97,13 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
|
|||||||
/**
|
/**
|
||||||
* Produce a matrix with this context and given dimensions
|
* Produce a matrix with this context and given dimensions
|
||||||
*/
|
*/
|
||||||
inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
|
public inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
|
||||||
val rows = Dimension.dim<R>()
|
val rows = Dimension.dim<R>()
|
||||||
val cols = Dimension.dim<C>()
|
val cols = Dimension.dim<C>()
|
||||||
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
|
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
|
public inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
|
||||||
val size = Dimension.dim<D>()
|
val size = Dimension.dim<D>()
|
||||||
|
|
||||||
return DPoint.coerceUnsafe(
|
return DPoint.coerceUnsafe(
|
||||||
@ -115,7 +114,7 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||||
other: DMatrix<T, C1, C2>
|
other: DMatrix<T, C1, C2>
|
||||||
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
||||||
|
|
||||||
|
@ -116,16 +116,13 @@ public operator fun Matrix<Double>.minus(other: Matrix<Double>): RealMatrix =
|
|||||||
* Operations on columns
|
* Operations on columns
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): Matrix<Double> {
|
public inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): Matrix<Double> =
|
||||||
contract { callsInPlace(mapper) }
|
MatrixContext.real.produce(rowNum, colNum + 1) { row, col ->
|
||||||
|
|
||||||
return MatrixContext.real.produce(rowNum, colNum + 1) { row, col ->
|
|
||||||
if (col < colNum)
|
if (col < colNum)
|
||||||
this[row, col]
|
this[row, col]
|
||||||
else
|
else
|
||||||
mapper(rows[row])
|
mapper(rows[row])
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
public fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
|
public fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
|
||||||
MatrixContext.real.produce(rowNum, columnRange.count()) { row, col ->
|
MatrixContext.real.produce(rowNum, columnRange.count()) { row, col ->
|
||||||
|
@ -12,11 +12,10 @@ public fun interface PiecewisePolynomial<T : Any> :
|
|||||||
/**
|
/**
|
||||||
* Ordered list of pieces in piecewise function
|
* Ordered list of pieces in piecewise function
|
||||||
*/
|
*/
|
||||||
public class OrderedPiecewisePolynomial<T : Comparable<T>>(delimeter: T) :
|
public class OrderedPiecewisePolynomial<T : Comparable<T>>(delimiter: T) :
|
||||||
PiecewisePolynomial<T> {
|
PiecewisePolynomial<T> {
|
||||||
|
private val delimiters: MutableList<T> = arrayListOf(delimiter)
|
||||||
private val delimiters: ArrayList<T> = arrayListOf(delimeter)
|
private val pieces: MutableList<Polynomial<T>> = arrayListOf()
|
||||||
private val pieces: ArrayList<Polynomial<T>> = ArrayList()
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dynamically add a piece to the "right" side (beyond maximum argument value of previous piece)
|
* Dynamically add a piece to the "right" side (beyond maximum argument value of previous piece)
|
||||||
@ -35,14 +34,13 @@ public class OrderedPiecewisePolynomial<T : Comparable<T>>(delimeter: T) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun findPiece(arg: T): Polynomial<T>? {
|
override fun findPiece(arg: T): Polynomial<T>? {
|
||||||
if (arg < delimiters.first() || arg >= delimiters.last()) {
|
if (arg < delimiters.first() || arg >= delimiters.last())
|
||||||
return null
|
return null
|
||||||
} else {
|
else {
|
||||||
for (index in 1 until delimiters.size) {
|
for (index in 1 until delimiters.size)
|
||||||
if (arg < delimiters[index]) {
|
if (arg < delimiters[index])
|
||||||
return pieces[index - 1]
|
return pieces[index - 1]
|
||||||
}
|
|
||||||
}
|
|
||||||
error("Piece not found")
|
error("Piece not found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -48,9 +48,9 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.asFunction(ring: C): (T) -> T =
|
|||||||
* An algebra for polynomials
|
* An algebra for polynomials
|
||||||
*/
|
*/
|
||||||
public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<Polynomial<T>> {
|
public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<Polynomial<T>> {
|
||||||
override val zero: Polynomial<T> = Polynomial(emptyList())
|
public override val zero: Polynomial<T> = Polynomial(emptyList())
|
||||||
|
|
||||||
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
||||||
val dim = max(a.coefficients.size, b.coefficients.size)
|
val dim = max(a.coefficients.size, b.coefficients.size)
|
||||||
|
|
||||||
return ring {
|
return ring {
|
||||||
@ -60,7 +60,7 @@ public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<P
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> =
|
public override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> =
|
||||||
ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) }
|
ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) }
|
||||||
|
|
||||||
public operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
|
public operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
|
||||||
|
@ -22,7 +22,7 @@ public interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
||||||
x: Buffer<T>,
|
x: Buffer<T>,
|
||||||
y: Buffer<T>
|
y: Buffer<T>
|
||||||
): PiecewisePolynomial<T> {
|
): PiecewisePolynomial<T> {
|
||||||
@ -30,14 +30,14 @@ fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
|||||||
return interpolatePolynomials(pointSet)
|
return interpolatePolynomials(pointSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
||||||
data: Map<T, T>
|
data: Map<T, T>
|
||||||
): PiecewisePolynomial<T> {
|
): PiecewisePolynomial<T> {
|
||||||
val pointSet = BufferXYPointSet(data.keys.toList().asBuffer(), data.values.toList().asBuffer())
|
val pointSet = BufferXYPointSet(data.keys.toList().asBuffer(), data.values.toList().asBuffer())
|
||||||
return interpolatePolynomials(pointSet)
|
return interpolatePolynomials(pointSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
||||||
data: List<Pair<T, T>>
|
data: List<Pair<T, T>>
|
||||||
): PiecewisePolynomial<T> {
|
): PiecewisePolynomial<T> {
|
||||||
val pointSet = BufferXYPointSet(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
|
val pointSet = BufferXYPointSet(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
|
||||||
|
@ -9,8 +9,8 @@ import scientifik.kmath.operations.invoke
|
|||||||
/**
|
/**
|
||||||
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
|
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
|
||||||
*/
|
*/
|
||||||
public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
public class LinearInterpolator<T : Comparable<T>>(public override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
||||||
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
|
public override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
|
||||||
require(points.size > 0) { "Point array should not be empty" }
|
require(points.size > 0) { "Point array should not be empty" }
|
||||||
insureSorted(points)
|
insureSorted(points)
|
||||||
|
|
||||||
|
@ -12,13 +12,13 @@ import scientifik.kmath.structures.MutableBufferFactory
|
|||||||
* Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java
|
* Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java
|
||||||
*/
|
*/
|
||||||
public class SplineInterpolator<T : Comparable<T>>(
|
public class SplineInterpolator<T : Comparable<T>>(
|
||||||
override val algebra: Field<T>,
|
public override val algebra: Field<T>,
|
||||||
public val bufferFactory: MutableBufferFactory<T>
|
public val bufferFactory: MutableBufferFactory<T>
|
||||||
) : PolynomialInterpolator<T> {
|
) : PolynomialInterpolator<T> {
|
||||||
|
|
||||||
//TODO possibly optimize zeroed buffers
|
//TODO possibly optimize zeroed buffers
|
||||||
|
|
||||||
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
|
public override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
|
||||||
if (points.size < 3) {
|
if (points.size < 3) {
|
||||||
error("Can't use spline interpolator with less than 3 points")
|
error("Can't use spline interpolator with less than 3 points")
|
||||||
}
|
}
|
||||||
@ -41,8 +41,9 @@ public class SplineInterpolator<T : Comparable<T>>(
|
|||||||
|
|
||||||
// cubic spline coefficients -- b is linear, c quadratic, d is cubic (original y's are constants)
|
// cubic spline coefficients -- b is linear, c quadratic, d is cubic (original y's are constants)
|
||||||
|
|
||||||
OrderedPiecewisePolynomial<T>(points.x[points.size - 1]).apply {
|
OrderedPiecewisePolynomial(points.x[points.size - 1]).apply {
|
||||||
var cOld = zero
|
var cOld = zero
|
||||||
|
|
||||||
for (j in n - 1 downTo 0) {
|
for (j in n - 1 downTo 0) {
|
||||||
val c = z[j] - mu[j] * cOld
|
val c = z[j] - mu[j] * cOld
|
||||||
val a = points.y[j]
|
val a = points.y[j]
|
||||||
|
@ -14,32 +14,32 @@ public interface XYZPointSet<X, Y, Z> : XYPointSet<X, Y> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
|
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
|
||||||
for (i in 0 until points.size - 1) require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
|
for (i in 0 until points.size - 1)
|
||||||
|
require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
|
||||||
}
|
}
|
||||||
|
|
||||||
public class NDStructureColumn<T>(public val structure: Structure2D<T>, public val column: Int) : Buffer<T> {
|
public class NDStructureColumn<T>(public val structure: Structure2D<T>, public val column: Int) : Buffer<T> {
|
||||||
|
public override val size: Int
|
||||||
|
get() = structure.rowNum
|
||||||
|
|
||||||
init {
|
init {
|
||||||
require(column < structure.colNum) { "Column index is outside of structure column range" }
|
require(column < structure.colNum) { "Column index is outside of structure column range" }
|
||||||
}
|
}
|
||||||
|
|
||||||
override val size: Int get() = structure.rowNum
|
public override operator fun get(index: Int): T = structure[index, column]
|
||||||
|
public override operator fun iterator(): Iterator<T> = sequence { repeat(size) { yield(get(it)) } }.iterator()
|
||||||
override operator fun get(index: Int): T = structure[index, column]
|
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<T> = sequence {
|
|
||||||
repeat(size) {
|
|
||||||
yield(get(it))
|
|
||||||
}
|
|
||||||
}.iterator()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class BufferXYPointSet<X, Y>(override val x: Buffer<X>, override val y: Buffer<Y>) : XYPointSet<X, Y> {
|
public class BufferXYPointSet<X, Y>(
|
||||||
|
public override val x: Buffer<X>,
|
||||||
|
public override val y: Buffer<Y>
|
||||||
|
) : XYPointSet<X, Y> {
|
||||||
|
public override val size: Int
|
||||||
|
get() = x.size
|
||||||
|
|
||||||
init {
|
init {
|
||||||
require(x.size == y.size) { "Sizes of x and y buffers should be the same" }
|
require(x.size == y.size) { "Sizes of x and y buffers should be the same" }
|
||||||
}
|
}
|
||||||
|
|
||||||
override val size: Int
|
|
||||||
get() = x.size
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> Structure2D<T>.asXYPointSet(): XYPointSet<T, T> {
|
public fun <T> Structure2D<T>.asXYPointSet(): XYPointSet<T, T> {
|
||||||
|
@ -6,7 +6,7 @@ import scientifik.kmath.operations.RealField
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class LinearInterpolatorTest {
|
internal class LinearInterpolatorTest {
|
||||||
@Test
|
@Test
|
||||||
fun testInterpolation() {
|
fun testInterpolation() {
|
||||||
val data = listOf(
|
val data = listOf(
|
||||||
@ -15,9 +15,9 @@ class LinearInterpolatorTest {
|
|||||||
2.0 to 3.0,
|
2.0 to 3.0,
|
||||||
3.0 to 4.0
|
3.0 to 4.0
|
||||||
)
|
)
|
||||||
|
|
||||||
val polynomial: PiecewisePolynomial<Double> = LinearInterpolator(RealField).interpolatePolynomials(data)
|
val polynomial: PiecewisePolynomial<Double> = LinearInterpolator(RealField).interpolatePolynomials(data)
|
||||||
val function = polynomial.asFunction(RealField)
|
val function = polynomial.asFunction(RealField)
|
||||||
|
|
||||||
assertEquals(null, function(-1.0))
|
assertEquals(null, function(-1.0))
|
||||||
assertEquals(0.5, function(0.5))
|
assertEquals(0.5, function(0.5))
|
||||||
assertEquals(2.0, function(1.5))
|
assertEquals(2.0, function(1.5))
|
||||||
|
@ -4,17 +4,16 @@ import org.jetbrains.bio.viktor.F64FlatArray
|
|||||||
import scientifik.kmath.structures.MutableBuffer
|
import scientifik.kmath.structures.MutableBuffer
|
||||||
|
|
||||||
@Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE")
|
@Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE")
|
||||||
inline class ViktorBuffer(val flatArray: F64FlatArray) : MutableBuffer<Double> {
|
public inline class ViktorBuffer(public val flatArray: F64FlatArray) : MutableBuffer<Double> {
|
||||||
override val size: Int get() = flatArray.size
|
public override val size: Int
|
||||||
|
get() = flatArray.size
|
||||||
|
|
||||||
|
public override inline fun get(index: Int): Double = flatArray[index]
|
||||||
|
|
||||||
override inline fun get(index: Int): Double = flatArray[index]
|
|
||||||
override inline fun set(index: Int, value: Double) {
|
override inline fun set(index: Int, value: Double) {
|
||||||
flatArray[index] = value
|
flatArray[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Double> {
|
public override fun copy(): MutableBuffer<Double> = ViktorBuffer(flatArray.copy().flatten())
|
||||||
return ViktorBuffer(flatArray.copy().flatten())
|
public override operator fun iterator(): Iterator<Double> = flatArray.data.iterator()
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<Double> = flatArray.data.iterator()
|
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
pluginManagement {
|
pluginManagement {
|
||||||
val toolsVersion = "0.6.0-dev-3"
|
val toolsVersion = "0.6.0-dev-5"
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
||||||
id("ru.mipt.npm.mpp") version toolsVersion
|
id("ru.mipt.npm.mpp") version toolsVersion
|
||||||
id("ru.mipt.npm.jvm") version toolsVersion
|
id("ru.mipt.npm.jvm") version toolsVersion
|
||||||
id("ru.mipt.npm.publish") version toolsVersion
|
id("ru.mipt.npm.publish") version toolsVersion
|
||||||
kotlin("plugin.allopen") version "1.4.0"
|
kotlin("plugin.allopen") version "1.4.20-dev-3898-14"
|
||||||
}
|
}
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
@ -17,6 +17,7 @@ pluginManagement {
|
|||||||
maven("https://dl.bintray.com/mipt-npm/scientifik")
|
maven("https://dl.bintray.com/mipt-npm/scientifik")
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user