Update tools and Kotlin, specify public explicitly, minor contracts refactor

This commit is contained in:
Iaroslav Postovalov 2020-09-09 09:55:26 +07:00
parent 5e4522bb06
commit 6b79e79d21
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
47 changed files with 422 additions and 528 deletions

View File

@ -3,7 +3,7 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
plugins { plugins {
java java
kotlin("jvm") kotlin("jvm")
kotlin("plugin.allopen") version "1.4.0" kotlin("plugin.allopen") version "1.4.20-dev-3898-14"
id("kotlinx.benchmark") version "0.2.0-dev-20" id("kotlinx.benchmark") version "0.2.0-dev-20"
} }
@ -13,6 +13,7 @@ repositories {
maven("http://dl.bintray.com/kyonifer/maven") maven("http://dl.bintray.com/kyonifer/maven")
maven("https://dl.bintray.com/mipt-npm/scientifik") maven("https://dl.bintray.com/mipt-npm/scientifik")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
mavenCentral() mavenCentral()
} }
@ -55,9 +56,4 @@ kotlin.sourceSets.all {
} }
} }
tasks.withType<KotlinCompile> { tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" }
kotlinOptions {
jvmTarget = "11"
freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn"
}
}

View File

@ -37,7 +37,7 @@ public object MstSpace : Space<MST>, NumericAlgebra<MST> {
/** /**
* [Ring] over [MST] nodes. * [Ring] over [MST] nodes.
*/ */
object MstRing : Ring<MST>, NumericAlgebra<MST> { public object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0) override val zero: MST = number(0.0)
override val one: MST = number(1.0) override val one: MST = number(1.0)
@ -58,18 +58,18 @@ object MstRing : Ring<MST>, NumericAlgebra<MST> {
/** /**
* [Field] over [MST] nodes. * [Field] over [MST] nodes.
*/ */
object MstField : Field<MST> { public object MstField : Field<MST> {
override val zero: MST = number(0.0) public override val zero: MST = number(0.0)
override val one: MST = number(1.0) public override val one: MST = number(1.0)
override fun symbol(value: String): MST = MstRing.symbol(value) public override fun symbol(value: String): MST = MstRing.symbol(value)
override fun number(value: Number): MST = MstRing.number(value) public override fun number(value: Number): MST = MstRing.number(value)
override fun add(a: MST, b: MST): MST = MstRing.add(a, b) public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = public override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstRing.binaryOperation(operation, left, right) MstRing.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
@ -78,7 +78,7 @@ object MstField : Field<MST> {
/** /**
* [ExtendedField] over [MST] nodes. * [ExtendedField] over [MST] nodes.
*/ */
object MstExtendedField : ExtendedField<MST> { public object MstExtendedField : ExtendedField<MST> {
override val zero: MST = number(0.0) override val zero: MST = number(0.0)
override val one: MST = number(1.0) override val one: MST = number(1.0)

View File

@ -26,22 +26,21 @@ import scientifik.kmath.structures.indices
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { public class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain {
public override val dimension: Int get() = lower.size
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i -> public override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
point[i] in lower[i]..upper[i] point[i] in lower[i]..upper[i]
} }
override val dimension: Int get() = lower.size public override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num]
override fun getLowerBound(num: Int, point: Point<Double>): Double? = lower[num] public override fun getLowerBound(num: Int): Double? = lower[num]
override fun getLowerBound(num: Int): Double? = lower[num] public override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num]
override fun getUpperBound(num: Int, point: Point<Double>): Double? = upper[num] public override fun getUpperBound(num: Int): Double? = upper[num]
override fun getUpperBound(num: Int): Double? = upper[num] public override fun nearestInDomain(point: Point<Double>): Point<Double> {
override fun nearestInDomain(point: Point<Double>): Point<Double> {
val res = DoubleArray(point.size) { i -> val res = DoubleArray(point.size) { i ->
when { when {
point[i] < lower[i] -> lower[i] point[i] < lower[i] -> lower[i]
@ -53,16 +52,14 @@ public class HyperSquareDomain(private val lower: RealBuffer, private val upper:
return RealBuffer(*res) return RealBuffer(*res)
} }
override fun volume(): Double { public override fun volume(): Double {
var res = 1.0 var res = 1.0
for (i in 0 until dimension) { for (i in 0 until dimension) {
if (lower[i].isInfinite() || upper[i].isInfinite()) { if (lower[i].isInfinite() || upper[i].isInfinite()) return Double.POSITIVE_INFINITY
return Double.POSITIVE_INFINITY if (upper[i] > lower[i]) res *= upper[i] - lower[i]
}
if (upper[i] > lower[i]) {
res *= upper[i] - lower[i]
}
} }
return res return res
} }
} }

View File

@ -17,18 +17,18 @@ package scientifik.kmath.domains
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
public class UnconstrainedDomain(override val dimension: Int) : RealDomain { public class UnconstrainedDomain(public override val dimension: Int) : RealDomain {
override operator fun contains(point: Point<Double>): Boolean = true public override operator fun contains(point: Point<Double>): Boolean = true
override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY public override fun getLowerBound(num: Int, point: Point<Double>): Double? = Double.NEGATIVE_INFINITY
override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY public override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY
override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY public override fun getUpperBound(num: Int, point: Point<Double>): Double? = Double.POSITIVE_INFINITY
override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY public override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY
override fun nearestInDomain(point: Point<Double>): Point<Double> = point public override fun nearestInDomain(point: Point<Double>): Point<Double> = point
override fun volume(): Double = Double.POSITIVE_INFINITY public override fun volume(): Double = Double.POSITIVE_INFINITY
} }

View File

@ -4,16 +4,20 @@ import scientifik.kmath.linear.Point
import scientifik.kmath.structures.asBuffer import scientifik.kmath.structures.asBuffer
public inline class UnivariateDomain(public val range: ClosedFloatingPointRange<Double>) : RealDomain { public inline class UnivariateDomain(public val range: ClosedFloatingPointRange<Double>) : RealDomain {
public override val dimension: Int
get() = 1
public operator fun contains(d: Double): Boolean = range.contains(d) public operator fun contains(d: Double): Boolean = range.contains(d)
override operator fun contains(point: Point<Double>): Boolean { public override operator fun contains(point: Point<Double>): Boolean {
require(point.size == 0) require(point.size == 0)
return contains(point[0]) return contains(point[0])
} }
override fun nearestInDomain(point: Point<Double>): Point<Double> { public override fun nearestInDomain(point: Point<Double>): Point<Double> {
require(point.size == 1) require(point.size == 1)
val value = point[0] val value = point[0]
return when { return when {
value in range -> point value in range -> point
value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer()
@ -21,27 +25,25 @@ public inline class UnivariateDomain(public val range: ClosedFloatingPointRange<
} }
} }
override fun getLowerBound(num: Int, point: Point<Double>): Double? { public override fun getLowerBound(num: Int, point: Point<Double>): Double? {
require(num == 0) require(num == 0)
return range.start return range.start
} }
override fun getUpperBound(num: Int, point: Point<Double>): Double? { public override fun getUpperBound(num: Int, point: Point<Double>): Double? {
require(num == 0) require(num == 0)
return range.endInclusive return range.endInclusive
} }
override fun getLowerBound(num: Int): Double? { public override fun getLowerBound(num: Int): Double? {
require(num == 0) require(num == 0)
return range.start return range.start
} }
override fun getUpperBound(num: Int): Double? { public override fun getUpperBound(num: Int): Double? {
require(num == 0) require(num == 0)
return range.endInclusive return range.endInclusive
} }
override fun volume(): Double = range.endInclusive - range.start public override fun volume(): Double = range.endInclusive - range.start
override val dimension: Int get() = 1
} }

View File

@ -4,7 +4,8 @@ import scientifik.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) : internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> { Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments)) public override operator fun invoke(arguments: Map<String, T>): T =
context.unaryOperation(name, expr.invoke(arguments))
} }
internal class FunctionalBinaryOperation<T>( internal class FunctionalBinaryOperation<T>(
@ -13,17 +14,17 @@ internal class FunctionalBinaryOperation<T>(
val first: Expression<T>, val first: Expression<T>,
val second: Expression<T> val second: Expression<T>
) : Expression<T> { ) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = public override operator fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
} }
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> { internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = public override operator fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name") arguments[name] ?: default ?: error("Parameter not found: $name")
} }
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> { internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = value public override operator fun invoke(arguments: Map<String, T>): T = value
} }
internal class FunctionalConstProductExpression<T>( internal class FunctionalConstProductExpression<T>(
@ -31,7 +32,7 @@ internal class FunctionalConstProductExpression<T>(
private val expr: Expression<T>, private val expr: Expression<T>,
val const: Number val const: Number
) : Expression<T> { ) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const) public override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
} }
/** /**
@ -44,23 +45,23 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
/** /**
* Builds an Expression of constant expression which does not depend on arguments. * Builds an Expression of constant expression which does not depend on arguments.
*/ */
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value) public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
/** /**
* Builds an Expression to access a variable. * Builds an Expression to access a variable.
*/ */
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default) public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
/** /**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/ */
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, operation, left, right) FunctionalBinaryOperation(algebra, operation, left, right)
/** /**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. * Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/ */
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
FunctionalUnaryOperation(algebra, operation, arg) FunctionalUnaryOperation(algebra, operation, arg)
} }
@ -69,18 +70,18 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
*/ */
public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) : public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> { FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
override val zero: Expression<T> get() = const(algebra.zero) public override val zero: Expression<T> get() = const(algebra.zero)
/** /**
* Builds an Expression of addition of two another expressions. * Builds an Expression of addition of two another expressions.
*/ */
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
/** /**
* Builds an Expression of multiplication of expression by number. * Builds an Expression of multiplication of expression by number.
*/ */
override fun multiply(a: Expression<T>, k: Number): Expression<T> = public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
FunctionalConstProductExpression(algebra, a, k) FunctionalConstProductExpression(algebra, a, k)
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg) public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
@ -88,31 +89,31 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg) super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right) super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
} }
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra), public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> { Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: Expression<T> public override val one: Expression<T>
get() = const(algebra.one) get() = const(algebra.one)
/** /**
* Builds an Expression of multiplication of two expressions. * Builds an Expression of multiplication of two expressions.
*/ */
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(RingOperations.TIMES_OPERATION, a, b) binaryOperation(RingOperations.TIMES_OPERATION, a, b)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg) public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation, arg) super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right) super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
} }
@ -122,38 +123,38 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
/** /**
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
binaryOperation(FieldOperations.DIV_OPERATION, a, b) binaryOperation(FieldOperations.DIV_OPERATION, a, b)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg) public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation, arg) super<FunctionalExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation, left, right) super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
} }
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) : public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
FunctionalExpressionField<T, A>(algebra), FunctionalExpressionField<T, A>(algebra),
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> { ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) public override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) public override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) public override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) public override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) public override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun power(arg: Expression<T>, pow: Number): Expression<T> = public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) public override fun exp(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg) public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.unaryOperation(operation, arg) super<FunctionalExpressionField>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionField>.binaryOperation(operation, left, right) super<FunctionalExpressionField>.binaryOperation(operation, left, right)
} }

View File

@ -8,82 +8,82 @@ import scientifik.kmath.structures.*
* Basic implementation of Matrix space based on [NDStructure] * Basic implementation of Matrix space based on [NDStructure]
*/ */
public class BufferMatrixContext<T : Any, R : Ring<T>>( public class BufferMatrixContext<T : Any, R : Ring<T>>(
override val elementContext: R, public override val elementContext: R,
private val bufferFactory: BufferFactory<T> private val bufferFactory: BufferFactory<T>
) : GenericMatrixContext<T, R> { ) : GenericMatrixContext<T, R> {
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) } val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer) public override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
public companion object public companion object
} }
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
public object RealMatrixContext : GenericMatrixContext<Double, RealField> { public object RealMatrixContext : GenericMatrixContext<Double, RealField> {
public override val elementContext: RealField
get() = RealField
override val elementContext: RealField get() = RealField public override inline fun produce(
rows: Int,
override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> { columns: Int,
initializer: (i: Int, j: Int) -> Double
): Matrix<Double> {
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }
override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> = RealBuffer(size, initializer) public override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
RealBuffer(size, initializer)
} }
public class BufferMatrix<T : Any>( public class BufferMatrix<T : Any>(
override val rowNum: Int, public override val rowNum: Int,
override val colNum: Int, public override val colNum: Int,
public val buffer: Buffer<out T>, public val buffer: Buffer<out T>,
override val features: Set<MatrixFeature> = emptySet() public override val features: Set<MatrixFeature> = emptySet()
) : FeaturedMatrix<T> { ) : FeaturedMatrix<T> {
override val shape: IntArray
get() = intArrayOf(rowNum, colNum)
init { init {
if (buffer.size != rowNum * colNum) { require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" }
error("Dimension mismatch for matrix structure")
}
} }
override val shape: IntArray get() = intArrayOf(rowNum, colNum) public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
BufferMatrix(rowNum, colNum, buffer, this.features + features) BufferMatrix(rowNum, colNum, buffer, this.features + features)
override operator fun get(index: IntArray): T = get(index[0], index[1]) public override operator fun get(index: IntArray): T = get(index[0], index[1])
public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] public override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
} }
override fun equals(other: Any?): Boolean { public override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
return when (other) { return when (other) {
is NDStructure<*> -> return NDStructure.equals(this, other) is NDStructure<*> -> return NDStructure.equals(this, other)
else -> false else -> false
} }
} }
override fun hashCode(): Int { public override fun hashCode(): Int {
var result = buffer.hashCode() var result = buffer.hashCode()
result = 31 * result + features.hashCode() result = 31 * result + features.hashCode()
return result return result
} }
override fun toString(): String { public override fun toString(): String {
return if (rowNum <= 5 && colNum <= 5) { return if (rowNum <= 5 && colNum <= 5)
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer -> rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
buffer.asSequence().joinToString(separator = "\t") { it.toString() } buffer.asSequence().joinToString(separator = "\t") { it.toString() }
} }
} else { else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
}
} }
} }
@ -92,26 +92,21 @@ public class BufferMatrix<T : Any>(
*/ */
public infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> { public infix fun BufferMatrix<Double>.dot(other: BufferMatrix<Double>): BufferMatrix<Double> {
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
val array = DoubleArray(this.rowNum * other.colNum) val array = DoubleArray(this.rowNum * other.colNum)
//convert to array to insure there is not memory indirection //convert to array to insure there is not memory indirection
fun Buffer<out Double>.unsafeArray(): DoubleArray = if (this is RealBuffer) { fun Buffer<out Double>.unsafeArray() = if (this is RealBuffer)
array array
} else { else
DoubleArray(size) { get(it) } DoubleArray(size) { get(it) }
}
val a = this.buffer.unsafeArray() val a = this.buffer.unsafeArray()
val b = other.buffer.unsafeArray() val b = other.buffer.unsafeArray()
for (i in (0 until rowNum)) { for (i in (0 until rowNum))
for (j in (0 until other.colNum)) { for (j in (0 until other.colNum))
for (k in (0 until colNum)) { for (k in (0 until colNum))
array[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j] array[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j]
}
}
}
val buffer = RealBuffer(array) val buffer = RealBuffer(array)
return BufferMatrix(rowNum, other.colNum, buffer) return BufferMatrix(rowNum, other.colNum, buffer)

View File

@ -26,10 +26,8 @@ public interface FeaturedMatrix<T : Any> : Matrix<T> {
public companion object public companion object
} }
public inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> { public inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix<Double> =
contract { callsInPlace(initializer) } MatrixContext.real.produce(rows, columns, initializer)
return MatrixContext.real.produce(rows, columns, initializer)
}
/** /**
* Build a square matrix from given elements. * Build a square matrix from given elements.

View File

@ -76,7 +76,6 @@ public inline fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.l
matrix: Matrix<T>, matrix: Matrix<T>,
checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): LUPDecomposition<T> { ): LUPDecomposition<T> {
contract { callsInPlace(checkSingular) }
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" } require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
val m = matrix.colNum val m = matrix.colNum
val pivot = IntArray(matrix.rowNum) val pivot = IntArray(matrix.rowNum)
@ -153,10 +152,7 @@ public inline fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.l
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.lup(
matrix: Matrix<T>, matrix: Matrix<T>,
checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): LUPDecomposition<T> { ): LUPDecomposition<T> = lup(T::class, matrix, checkSingular)
contract { callsInPlace(checkSingular) }
return lup(T::class, matrix, checkSingular)
}
public fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> = public fun GenericMatrixContext<Double, RealField>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
lup(Double::class, matrix) { it < 1e-11 } lup(Double::class, matrix) { it < 1e-11 }
@ -216,7 +212,6 @@ public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext
b: Matrix<T>, b: Matrix<T>,
checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): Matrix<T> { ): Matrix<T> {
contract { callsInPlace(checkSingular) }
// Use existing decomposition if it is provided by matrix // Use existing decomposition if it is provided by matrix
val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular) val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular)
return decomposition.solve(T::class, b) return decomposition.solve(T::class, b)
@ -227,10 +222,7 @@ public fun RealMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): Matrix
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F>.inverse(
matrix: Matrix<T>, matrix: Matrix<T>,
checkSingular: (T) -> Boolean checkSingular: (T) -> Boolean
): Matrix<T> { ): Matrix<T> = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
contract { callsInPlace(checkSingular) }
return solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular)
}
public fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> = public fun RealMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> =
solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 } solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 }

View File

@ -19,10 +19,9 @@ public interface LinearSolver<T : Any> {
* Convert matrix to vector if it is possible * Convert matrix to vector if it is possible
*/ */
public fun <T : Any> Matrix<T>.asPoint(): Point<T> = public fun <T : Any> Matrix<T>.asPoint(): Point<T> =
if (this.colNum == 1) { if (this.colNum == 1)
VirtualBuffer(rowNum) { get(it, 0) } VirtualBuffer(rowNum) { get(it, 0) }
} else { else
error("Can't convert matrix with more than one column to vector") error("Can't convert matrix with more than one column to vector")
}
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) } public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }

View File

@ -12,10 +12,8 @@ import kotlin.jvm.JvmName
* @param R the type of resulting iterable. * @param R the type of resulting iterable.
* @param initial lazy evaluated. * @param initial lazy evaluated.
*/ */
public inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> { public inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator<R> =
contract { callsInPlace(operation) } object : Iterator<R> {
return object : Iterator<R> {
var state: R = initial var state: R = initial
override fun hasNext(): Boolean = this@cumulative.hasNext() override fun hasNext(): Boolean = this@cumulative.hasNext()
@ -25,7 +23,6 @@ public inline fun <T, R> Iterator<T>.cumulative(initial: R, crossinline operatio
return state return state
} }
} }
}
public inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> = public inline fun <T, R> Iterable<T>.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable<R> =
Iterable { this@cumulative.iterator().cumulative(initial, operation) } Iterable { this@cumulative.iterator().cumulative(initial, operation) }

View File

@ -1,5 +1,7 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import kotlin.contracts.contract
/** /**
* Stub for DSL the [Algebra] is. * Stub for DSL the [Algebra] is.
*/ */

View File

@ -40,18 +40,17 @@ public class BigInt internal constructor(
private val sign: Byte, private val sign: Byte,
private val magnitude: Magnitude private val magnitude: Magnitude
) : Comparable<BigInt> { ) : Comparable<BigInt> {
public override fun compareTo(other: BigInt): Int = when {
override fun compareTo(other: BigInt): Int = when {
(this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0 (this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0
this.sign < other.sign -> -1 this.sign < other.sign -> -1
this.sign > other.sign -> 1 this.sign > other.sign -> 1
else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude) else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude)
} }
override fun equals(other: Any?): Boolean = public override fun equals(other: Any?): Boolean =
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type") if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
override fun hashCode(): Int = magnitude.hashCode() + sign public override fun hashCode(): Int = magnitude.hashCode() + sign
public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude) public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude)
@ -456,15 +455,11 @@ public fun String.parseBigInteger(): BigInt? {
return res * sign return res * sign
} }
public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> { public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
contract { callsInPlace(initializer) } boxing(size, initializer)
return boxing(size, initializer)
}
public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> { public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
contract { callsInPlace(initializer) } boxing(size, initializer)
return boxing(size, initializer)
}
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> = public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)

View File

@ -6,7 +6,6 @@ import scientifik.kmath.structures.MutableBuffer
import scientifik.memory.MemoryReader import scientifik.memory.MemoryReader
import scientifik.memory.MemorySpec import scientifik.memory.MemorySpec
import scientifik.memory.MemoryWriter import scientifik.memory.MemoryWriter
import kotlin.contracts.contract
import kotlin.math.* import kotlin.math.*
/** /**
@ -165,7 +164,8 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
* @property re The real part. * @property re The real part.
* @property im The imaginary part. * @property im The imaginary part.
*/ */
public data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> { public data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>,
Comparable<Complex> {
public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
override val context: ComplexField get() = ComplexField override val context: ComplexField get() = ComplexField
@ -197,12 +197,8 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
*/ */
public fun Number.toComplex(): Complex = Complex(this, 0.0) public fun Number.toComplex(): Complex = Complex(this, 0.0)
public inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { public inline fun Buffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
contract { callsInPlace(init) } MemoryBuffer.create(Complex, size, init)
return MemoryBuffer.create(Complex, size, init)
}
public inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> { public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
contract { callsInPlace(init) } MemoryBuffer.create(Complex, size, init)
return MemoryBuffer.create(Complex, size, init)
}

View File

@ -62,7 +62,7 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
* *
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/ */
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> { public inline class Real(public val value: Double) : FieldElement<Double, Real, RealField> {
override val context: RealField override val context: RealField
get() = RealField get() = RealField
@ -70,14 +70,14 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
override fun Double.wrap(): Real = Real(value) override fun Double.wrap(): Real = Real(value)
companion object public companion object
} }
/** /**
* A field for [Double] without boxing. Does not produce appropriate field element. * A field for [Double] without boxing. Does not produce appropriate field element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object RealField : ExtendedField<Double>, Norm<Double, Double> { public object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double override val zero: Double
get() = 0.0 get() = 0.0
@ -127,7 +127,7 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
* A field for [Float] without boxing. Does not produce appropriate field element. * A field for [Float] without boxing. Does not produce appropriate field element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : ExtendedField<Float>, Norm<Float, Float> { public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float override val zero: Float
get() = 0.0f get() = 0.0f
@ -177,7 +177,7 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
* A field for [Int] without boxing. Does not produce corresponding ring element. * A field for [Int] without boxing. Does not produce corresponding ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object IntRing : Ring<Int>, Norm<Int, Int> { public object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int override val zero: Int
get() = 0 get() = 0
@ -201,7 +201,7 @@ object IntRing : Ring<Int>, Norm<Int, Int> {
* A field for [Short] without boxing. Does not produce appropriate ring element. * A field for [Short] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ShortRing : Ring<Short>, Norm<Short, Short> { public object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short override val zero: Short
get() = 0 get() = 0
@ -225,7 +225,7 @@ object ShortRing : Ring<Short>, Norm<Short, Short> {
* A field for [Byte] without boxing. Does not produce appropriate ring element. * A field for [Byte] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ByteRing : Ring<Byte>, Norm<Byte, Byte> { public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte override val zero: Byte
get() = 0 get() = 0
@ -249,7 +249,7 @@ object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
* A field for [Double] without boxing. Does not produce appropriate ring element. * A field for [Double] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object LongRing : Ring<Long>, Norm<Long, Long> { public object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long override val zero: Long
get() = 0 get() = 0

View File

@ -59,6 +59,7 @@ public class BoxingNDRing<T, R : Ring<T>>(
transform: R.(T, T) -> T transform: R.(T, T) -> T
): BufferedNDRingElement<T, R> { ): BufferedNDRingElement<T, R> {
check(a, b) check(a, b)
return BufferedNDRingElement( return BufferedNDRingElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })

View File

@ -5,7 +5,7 @@ import scientifik.kmath.operations.*
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> { public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
public val strides: Strides public val strides: Strides
override fun check(vararg elements: NDBuffer<T>): Unit = public override fun check(vararg elements: NDBuffer<T>): Unit =
require(elements.all { it.strides == strides }) { ("Strides mismatch") } require(elements.all { it.strides == strides }) { ("Strides mismatch") }
/** /**
@ -29,7 +29,7 @@ public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
public interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> { public interface BufferedNDSpace<T, S : Space<T>> : NDSpace<T, S, NDBuffer<T>>, BufferedNDAlgebra<T, S> {
override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>> public override fun NDBuffer<T>.toElement(): SpaceElement<NDBuffer<T>, *, out BufferedNDSpace<T, S>>
} }
public interface BufferedNDRing<T, R : Ring<T>> : NDRing<T, R, NDBuffer<T>>, BufferedNDSpace<T, R> { public interface BufferedNDRing<T, R : Ring<T>> : NDRing<T, R, NDBuffer<T>>, BufferedNDSpace<T, R> {

View File

@ -2,8 +2,6 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.complex import scientifik.kmath.operations.complex
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -56,7 +54,8 @@ public interface Buffer<T> {
/** /**
* Create a boxing buffer of given type * Create a boxing buffer of given type
*/ */
public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> = ListBuffer(List(size, initializer)) public inline fun <T> boxing(size: Int, initializer: (Int) -> T): Buffer<T> =
ListBuffer(List(size, initializer))
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> { public inline fun <T : Any> auto(type: KClass<T>, size: Int, crossinline initializer: (Int) -> T): Buffer<T> {
@ -115,11 +114,11 @@ public interface MutableBuffer<T> : Buffer<T> {
/** /**
* Create a boxing mutable buffer of given type * Create a boxing mutable buffer of given type
*/ */
inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <T> boxing(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
MutableListBuffer(MutableList(size, initializer)) MutableListBuffer(MutableList(size, initializer))
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <T : Any> auto(type: KClass<out T>, size: Int, initializer: (Int) -> T): MutableBuffer<T> =
when (type) { when (type) {
Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T> Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T> Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer<T>
@ -132,12 +131,11 @@ public interface MutableBuffer<T> : Buffer<T> {
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible * Create most appropriate mutable buffer for given type avoiding boxing wherever possible
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> = public inline fun <reified T : Any> auto(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
auto(T::class, size, initializer) auto(T::class, size, initializer)
val real: MutableBufferFactory<Double> = { size: Int, initializer: (Int) -> Double -> public val real: MutableBufferFactory<Double> =
RealBuffer(DoubleArray(size) { initializer(it) }) { size, initializer -> RealBuffer(DoubleArray(size) { initializer(it) }) }
}
} }
} }
@ -147,7 +145,7 @@ public interface MutableBuffer<T> : Buffer<T> {
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property list The underlying list. * @property list The underlying list.
*/ */
inline class ListBuffer<T>(val list: List<T>) : Buffer<T> { public inline class ListBuffer<T>(public val list: List<T>) : Buffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
@ -158,7 +156,7 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
/** /**
* Returns an [ListBuffer] that wraps the original list. * Returns an [ListBuffer] that wraps the original list.
*/ */
fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this) public fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
/** /**
* Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified * Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified
@ -167,10 +165,7 @@ fun <T> List<T>.asBuffer(): ListBuffer<T> = ListBuffer(this)
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an array element given its index. * It should return the value for an array element given its index.
*/ */
inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> { public inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> = List(size, init).asBuffer()
contract { callsInPlace(init) }
return List(size, init).asBuffer()
}
/** /**
* [MutableBuffer] implementation over [MutableList]. * [MutableBuffer] implementation over [MutableList].
@ -178,7 +173,7 @@ inline fun <T> ListBuffer(size: Int, init: (Int) -> T): ListBuffer<T> {
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property list The underlying list. * @property list The underlying list.
*/ */
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> { public inline class MutableListBuffer<T>(public val list: MutableList<T>) : MutableBuffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
@ -198,7 +193,7 @@ inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
* @param T the type of elements contained in the buffer. * @param T the type of elements contained in the buffer.
* @property array The underlying array. * @property array The underlying array.
*/ */
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> { public class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
// Can't inline because array is invariant // Can't inline because array is invariant
override val size: Int override val size: Int
get() = array.size get() = array.size

View File

@ -4,7 +4,6 @@ import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.FieldElement
import scientifik.kmath.operations.complex import scientifik.kmath.operations.complex
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -98,7 +97,7 @@ public class ComplexNDField(override val shape: IntArray) :
/** /**
* Fast element production using function inlining * Fast element production using function inlining
*/ */
inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline initializer: ComplexField.(Int) -> Complex): ComplexNDElement { public inline fun BufferedNDField<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): ComplexNDElement {
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
return BufferedNDFieldElement(this, buffer) return BufferedNDFieldElement(this, buffer)
} }
@ -106,14 +105,13 @@ inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline init
/** /**
* Map one [ComplexNDElement] using function with indices. * Map one [ComplexNDElement] using function with indices.
*/ */
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement = public inline fun ComplexNDElement.mapIndexed(transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement =
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
/** /**
* Map one [ComplexNDElement] using function without indices. * Map one [ComplexNDElement] using function without indices.
*/ */
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { public inline fun ComplexNDElement.map(transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
contract { callsInPlace(transform) }
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
return BufferedNDFieldElement(context, buffer) return BufferedNDFieldElement(context, buffer)
} }
@ -121,10 +119,9 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) ->
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement = public operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement): ComplexNDElement =
ndElement.map { this@invoke(it) } ndElement.map { this@invoke(it) }
/* plus and minus */ /* plus and minus */
/** /**
@ -142,8 +139,10 @@ public operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement = map
public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape) public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
public fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement = public fun NDElement.Companion.complex(
NDField.complex(*shape).produce(initializer) vararg shape: Int,
initializer: ComplexField.(IntArray) -> Complex
): ComplexNDElement = NDField.complex(*shape).produce(initializer)
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field

View File

@ -1,7 +1,5 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.experimental.and import kotlin.experimental.and
/** /**
@ -34,23 +32,23 @@ public enum class ValueFlag(public val mask: Byte) {
/** /**
* A buffer with flagged values. * A buffer with flagged values.
*/ */
interface FlaggedBuffer<T> : Buffer<T> { public interface FlaggedBuffer<T> : Buffer<T> {
fun getFlag(index: Int): Byte public fun getFlag(index: Int): Byte
} }
/** /**
* The value is valid if all flags are down * The value is valid if all flags are down
*/ */
fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte() public fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte()
fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte() public fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte()
fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING) public fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING)
/** /**
* A real buffer which supports flags for each value like NaN or Missing * A real buffer which supports flags for each value like NaN or Missing
*/ */
class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> { public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer<Double?>, Buffer<Double?> {
init { init {
require(values.size == flags.size) { "Values and flags must have the same dimensions" } require(values.size == flags.size) { "Values and flags must have the same dimensions" }
} }
@ -66,9 +64,7 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged
}.iterator() }.iterator()
} }
inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { public inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) {
contract { callsInPlace(block) }
indices indices
.asSequence() .asSequence()
.filter(::isValid) .filter(::isValid)

View File

@ -8,7 +8,7 @@ import kotlin.contracts.contract
* *
* @property array the underlying array. * @property array the underlying array.
*/ */
inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> { public inline class FloatBuffer(public val array: FloatArray) : MutableBuffer<Float> {
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Float = array[index] override operator fun get(index: Int): Float = array[index]
@ -30,20 +30,17 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer<Float> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer { public inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) })
contract { callsInPlace(init) }
return FloatBuffer(FloatArray(size) { init(it) })
}
/** /**
* Returns a new [FloatBuffer] of given elements. * Returns a new [FloatBuffer] of given elements.
*/ */
fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats) public fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats)
/** /**
* Returns a [FloatArray] containing all of the elements of this [MutableBuffer]. * Returns a [FloatArray] containing all of the elements of this [MutableBuffer].
*/ */
val MutableBuffer<out Float>.array: FloatArray public val MutableBuffer<out Float>.array: FloatArray
get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) }) get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) })
/** /**
@ -52,4 +49,4 @@ val MutableBuffer<out Float>.array: FloatArray
* @receiver the array. * @receiver the array.
* @return the new buffer. * @return the new buffer.
*/ */
fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this) public fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this)

View File

@ -1,9 +1,5 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [IntArray]. * Specialized [MutableBuffer] implementation over [IntArray].
* *
@ -31,20 +27,17 @@ public inline class IntBuffer(public val array: IntArray) : MutableBuffer<Int> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer { public inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) })
contract { callsInPlace(init) }
return IntBuffer(IntArray(size) { init(it) })
}
/** /**
* Returns a new [IntBuffer] of given elements. * Returns a new [IntBuffer] of given elements.
*/ */
fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints) public fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints)
/** /**
* Returns a [IntArray] containing all of the elements of this [MutableBuffer]. * Returns a [IntArray] containing all of the elements of this [MutableBuffer].
*/ */
val MutableBuffer<out Int>.array: IntArray public val MutableBuffer<out Int>.array: IntArray
get() = (if (this is IntBuffer) array else IntArray(size) { get(it) }) get() = (if (this is IntBuffer) array else IntArray(size) { get(it) })
/** /**
@ -53,4 +46,4 @@ val MutableBuffer<out Int>.array: IntArray
* @receiver the array. * @receiver the array.
* @return the new buffer. * @return the new buffer.
*/ */
fun IntArray.asBuffer(): IntBuffer = IntBuffer(this) public fun IntArray.asBuffer(): IntBuffer = IntBuffer(this)

View File

@ -31,10 +31,7 @@ public inline class LongBuffer(public val array: LongArray) : MutableBuffer<Long
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer { public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) })
contract { callsInPlace(init) }
return LongBuffer(LongArray(size) { init(it) })
}
/** /**
* Returns a new [LongBuffer] of given elements. * Returns a new [LongBuffer] of given elements.

View File

@ -9,7 +9,7 @@ import scientifik.memory.*
* @property memory the underlying memory segment. * @property memory the underlying memory segment.
* @property spec the spec of [T] type. * @property spec the spec of [T] type.
*/ */
open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> { public open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spec: MemorySpec<T>) : Buffer<T> {
override val size: Int get() = memory.size / spec.objectSize override val size: Int get() = memory.size / spec.objectSize
private val reader: MemoryReader = memory.reader() private val reader: MemoryReader = memory.reader()
@ -17,20 +17,17 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index) override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
companion object { public companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> = public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
MemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( public inline fun <T : Any> create(
spec: MemorySpec<T>, spec: MemorySpec<T>,
size: Int, size: Int,
crossinline initializer: (Int) -> T initializer: (Int) -> T
): MemoryBuffer<T> = ): MemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> (0 until size).forEach { buffer[it] = initializer(it) }
(0 until size).forEach { }
buffer[it] = initializer(it)
}
}
} }
} }
@ -41,7 +38,7 @@ open class MemoryBuffer<T : Any>(protected val memory: Memory, protected val spe
* @property memory the underlying memory segment. * @property memory the underlying memory segment.
* @property spec the spec of [T] type. * @property spec the spec of [T] type.
*/ */
class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec), public class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : MemoryBuffer<T>(memory, spec),
MutableBuffer<T> { MutableBuffer<T> {
private val writer: MemoryWriter = memory.writer() private val writer: MemoryWriter = memory.writer()
@ -49,19 +46,16 @@ class MutableMemoryBuffer<T : Any>(memory: Memory, spec: MemorySpec<T>) : Memory
override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value)
override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec) override fun copy(): MutableBuffer<T> = MutableMemoryBuffer(memory.copy(), spec)
companion object { public companion object {
fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> = public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MutableMemoryBuffer<T> =
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec) MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec)
inline fun <T : Any> create( public inline fun <T : Any> create(
spec: MemorySpec<T>, spec: MemorySpec<T>,
size: Int, size: Int,
crossinline initializer: (Int) -> T crossinline initializer: (Int) -> T
): MutableMemoryBuffer<T> = ): MutableMemoryBuffer<T> = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer ->
MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> (0 until size).forEach { buffer[it] = initializer(it) }
(0 until size).forEach { }
buffer[it] = initializer(it)
}
}
} }
} }

View File

@ -115,19 +115,18 @@ public interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing
public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) } public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
companion object { public companion object {
private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
private val realNDFieldCache = HashMap<IntArray, RealNDField>()
/** /**
* Create a nd-field for [Double] values or pull it from cache if it was created previously * Create a nd-field for [Double] values or pull it from cache if it was created previously
*/ */
fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
/** /**
* Create a nd-field with boxing generic buffer * Create a nd-field with boxing generic buffer
*/ */
fun <T : Any, F : Field<T>> boxing( public fun <T : Any, F : Field<T>> boxing(
field: F, field: F,
vararg shape: Int, vararg shape: Int,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
@ -137,7 +136,7 @@ public interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing
* Create a most suitable implementation for nd-field using reified class. * Create a most suitable implementation for nd-field using reified class.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): BufferedNDField<T, F> = public inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): BufferedNDField<T, F> =
when { when {
T::class == Double::class -> real(*shape) as BufferedNDField<T, F> T::class == Double::class -> real(*shape) as BufferedNDField<T, F>
T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F> T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>

View File

@ -4,6 +4,7 @@ import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import kotlin.contracts.contract
/** /**
* The root for all [NDStructure] based algebra elements. Does not implement algebra element root because of problems with recursive self-types * The root for all [NDStructure] based algebra elements. Does not implement algebra element root because of problems with recursive self-types
@ -11,31 +12,30 @@ import scientifik.kmath.operations.Space
* @param C the type of the context for the element * @param C the type of the context for the element
* @param N the type of the underlying [NDStructure] * @param N the type of the underlying [NDStructure]
*/ */
interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> { public interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
public val context: NDAlgebra<T, C, N>
val context: NDAlgebra<T, C, N> public fun unwrap(): N
fun unwrap(): N public fun N.wrap(): NDElement<T, C, N>
fun N.wrap(): NDElement<T, C, N> public companion object {
companion object {
/** /**
* Create a optimized NDArray of doubles * Create a optimized NDArray of doubles
*/ */
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement = public fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement =
NDField.real(*shape).produce(initializer) NDField.real(*shape).produce(initializer)
inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = public inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement =
real(intArrayOf(dim)) { initializer(it[0]) } real(intArrayOf(dim)) { initializer(it[0]) }
inline fun real2D( public inline fun real2D(
dim1: Int, dim1: Int,
dim2: Int, dim2: Int,
crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 } crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 }
): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } ): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
inline fun real3D( public inline fun real3D(
dim1: Int, dim1: Int,
dim2: Int, dim2: Int,
dim3: Int, dim3: Int,
@ -46,7 +46,7 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
/** /**
* Simple boxing NDArray * Simple boxing NDArray
*/ */
fun <T : Any, F : Field<T>> boxing( public fun <T : Any, F : Field<T>> boxing(
shape: IntArray, shape: IntArray,
field: F, field: F,
initializer: F.(IntArray) -> T initializer: F.(IntArray) -> T
@ -55,7 +55,7 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
return ndField.produce(initializer) return ndField.produce(initializer)
} }
inline fun <reified T : Any, F : Field<T>> auto( public inline fun <reified T : Any, F : Field<T>> auto(
shape: IntArray, shape: IntArray,
field: F, field: F,
noinline initializer: F.(IntArray) -> T noinline initializer: F.(IntArray) -> T
@ -66,17 +66,16 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
} }
} }
public fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement<T, C, N> =
context.mapIndexed(unwrap(), transform).wrap() context.mapIndexed(unwrap(), transform).wrap()
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> = public fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T): NDElement<T, C, N> =
context.map(unwrap(), transform).wrap() context.map(unwrap(), transform).wrap()
/** /**
* Element by element application of any operation on elements to the whole [NDElement] * Element by element application of any operation on elements to the whole [NDElement]
*/ */
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> = public operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>): NDElement<T, C, N> =
ndElement.map { value -> this@invoke(value) } ndElement.map { value -> this@invoke(value) }
/* plus and minus */ /* plus and minus */
@ -84,13 +83,13 @@ operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElem
/** /**
* Summation operation for [NDElement] and single element * Summation operation for [NDElement] and single element
*/ */
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> = public operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T): NDElement<T, S, N> =
map { value -> arg + value } map { value -> arg + value }
/** /**
* Subtraction operation between [NDElement] and single element * Subtraction operation between [NDElement] and single element
*/ */
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> = public operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T): NDElement<T, S, N> =
map { value -> arg - value } map { value -> arg - value }
/* prod and div */ /* prod and div */
@ -98,13 +97,13 @@ operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg:
/** /**
* Product operation for [NDElement] and single element * Product operation for [NDElement] and single element
*/ */
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> = public operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T): NDElement<T, R, N> =
map { value -> arg * value } map { value -> arg * value }
/** /**
* Division operation between [NDElement] and single element * Division operation between [NDElement] and single element
*/ */
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> = public operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T): NDElement<T, F, N> =
map { value -> arg / value } map { value -> arg / value }
// /** // /**

View File

@ -1,6 +1,5 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.reflect.KClass import kotlin.reflect.KClass
@ -12,17 +11,17 @@ import kotlin.reflect.KClass
* *
* @param T the type of items. * @param T the type of items.
*/ */
interface NDStructure<T> { public interface NDStructure<T> {
/** /**
* The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of * The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of
* this structure. * this structure.
*/ */
val shape: IntArray public val shape: IntArray
/** /**
* The count of dimensions in this structure. It should be equal to size of [shape]. * The count of dimensions in this structure. It should be equal to size of [shape].
*/ */
val dimension: Int get() = shape.size public val dimension: Int get() = shape.size
/** /**
* Returns the value at the specified indices. * Returns the value at the specified indices.
@ -30,24 +29,24 @@ interface NDStructure<T> {
* @param index the indices. * @param index the indices.
* @return the value. * @return the value.
*/ */
operator fun get(index: IntArray): T public operator fun get(index: IntArray): T
/** /**
* Returns the sequence of all the elements associated by their indices. * Returns the sequence of all the elements associated by their indices.
* *
* @return the lazy sequence of pairs of indices to values. * @return the lazy sequence of pairs of indices to values.
*/ */
fun elements(): Sequence<Pair<IntArray, T>> public fun elements(): Sequence<Pair<IntArray, T>>
override fun equals(other: Any?): Boolean override fun equals(other: Any?): Boolean
override fun hashCode(): Int override fun hashCode(): Int
companion object { public companion object {
/** /**
* Indicates whether some [NDStructure] is equal to another one. * Indicates whether some [NDStructure] is equal to another one.
*/ */
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { public fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
@ -68,7 +67,7 @@ interface NDStructure<T> {
* *
* Strides should be reused if possible. * Strides should be reused if possible.
*/ */
fun <T> build( public fun <T> build(
strides: Strides, strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T
@ -78,39 +77,39 @@ interface NDStructure<T> {
/** /**
* Inline create NDStructure with non-boxing buffer implementation if it is possible * Inline create NDStructure with non-boxing buffer implementation if it is possible
*/ */
inline fun <reified T : Any> auto( public inline fun <reified T : Any> auto(
strides: Strides, strides: Strides,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> = ): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
inline fun <T : Any> auto( public inline fun <T : Any> auto(
type: KClass<T>, type: KClass<T>,
strides: Strides, strides: Strides,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> = ): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T> build( public fun <T> build(
shape: IntArray, shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer) ): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> auto( public inline fun <reified T : Any> auto(
shape: IntArray, shape: IntArray,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> = ): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
@JvmName("autoVarArg") @JvmName("autoVarArg")
inline fun <reified T : Any> auto( public inline fun <reified T : Any> auto(
vararg shape: Int, vararg shape: Int,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T
): BufferNDStructure<T> = ): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
inline fun <T : Any> auto( public inline fun <T : Any> auto(
type: KClass<T>, type: KClass<T>,
vararg shape: Int, vararg shape: Int,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T
@ -125,68 +124,68 @@ interface NDStructure<T> {
* @param index the indices. * @param index the indices.
* @return the value. * @return the value.
*/ */
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index) public operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
/** /**
* Represents mutable [NDStructure]. * Represents mutable [NDStructure].
*/ */
interface MutableNDStructure<T> : NDStructure<T> { public interface MutableNDStructure<T> : NDStructure<T> {
/** /**
* Inserts an item at the specified indices. * Inserts an item at the specified indices.
* *
* @param index the indices. * @param index the indices.
* @param value the value. * @param value the value.
*/ */
operator fun set(index: IntArray, value: T) public operator fun set(index: IntArray, value: T)
} }
inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) { public inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
contract { callsInPlace(action) }
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
}
/** /**
* A way to convert ND index to linear one and back. * A way to convert ND index to linear one and back.
*/ */
interface Strides { public interface Strides {
/** /**
* Shape of NDstructure * Shape of NDstructure
*/ */
val shape: IntArray public val shape: IntArray
/** /**
* Array strides * Array strides
*/ */
val strides: List<Int> public val strides: List<Int>
/** /**
* Get linear index from multidimensional index * Get linear index from multidimensional index
*/ */
fun offset(index: IntArray): Int public fun offset(index: IntArray): Int
/** /**
* Get multidimensional from linear * Get multidimensional from linear
*/ */
fun index(offset: Int): IntArray public fun index(offset: Int): IntArray
/** /**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/ */
val linearSize: Int public val linearSize: Int
// TODO introduce a fast way to calculate index of the next element?
/** /**
* Iterate over ND indices in a natural order * Iterate over ND indices in a natural order
*/ */
fun indices(): Sequence<IntArray> { public fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map { index(it) }
//TODO introduce a fast way to calculate index of the next element?
return (0 until linearSize).asSequence().map { index(it) }
}
} }
/** /**
* Simple implementation of [Strides]. * Simple implementation of [Strides].
*/ */
class DefaultStrides private constructor(override val shape: IntArray) : Strides { public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
override val linearSize: Int
get() = strides[shape.size]
/** /**
* Strides for memory access * Strides for memory access
*/ */
@ -194,6 +193,7 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
sequence { sequence {
var current = 1 var current = 1
yield(1) yield(1)
shape.forEach { shape.forEach {
current *= it current *= it
yield(current) yield(current)
@ -212,17 +212,16 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
val res = IntArray(shape.size) val res = IntArray(shape.size)
var current = offset var current = offset
var strideIndex = strides.size - 2 var strideIndex = strides.size - 2
while (strideIndex >= 0) { while (strideIndex >= 0) {
res[strideIndex] = (current / strides[strideIndex]) res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex] current %= strides[strideIndex]
strideIndex-- strideIndex--
} }
return res return res
} }
override val linearSize: Int
get() = strides[shape.size]
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other !is DefaultStrides) return false if (other !is DefaultStrides) return false
@ -232,13 +231,14 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
override fun hashCode(): Int = shape.contentHashCode() override fun hashCode(): Int = shape.contentHashCode()
companion object { public companion object {
private val defaultStridesCache = HashMap<IntArray, Strides>() private val defaultStridesCache = HashMap<IntArray, Strides>()
/** /**
* Cached builder for default strides * Cached builder for default strides
*/ */
operator fun invoke(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } public operator fun invoke(shape: IntArray): Strides =
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
} }
} }
@ -247,16 +247,16 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
* *
* @param T the type of items. * @param T the type of items.
*/ */
abstract class NDBuffer<T> : NDStructure<T> { public abstract class NDBuffer<T> : NDStructure<T> {
/** /**
* The underlying buffer. * The underlying buffer.
*/ */
abstract val buffer: Buffer<T> public abstract val buffer: Buffer<T>
/** /**
* The strides to access elements of [Buffer] by linear indices. * The strides to access elements of [Buffer] by linear indices.
*/ */
abstract val strides: Strides public abstract val strides: Strides
override operator fun get(index: IntArray): T = buffer[strides.offset(index)] override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
@ -278,7 +278,7 @@ abstract class NDBuffer<T> : NDStructure<T> {
/** /**
* Boxing generic [NDStructure] * Boxing generic [NDStructure]
*/ */
class BufferNDStructure<T>( public class BufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : NDBuffer<T>() { ) : NDBuffer<T>() {
@ -292,13 +292,13 @@ class BufferNDStructure<T>(
/** /**
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure] * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
*/ */
inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer( public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
factory: BufferFactory<R> = Buffer.Companion::auto, factory: BufferFactory<R> = Buffer.Companion::auto,
crossinline transform: (T) -> R crossinline transform: (T) -> R
): BufferNDStructure<R> { ): BufferNDStructure<R> {
return if (this is BufferNDStructure<T>) { return if (this is BufferNDStructure<T>)
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
} else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
} }
@ -307,7 +307,7 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
/** /**
* Mutable ND buffer based on linear [MutableBuffer]. * Mutable ND buffer based on linear [MutableBuffer].
*/ */
class MutableBufferNDStructure<T>( public class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: MutableBuffer<T> override val buffer: MutableBuffer<T>
) : NDBuffer<T>(), MutableNDStructure<T> { ) : NDBuffer<T>(), MutableNDStructure<T> {
@ -321,7 +321,7 @@ class MutableBufferNDStructure<T>(
override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
} }
inline fun <reified T : Any> NDStructure<T>.combine( public inline fun <reified T : Any> NDStructure<T>.combine(
struct: NDStructure<T>, struct: NDStructure<T>,
crossinline block: (T, T) -> T crossinline block: (T, T) -> T
): NDStructure<T> { ): NDStructure<T> {

View File

@ -1,14 +1,11 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [DoubleArray]. * Specialized [MutableBuffer] implementation over [DoubleArray].
* *
* @property array the underlying array. * @property array the underlying array.
*/ */
inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> { public inline class RealBuffer(public val array: DoubleArray) : MutableBuffer<Double> {
override val size: Int get() = array.size override val size: Int get() = array.size
override operator fun get(index: Int): Double = array[index] override operator fun get(index: Int): Double = array[index]
@ -30,20 +27,17 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer<Double> {
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer { public inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) })
contract { callsInPlace(init) }
return RealBuffer(DoubleArray(size) { init(it) })
}
/** /**
* Returns a new [RealBuffer] of given elements. * Returns a new [RealBuffer] of given elements.
*/ */
fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) public fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
/** /**
* Returns a [DoubleArray] containing all of the elements of this [MutableBuffer]. * Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
*/ */
val MutableBuffer<out Double>.array: DoubleArray public val MutableBuffer<out Double>.array: DoubleArray
get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) }) get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) })
/** /**
@ -52,4 +46,4 @@ val MutableBuffer<out Double>.array: DoubleArray
* @receiver the array. * @receiver the array.
* @return the new buffer. * @return the new buffer.
*/ */
fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this) public fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this)

View File

@ -4,11 +4,10 @@ import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.ExtendedFieldOperations import scientifik.kmath.operations.ExtendedFieldOperations
import kotlin.math.* import kotlin.math.*
/** /**
* [ExtendedFieldOperations] over [RealBuffer]. * [ExtendedFieldOperations] over [RealBuffer].
*/ */
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> { public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
@ -73,9 +72,8 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
} else { } else
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
}
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
@ -92,37 +90,44 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) }) RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) } else
RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) }) RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) } else
RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) }) RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) } else
RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) }) RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) } else
RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) }) RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) } else
RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) }) RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) } else
RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
} else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) } else
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
@ -132,7 +137,8 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } else
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
} }
/** /**
@ -140,7 +146,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
* *
* @property size the size of buffers to operate on. * @property size the size of buffers to operate on.
*/ */
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> { public class RealBufferField(public val size: Int) : ExtendedField<Buffer<Double>> {
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } } override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } } override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }

View File

@ -112,26 +112,22 @@ public inline fun RealNDElement.map(crossinline transform: RealField.(Double) ->
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy. * Element by element application of any operation on elements to the whole array. Just like in numpy.
*/ */
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement = public operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement): RealNDElement =
ndElement.map { this@invoke(it) } ndElement.map { this@invoke(it) }
/* plus and minus */ /* plus and minus */
/** /**
* Summation operation for [BufferedNDElement] and single element * Summation operation for [BufferedNDElement] and single element
*/ */
operator fun RealNDElement.plus(arg: Double): RealNDElement = public operator fun RealNDElement.plus(arg: Double): RealNDElement = map { it + arg }
map { it + arg }
/** /**
* Subtraction operation between [BufferedNDElement] and single element * Subtraction operation between [BufferedNDElement] and single element
*/ */
operator fun RealNDElement.minus(arg: Double): RealNDElement = public operator fun RealNDElement.minus(arg: Double): RealNDElement = map { it - arg }
map { it - arg }
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
public inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)

View File

@ -1,6 +1,5 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract import kotlin.contracts.contract
/** /**
@ -30,10 +29,7 @@ public inline class ShortBuffer(public val array: ShortArray) : MutableBuffer<Sh
* The function [init] is called for each array element sequentially starting from the first one. * The function [init] is called for each array element sequentially starting from the first one.
* It should return the value for an buffer element given its index. * It should return the value for an buffer element given its index.
*/ */
public inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer { public inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) })
contract { callsInPlace(init) }
return ShortBuffer(ShortArray(size) { init(it) })
}
/** /**
* Returns a new [ShortBuffer] of given elements. * Returns a new [ShortBuffer] of given elements.

View File

@ -2,7 +2,6 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.RingElement import scientifik.kmath.operations.RingElement
import scientifik.kmath.operations.ShortRing import scientifik.kmath.operations.ShortRing
import kotlin.contracts.contract
public typealias ShortNDElement = BufferedNDRingElement<Short, ShortRing> public typealias ShortNDElement = BufferedNDRingElement<Short, ShortRing>
@ -69,11 +68,8 @@ public class ShortNDRing(override val shape: IntArray) :
/** /**
* Fast element production using function inlining. * Fast element production using function inlining.
*/ */
public inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement { public inline fun BufferedNDRing<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement =
contract { callsInPlace(initializer) } BufferedNDRingElement(this, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }
return BufferedNDRingElement(this, ShortBuffer(array))
}
/** /**
* Element by element application of any operation on elements to the whole array. * Element by element application of any operation on elements to the whole array.

View File

@ -3,7 +3,7 @@ package scientifik.kmath.structures
/** /**
* A structure that is guaranteed to be one-dimensional * A structure that is guaranteed to be one-dimensional
*/ */
interface Structure1D<T> : NDStructure<T>, Buffer<T> { public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
override val dimension: Int get() = 1 override val dimension: Int get() = 1
override operator fun get(index: IntArray): T { override operator fun get(index: IntArray): T {
@ -11,14 +11,13 @@ interface Structure1D<T> : NDStructure<T>, Buffer<T> {
return get(index[0]) return get(index[0])
} }
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator() override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
} }
/** /**
* A 1D wrapper for nd-structure * A 1D wrapper for nd-structure
*/ */
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> { private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0] override val size: Int get() = structure.shape[0]
@ -45,18 +44,12 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
/** /**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/ */
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) { public fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
if (this is NDBuffer) { if (this is NDBuffer) Buffer1DWrapper(this.buffer) else Structure1DWrapper(this)
Buffer1DWrapper(this.buffer) } else
} else {
Structure1DWrapper(this)
}
} else {
error("Can't create 1d-structure from ${shape.size}d-structure") error("Can't create 1d-structure from ${shape.size}d-structure")
}
/** /**
* Represent this buffer as 1D structure * Represent this buffer as 1D structure
*/ */
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this) public fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)

View File

@ -21,11 +21,8 @@ public interface Structure2D<T> : NDStructure<T> {
get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
override fun elements(): Sequence<Pair<IntArray, T>> = sequence { override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in (0 until rowNum)) { for (i in (0 until rowNum))
for (j in (0 until colNum)) { for (j in (0 until colNum)) yield(intArrayOf(i, j) to get(i, j))
yield(intArrayOf(i, j) to get(i, j))
}
}
} }
public companion object public companion object
@ -45,10 +42,9 @@ private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Stru
/** /**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/ */
public fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) { public fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2)
Structure2DWrapper(this) Structure2DWrapper(this)
} else { else
error("Can't create 2d-structure from ${shape.size}d-structure") error("Can't create 2d-structure from ${shape.size}d-structure")
}
public typealias Matrix<T> = Structure2D<T> public typealias Matrix<T> = Structure2D<T>

View File

@ -3,9 +3,8 @@ package scientifik.kmath.coroutines
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.produce import kotlinx.coroutines.channels.produce
import kotlinx.coroutines.flow.* import kotlinx.coroutines.flow.*
import kotlin.contracts.contract
val Dispatchers.Math: CoroutineDispatcher public val Dispatchers.Math: CoroutineDispatcher
get() = Default get() = Default
/** /**
@ -15,31 +14,25 @@ internal class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: s
private var deferred: Deferred<T>? = null private var deferred: Deferred<T>? = null
internal fun start(scope: CoroutineScope) { internal fun start(scope: CoroutineScope) {
if (deferred == null) { if (deferred == null) deferred = scope.async(dispatcher, block = block)
deferred = scope.async(dispatcher, block = block)
}
} }
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started") suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
} }
class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> { public class AsyncFlow<T> internal constructor(internal val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
override suspend fun collect(collector: FlowCollector<T>) { override suspend fun collect(collector: FlowCollector<T>): Unit = deferredFlow.collect { collector.emit((it.await())) }
deferredFlow.collect { collector.emit((it.await())) }
}
} }
fun <T, R> Flow<T>.async( public fun <T, R> Flow<T>.async(
dispatcher: CoroutineDispatcher = Dispatchers.Default, dispatcher: CoroutineDispatcher = Dispatchers.Default,
block: suspend CoroutineScope.(T) -> R block: suspend CoroutineScope.(T) -> R
): AsyncFlow<R> { ): AsyncFlow<R> {
val flow = map { val flow = map { LazyDeferred(dispatcher) { block(it) } }
LazyDeferred(dispatcher) { block(it) }
}
return AsyncFlow(flow) return AsyncFlow(flow)
} }
fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> = public fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
AsyncFlow(deferredFlow.map { input -> AsyncFlow(deferredFlow.map { input ->
//TODO add function composition //TODO add function composition
LazyDeferred(input.dispatcher) { LazyDeferred(input.dispatcher) {
@ -48,7 +41,7 @@ fun <T, R> AsyncFlow<T>.map(action: (T) -> R): AsyncFlow<R> =
} }
}) })
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) { public suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" } require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" }
coroutineScope { coroutineScope {
@ -76,18 +69,14 @@ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<
} }
} }
suspend inline fun <T> AsyncFlow<T>.collect(concurrency: Int, crossinline action: suspend (value: T) -> Unit) { public suspend inline fun <T> AsyncFlow<T>.collect(
contract { callsInPlace(action) } concurrency: Int,
crossinline action: suspend (value: T) -> Unit
): Unit = collect(concurrency, object : FlowCollector<T> {
override suspend fun emit(value: T): Unit = action(value)
})
collect(concurrency, object : FlowCollector<T> { public inline fun <T, R> Flow<T>.mapParallel(
override suspend fun emit(value: T): Unit = action(value)
})
}
inline fun <T, R> Flow<T>.mapParallel(
dispatcher: CoroutineDispatcher = Dispatchers.Default, dispatcher: CoroutineDispatcher = Dispatchers.Default,
crossinline transform: suspend (T) -> R crossinline transform: suspend (T) -> R
): Flow<R> { ): Flow<R> = flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher)
contract { callsInPlace(transform) }
return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher)
}

View File

@ -3,38 +3,31 @@ package scientifik.kmath.structures
import kotlinx.coroutines.* import kotlinx.coroutines.*
import scientifik.kmath.coroutines.Math import scientifik.kmath.coroutines.Math
class LazyNDStructure<T>( public class LazyNDStructure<T>(
val scope: CoroutineScope, public val scope: CoroutineScope,
override val shape: IntArray, public override val shape: IntArray,
val function: suspend (IntArray) -> T public val function: suspend (IntArray) -> T
) : NDStructure<T> { ) : NDStructure<T> {
private val cache: MutableMap<IntArray, Deferred<T>> = hashMapOf() private val cache: MutableMap<IntArray, Deferred<T>> = hashMapOf()
fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) { public fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) {
scope.async(context = Dispatchers.Math) { scope.async(context = Dispatchers.Math) { function(index) }
function(index)
}
} }
suspend fun await(index: IntArray): T = deferred(index).await() public suspend fun await(index: IntArray): T = deferred(index).await()
public override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() }
override operator fun get(index: IntArray): T = runBlocking { public override fun elements(): Sequence<Pair<IntArray, T>> {
deferred(index).await()
}
override fun elements(): Sequence<Pair<IntArray, T>> {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
val res = runBlocking { val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } }
strides.indices().toList().map { index -> index to await(index) }
}
return res.asSequence() return res.asSequence()
} }
override fun equals(other: Any?): Boolean { public override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
} }
override fun hashCode(): Int { public override fun hashCode(): Int {
var result = scope.hashCode() var result = scope.hashCode()
result = 31 * result + shape.contentHashCode() result = 31 * result + shape.contentHashCode()
result = 31 * result + function.hashCode() result = 31 * result + function.hashCode()
@ -43,21 +36,21 @@ class LazyNDStructure<T>(
} }
} }
fun <T> NDStructure<T>.deferred(index: IntArray): Deferred<T> = public fun <T> NDStructure<T>.deferred(index: IntArray): Deferred<T> =
if (this is LazyNDStructure<T>) this.deferred(index) else CompletableDeferred(get(index)) if (this is LazyNDStructure<T>) this.deferred(index) else CompletableDeferred(get(index))
suspend fun <T> NDStructure<T>.await(index: IntArray): T = public suspend fun <T> NDStructure<T>.await(index: IntArray): T =
if (this is LazyNDStructure<T>) this.await(index) else get(index) if (this is LazyNDStructure<T>) this.await(index) else get(index)
/** /**
* PENDING would benefit from KEEP-176 * PENDING would benefit from KEEP-176
*/ */
inline fun <T, R> NDStructure<T>.mapAsyncIndexed( public inline fun <T, R> NDStructure<T>.mapAsyncIndexed(
scope: CoroutineScope, scope: CoroutineScope,
crossinline function: suspend (T, index: IntArray) -> R crossinline function: suspend (T, index: IntArray) -> R
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) } ): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index), index) }
inline fun <T, R> NDStructure<T>.mapAsync( public inline fun <T, R> NDStructure<T>.mapAsync(
scope: CoroutineScope, scope: CoroutineScope,
crossinline function: suspend (T) -> R crossinline function: suspend (T) -> R
): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index)) } ): LazyNDStructure<R> = LazyNDStructure(scope, shape) { index -> function(get(index)) }

View File

@ -13,35 +13,36 @@ import scientifik.kmath.structures.Structure2D
/** /**
* A matrix with compile-time controlled dimension * A matrix with compile-time controlled dimension
*/ */
interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> { public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
companion object { public companion object {
/** /**
* Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed
*/ */
inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> { public inline fun <T, reified R : Dimension, reified C : Dimension> coerce(structure: Structure2D<T>): DMatrix<T, R, C> {
if (structure.rowNum != Dimension.dim<R>().toInt()) { require(structure.rowNum == Dimension.dim<R>().toInt()) {
error("Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}") "Row number mismatch: expected ${Dimension.dim<R>()} but found ${structure.rowNum}"
} }
if (structure.colNum != Dimension.dim<C>().toInt()) {
error("Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}") require(structure.colNum == Dimension.dim<C>().toInt()) {
"Column number mismatch: expected ${Dimension.dim<C>()} but found ${structure.colNum}"
} }
return DMatrixWrapper(structure) return DMatrixWrapper(structure)
} }
/** /**
* The same as [coerce] but without dimension checks. Use with caution * The same as [DMatrix.coerce] but without dimension checks. Use with caution
*/ */
fun <T, R : Dimension, C : Dimension> coerceUnsafe(structure: Structure2D<T>): DMatrix<T, R, C> { public fun <T, R : Dimension, C : Dimension> coerceUnsafe(structure: Structure2D<T>): DMatrix<T, R, C> =
return DMatrixWrapper(structure) DMatrixWrapper(structure)
}
} }
} }
/** /**
* An inline wrapper for a Matrix * An inline wrapper for a Matrix
*/ */
inline class DMatrixWrapper<T, R : Dimension, C : Dimension>( public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
val structure: Structure2D<T> public val structure: Structure2D<T>
) : DMatrix<T, R, C> { ) : DMatrix<T, R, C> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override operator fun get(i: Int, j: Int): T = structure[i, j] override operator fun get(i: Int, j: Int): T = structure[i, j]
@ -50,25 +51,24 @@ inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
/** /**
* Dimension-safe point * Dimension-safe point
*/ */
interface DPoint<T, D : Dimension> : Point<T> { public interface DPoint<T, D : Dimension> : Point<T> {
companion object { public companion object {
inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> { public inline fun <T, reified D : Dimension> coerce(point: Point<T>): DPoint<T, D> {
if (point.size != Dimension.dim<D>().toInt()) { require(point.size == Dimension.dim<D>().toInt()) {
error("Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}") "Vector dimension mismatch: expected ${Dimension.dim<D>()}, but found ${point.size}"
} }
return DPointWrapper(point) return DPointWrapper(point)
} }
fun <T, D : Dimension> coerceUnsafe(point: Point<T>): DPoint<T, D> { public fun <T, D : Dimension> coerceUnsafe(point: Point<T>): DPoint<T, D> = DPointWrapper(point)
return DPointWrapper(point)
}
} }
} }
/** /**
* Dimension-safe point wrapper * Dimension-safe point wrapper
*/ */
inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) : public inline class DPointWrapper<T, D : Dimension>(public val point: Point<T>) :
DPoint<T, D> { DPoint<T, D> {
override val size: Int get() = point.size override val size: Int get() = point.size
@ -81,16 +81,15 @@ inline class DPointWrapper<T, D : Dimension>(val point: Point<T>) :
/** /**
* Basic operations on dimension-safe matrices. Operates on [Matrix] * Basic operations on dimension-safe matrices. Operates on [Matrix]
*/ */
inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixContext<T, Ri>) { public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: GenericMatrixContext<T, Ri>) {
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
require(rowNum == Dimension.dim<R>().toInt()) {
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
}
inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> { require(colNum == Dimension.dim<C>().toInt()) {
check( "Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum"
rowNum == Dimension.dim<R>().toInt() }
) { "Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum" }
check(
colNum == Dimension.dim<C>().toInt()
) { "Column number mismatch: expected ${Dimension.dim<C>()} but found $colNum" }
return DMatrix.coerceUnsafe(this) return DMatrix.coerceUnsafe(this)
} }
@ -98,13 +97,13 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
/** /**
* Produce a matrix with this context and given dimensions * Produce a matrix with this context and given dimensions
*/ */
inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> { public inline fun <reified R : Dimension, reified C : Dimension> produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix<T, R, C> {
val rows = Dimension.dim<R>() val rows = Dimension.dim<R>()
val cols = Dimension.dim<C>() val cols = Dimension.dim<C>()
return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R, C>() return context.produce(rows.toInt(), cols.toInt(), initializer).coerce<R, C>()
} }
inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> { public inline fun <reified D : Dimension> point(noinline initializer: (Int) -> T): DPoint<T, D> {
val size = Dimension.dim<D>() val size = Dimension.dim<D>()
return DPoint.coerceUnsafe( return DPoint.coerceUnsafe(
@ -115,7 +114,7 @@ inline class DMatrixContext<T : Any, Ri : Ring<T>>(val context: GenericMatrixCon
) )
} }
inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot( public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
other: DMatrix<T, C1, C2> other: DMatrix<T, C1, C2>
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce() ): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()

View File

@ -116,16 +116,13 @@ public operator fun Matrix<Double>.minus(other: Matrix<Double>): RealMatrix =
* Operations on columns * Operations on columns
*/ */
public inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): Matrix<Double> { public inline fun Matrix<Double>.appendColumn(crossinline mapper: (Buffer<Double>) -> Double): Matrix<Double> =
contract { callsInPlace(mapper) } MatrixContext.real.produce(rowNum, colNum + 1) { row, col ->
return MatrixContext.real.produce(rowNum, colNum + 1) { row, col ->
if (col < colNum) if (col < colNum)
this[row, col] this[row, col]
else else
mapper(rows[row]) mapper(rows[row])
} }
}
public fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix = public fun Matrix<Double>.extractColumns(columnRange: IntRange): RealMatrix =
MatrixContext.real.produce(rowNum, columnRange.count()) { row, col -> MatrixContext.real.produce(rowNum, columnRange.count()) { row, col ->

View File

@ -12,11 +12,10 @@ public fun interface PiecewisePolynomial<T : Any> :
/** /**
* Ordered list of pieces in piecewise function * Ordered list of pieces in piecewise function
*/ */
public class OrderedPiecewisePolynomial<T : Comparable<T>>(delimeter: T) : public class OrderedPiecewisePolynomial<T : Comparable<T>>(delimiter: T) :
PiecewisePolynomial<T> { PiecewisePolynomial<T> {
private val delimiters: MutableList<T> = arrayListOf(delimiter)
private val delimiters: ArrayList<T> = arrayListOf(delimeter) private val pieces: MutableList<Polynomial<T>> = arrayListOf()
private val pieces: ArrayList<Polynomial<T>> = ArrayList()
/** /**
* Dynamically add a piece to the "right" side (beyond maximum argument value of previous piece) * Dynamically add a piece to the "right" side (beyond maximum argument value of previous piece)
@ -35,14 +34,13 @@ public class OrderedPiecewisePolynomial<T : Comparable<T>>(delimeter: T) :
} }
override fun findPiece(arg: T): Polynomial<T>? { override fun findPiece(arg: T): Polynomial<T>? {
if (arg < delimiters.first() || arg >= delimiters.last()) { if (arg < delimiters.first() || arg >= delimiters.last())
return null return null
} else { else {
for (index in 1 until delimiters.size) { for (index in 1 until delimiters.size)
if (arg < delimiters[index]) { if (arg < delimiters[index])
return pieces[index - 1] return pieces[index - 1]
}
}
error("Piece not found") error("Piece not found")
} }
} }

View File

@ -48,9 +48,9 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.asFunction(ring: C): (T) -> T =
* An algebra for polynomials * An algebra for polynomials
*/ */
public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<Polynomial<T>> { public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<Polynomial<T>> {
override val zero: Polynomial<T> = Polynomial(emptyList()) public override val zero: Polynomial<T> = Polynomial(emptyList())
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> { public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
val dim = max(a.coefficients.size, b.coefficients.size) val dim = max(a.coefficients.size, b.coefficients.size)
return ring { return ring {
@ -60,7 +60,7 @@ public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<P
} }
} }
override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> = public override fun multiply(a: Polynomial<T>, k: Number): Polynomial<T> =
ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) } ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) }
public operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg) public operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)

View File

@ -22,7 +22,7 @@ public interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T>
} }
} }
fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials( public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
x: Buffer<T>, x: Buffer<T>,
y: Buffer<T> y: Buffer<T>
): PiecewisePolynomial<T> { ): PiecewisePolynomial<T> {
@ -30,16 +30,16 @@ fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
return interpolatePolynomials(pointSet) return interpolatePolynomials(pointSet)
} }
fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials( public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
data: Map<T, T> data: Map<T, T>
): PiecewisePolynomial<T> { ): PiecewisePolynomial<T> {
val pointSet = BufferXYPointSet(data.keys.toList().asBuffer(), data.values.toList().asBuffer()) val pointSet = BufferXYPointSet(data.keys.toList().asBuffer(), data.values.toList().asBuffer())
return interpolatePolynomials(pointSet) return interpolatePolynomials(pointSet)
} }
fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials( public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
data: List<Pair<T, T>> data: List<Pair<T, T>>
): PiecewisePolynomial<T> { ): PiecewisePolynomial<T> {
val pointSet = BufferXYPointSet(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer()) val pointSet = BufferXYPointSet(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
return interpolatePolynomials(pointSet) return interpolatePolynomials(pointSet)
} }

View File

@ -9,8 +9,8 @@ import scientifik.kmath.operations.invoke
/** /**
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java * Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
*/ */
public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> { public class LinearInterpolator<T : Comparable<T>>(public override val algebra: Field<T>) : PolynomialInterpolator<T> {
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra { public override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
require(points.size > 0) { "Point array should not be empty" } require(points.size > 0) { "Point array should not be empty" }
insureSorted(points) insureSorted(points)

View File

@ -12,13 +12,13 @@ import scientifik.kmath.structures.MutableBufferFactory
* Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java * Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java
*/ */
public class SplineInterpolator<T : Comparable<T>>( public class SplineInterpolator<T : Comparable<T>>(
override val algebra: Field<T>, public override val algebra: Field<T>,
public val bufferFactory: MutableBufferFactory<T> public val bufferFactory: MutableBufferFactory<T>
) : PolynomialInterpolator<T> { ) : PolynomialInterpolator<T> {
//TODO possibly optimize zeroed buffers //TODO possibly optimize zeroed buffers
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra { public override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
if (points.size < 3) { if (points.size < 3) {
error("Can't use spline interpolator with less than 3 points") error("Can't use spline interpolator with less than 3 points")
} }
@ -41,8 +41,9 @@ public class SplineInterpolator<T : Comparable<T>>(
// cubic spline coefficients -- b is linear, c quadratic, d is cubic (original y's are constants) // cubic spline coefficients -- b is linear, c quadratic, d is cubic (original y's are constants)
OrderedPiecewisePolynomial<T>(points.x[points.size - 1]).apply { OrderedPiecewisePolynomial(points.x[points.size - 1]).apply {
var cOld = zero var cOld = zero
for (j in n - 1 downTo 0) { for (j in n - 1 downTo 0) {
val c = z[j] - mu[j] * cOld val c = z[j] - mu[j] * cOld
val a = points.y[j] val a = points.y[j]

View File

@ -14,32 +14,32 @@ public interface XYZPointSet<X, Y, Z> : XYPointSet<X, Y> {
} }
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) { internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
for (i in 0 until points.size - 1) require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" } for (i in 0 until points.size - 1)
require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
} }
public class NDStructureColumn<T>(public val structure: Structure2D<T>, public val column: Int) : Buffer<T> { public class NDStructureColumn<T>(public val structure: Structure2D<T>, public val column: Int) : Buffer<T> {
public override val size: Int
get() = structure.rowNum
init { init {
require(column < structure.colNum) { "Column index is outside of structure column range" } require(column < structure.colNum) { "Column index is outside of structure column range" }
} }
override val size: Int get() = structure.rowNum public override operator fun get(index: Int): T = structure[index, column]
public override operator fun iterator(): Iterator<T> = sequence { repeat(size) { yield(get(it)) } }.iterator()
override operator fun get(index: Int): T = structure[index, column]
override operator fun iterator(): Iterator<T> = sequence {
repeat(size) {
yield(get(it))
}
}.iterator()
} }
public class BufferXYPointSet<X, Y>(override val x: Buffer<X>, override val y: Buffer<Y>) : XYPointSet<X, Y> { public class BufferXYPointSet<X, Y>(
public override val x: Buffer<X>,
public override val y: Buffer<Y>
) : XYPointSet<X, Y> {
public override val size: Int
get() = x.size
init { init {
require(x.size == y.size) { "Sizes of x and y buffers should be the same" } require(x.size == y.size) { "Sizes of x and y buffers should be the same" }
} }
override val size: Int
get() = x.size
} }
public fun <T> Structure2D<T>.asXYPointSet(): XYPointSet<T, T> { public fun <T> Structure2D<T>.asXYPointSet(): XYPointSet<T, T> {

View File

@ -6,7 +6,7 @@ import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class LinearInterpolatorTest { internal class LinearInterpolatorTest {
@Test @Test
fun testInterpolation() { fun testInterpolation() {
val data = listOf( val data = listOf(
@ -15,9 +15,9 @@ class LinearInterpolatorTest {
2.0 to 3.0, 2.0 to 3.0,
3.0 to 4.0 3.0 to 4.0
) )
val polynomial: PiecewisePolynomial<Double> = LinearInterpolator(RealField).interpolatePolynomials(data) val polynomial: PiecewisePolynomial<Double> = LinearInterpolator(RealField).interpolatePolynomials(data)
val function = polynomial.asFunction(RealField) val function = polynomial.asFunction(RealField)
assertEquals(null, function(-1.0)) assertEquals(null, function(-1.0))
assertEquals(0.5, function(0.5)) assertEquals(0.5, function(0.5))
assertEquals(2.0, function(1.5)) assertEquals(2.0, function(1.5))

View File

@ -4,17 +4,16 @@ import org.jetbrains.bio.viktor.F64FlatArray
import scientifik.kmath.structures.MutableBuffer import scientifik.kmath.structures.MutableBuffer
@Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE") @Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE")
inline class ViktorBuffer(val flatArray: F64FlatArray) : MutableBuffer<Double> { public inline class ViktorBuffer(public val flatArray: F64FlatArray) : MutableBuffer<Double> {
override val size: Int get() = flatArray.size public override val size: Int
get() = flatArray.size
public override inline fun get(index: Int): Double = flatArray[index]
override inline fun get(index: Int): Double = flatArray[index]
override inline fun set(index: Int, value: Double) { override inline fun set(index: Int, value: Double) {
flatArray[index] = value flatArray[index] = value
} }
override fun copy(): MutableBuffer<Double> { public override fun copy(): MutableBuffer<Double> = ViktorBuffer(flatArray.copy().flatten())
return ViktorBuffer(flatArray.copy().flatten()) public override operator fun iterator(): Iterator<Double> = flatArray.data.iterator()
}
override operator fun iterator(): Iterator<Double> = flatArray.data.iterator()
} }

View File

@ -1,12 +1,12 @@
pluginManagement { pluginManagement {
val toolsVersion = "0.6.0-dev-3" val toolsVersion = "0.6.0-dev-5"
plugins { plugins {
id("kotlinx.benchmark") version "0.2.0-dev-20" id("kotlinx.benchmark") version "0.2.0-dev-20"
id("ru.mipt.npm.mpp") version toolsVersion id("ru.mipt.npm.mpp") version toolsVersion
id("ru.mipt.npm.jvm") version toolsVersion id("ru.mipt.npm.jvm") version toolsVersion
id("ru.mipt.npm.publish") version toolsVersion id("ru.mipt.npm.publish") version toolsVersion
kotlin("plugin.allopen") version "1.4.0" kotlin("plugin.allopen") version "1.4.20-dev-3898-14"
} }
repositories { repositories {
@ -17,6 +17,7 @@ pluginManagement {
maven("https://dl.bintray.com/mipt-npm/scientifik") maven("https://dl.bintray.com/mipt-npm/scientifik")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
} }
} }