diff --git a/CHANGELOG.md b/CHANGELOG.md index 26f9e33ec..3944c673e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,11 +20,14 @@ - Norm support for `Complex` ### Changed +- `readAsMemory` now has `throws IOException` in JVM signature. +- Several functions taking functional types were made `inline`. +- Several functions taking functional types now have `callsInPlace` contracts. - BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations - `power(T, Int)` extension function has preconditions and supports `Field` - Memory objects have more preconditions (overflow checking) - `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114) -- Gradle version: 6.3 -> 6.5.1 +- Gradle version: 6.3 -> 6.6 - Moved probability distributions to commons-rng and to `kmath-prob` ### Fixed diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 73def3572..9f842024d 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -60,5 +60,6 @@ benchmark { tasks.withType { kotlinOptions { jvmTarget = Scientifik.JVM_TARGET.toString() + freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn" } -} \ No newline at end of file +} diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt index ae27620f7..46da6c6d8 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt @@ -4,46 +4,38 @@ import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke @State(Scope.Benchmark) class NDFieldBenchmark { - @Benchmark fun autoFieldAdd() { - bufferedField.run { + bufferedField { var res: NDBuffer = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } @Benchmark fun autoElementAdd() { var res = genericField.one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } @Benchmark fun specializedFieldAdd() { - specializedField.run { + specializedField { var res: NDBuffer = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } } @Benchmark fun boxingFieldAdd() { - genericField.run { + genericField { var res: NDBuffer = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index f7b9661ef..9627743c9 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -5,23 +5,22 @@ import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import scientifik.kmath.viktor.ViktorNDField - @State(Scope.Benchmark) class ViktorBenchmark { final val dim = 1000 final val n = 100 // automatically build context most suited for given type. - final val autoField = NDField.auto(RealField, dim, dim) - final val realField = NDField.real(dim, dim) - - final val viktorField = ViktorNDField(intArrayOf(dim, dim)) + final val autoField: BufferedNDField = NDField.auto(RealField, dim, dim) + final val realField: RealNDField = NDField.real(dim, dim) + final val viktorField: ViktorNDField = ViktorNDField(intArrayOf(dim, dim)) @Benchmark fun automaticFieldAddition() { - autoField.run { + autoField { var res = one repeat(n) { res += one } } @@ -29,7 +28,7 @@ class ViktorBenchmark { @Benchmark fun viktorFieldAddition() { - viktorField.run { + viktorField { var res = one repeat(n) { res += one } } @@ -44,7 +43,7 @@ class ViktorBenchmark { @Benchmark fun realdFieldLog() { - realField.run { + realField { val fortyTwo = produce { 42.0 } var res = one repeat(n) { res = ln(fortyTwo) } diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt index 6ec9e9c17..305f21d4f 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt @@ -1,8 +1,13 @@ package scientifik.kmath.utils +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.system.measureTimeMillis +@OptIn(ExperimentalContracts::class) internal inline fun measureAndPrint(title: String, block: () -> Unit) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } val time = measureTimeMillis(block) println("$title completed in $time millis") -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt index 960f03de3..6cc5411b8 100644 --- a/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt @@ -5,6 +5,7 @@ import scientifik.kmath.commons.linear.CMMatrixContext import scientifik.kmath.commons.linear.inverse import scientifik.kmath.commons.linear.toCM import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import kotlin.contracts.ExperimentalContracts import kotlin.random.Random @@ -21,29 +22,18 @@ fun main() { val n = 5000 // iterations - MatrixContext.real.run { - - repeat(50) { - val res = inverse(matrix) - } - - val inverseTime = measureTimeMillis { - repeat(n) { - val res = inverse(matrix) - } - } - + MatrixContext.real { + repeat(50) { val res = inverse(matrix) } + val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } } println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis") } //commons-math val commonsTime = measureTimeMillis { - CMMatrixContext.run { + CMMatrixContext { val cm = matrix.toCM() //avoid overhead on conversion - repeat(n) { - val res = inverse(cm) - } + repeat(n) { val res = inverse(cm) } } } @@ -53,7 +43,7 @@ fun main() { //koma-ejml val komaTime = measureTimeMillis { - KomaMatrixContext(EJMLMatrixFactory(), RealField).run { + (KomaMatrixContext(EJMLMatrixFactory(), RealField)) { val km = matrix.toKoma() //avoid overhead on conversion repeat(n) { val res = inverse(km) diff --git a/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt index 03bd0001c..3ae550682 100644 --- a/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt @@ -4,6 +4,7 @@ import koma.matrix.ejml.EJMLMatrixFactory import scientifik.kmath.commons.linear.CMMatrixContext import scientifik.kmath.commons.linear.toCM import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import kotlin.random.Random import kotlin.system.measureTimeMillis @@ -18,7 +19,7 @@ fun main() { // //warmup // matrix1 dot matrix2 - CMMatrixContext.run { + CMMatrixContext { val cmMatrix1 = matrix1.toCM() val cmMatrix2 = matrix2.toCM() @@ -29,8 +30,7 @@ fun main() { println("CM implementation time: $cmTime") } - - KomaMatrixContext(EJMLMatrixFactory(), RealField).run { + (KomaMatrixContext(EJMLMatrixFactory(), RealField)) { val komaMatrix1 = matrix1.toKoma() val komaMatrix2 = matrix2.toKoma() diff --git a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt index 4841f9dd8..6dbfebce1 100644 --- a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt +++ b/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt @@ -9,13 +9,11 @@ fun main() { Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) } - - val compute = NDField.complex(8).run { + val compute = (NDField.complex(8)) { val a = produce { (it) -> i * it - it.toDouble() } val b = 3 val c = Complex(1.0, 1.0) (a pow b) + c } - -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt index 991cd34a1..2329f3fc3 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt @@ -13,9 +13,8 @@ fun main() { val realField = NDField.real(dim, dim) val complexField = NDField.complex(dim, dim) - val realTime = measureTimeMillis { - realField.run { + realField { var res: NDBuffer = one repeat(n) { res += 1.0 @@ -26,18 +25,15 @@ fun main() { println("Real addition completed in $realTime millis") val complexTime = measureTimeMillis { - complexField.run { + complexField { var res: NDBuffer = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } } println("Complex addition completed in $complexTime millis") } - fun complexExample() { //Create a context for 2-d structure with complex values ComplexField { @@ -46,10 +42,7 @@ fun complexExample() { val x = one * 2.5 operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im) //a structure generator specific to this context - val matrix = produce { (k, l) -> - k + l * i - } - + val matrix = produce { (k, l) -> k + l * i } //Perform sum val sum = matrix + x + 1.0 diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt index 2aafb504d..1ce74b2a3 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt @@ -2,14 +2,19 @@ package scientifik.kmath.structures import kotlinx.coroutines.GlobalScope import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.system.measureTimeMillis +@OptIn(ExperimentalContracts::class) internal inline fun measureAndPrint(title: String, block: () -> Unit) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } val time = measureTimeMillis(block) println("$title completed in $time millis") } - fun main() { val dim = 1000 val n = 1000 @@ -22,27 +27,21 @@ fun main() { val genericField = NDField.boxing(RealField, dim, dim) measureAndPrint("Automatic field addition") { - autoField.run { + autoField { var res: NDBuffer = one - repeat(n) { - res += number(1.0) - } + repeat(n) { res += number(1.0) } } } measureAndPrint("Element addition") { var res = genericField.one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } measureAndPrint("Specialized addition") { - specializedField.run { + specializedField { var res: NDBuffer = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } } @@ -60,12 +59,11 @@ fun main() { measureAndPrint("Generic addition") { //genericField.run(action) - genericField.run { + genericField { var res: NDBuffer = one repeat(n) { - res += one // con't avoid using `one` due to resolution ambiguity + res += one // couldn't avoid using `one` due to resolution ambiguity } } } } - -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt index fdc09ed5d..e627cecaa 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt @@ -16,6 +16,7 @@ fun DMatrixContext.simple() { object D5 : Dimension { + @OptIn(ExperimentalUnsignedTypes::class) override val dim: UInt = 5u } @@ -23,13 +24,10 @@ fun DMatrixContext.custom() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i - j).toDouble() } val m3 = produce { i, j -> (i - j).toDouble() } - (m1 dot m2) + m3 } -fun main() { - DMatrixContext.real.run { - simple() - custom() - } -} \ No newline at end of file +fun main(): Unit = with(DMatrixContext.real) { + simple() + custom() +} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt index 59f3f15d8..635dc940d 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -2,6 +2,9 @@ package scientifik.kmath.ast import scientifik.kmath.expressions.* import scientifik.kmath.operations.* +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than @@ -24,7 +27,7 @@ class MstExpression(val algebra: Algebra, val mst: MST) : Expression { error("Numeric nodes are not supported by $this") } - override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) + override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) } /** @@ -38,51 +41,72 @@ inline fun , E : Algebra> A.mst( /** * Builds [MstExpression] over [Space]. */ -inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = - MstExpression(this, MstSpace.block()) +@OptIn(ExperimentalContracts::class) +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstSpace.block()) +} /** * Builds [MstExpression] over [Ring]. */ -inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = - MstExpression(this, MstRing.block()) +@OptIn(ExperimentalContracts::class) +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstRing.block()) +} /** * Builds [MstExpression] over [Field]. */ -inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = - MstExpression(this, MstField.block()) +@OptIn(ExperimentalContracts::class) +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstField.block()) +} /** * Builds [MstExpression] over [ExtendedField]. */ -inline fun Field.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression = - MstExpression(this, MstExtendedField.block()) +@OptIn(ExperimentalContracts::class) +inline fun Field.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstExtendedField.block()) +} /** * Builds [MstExpression] over [FunctionalExpressionSpace]. */ -inline fun > FunctionalExpressionSpace.mstInSpace( - block: MstSpace.() -> MST -): MstExpression = algebra.mstInSpace(block) +@OptIn(ExperimentalContracts::class) +inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInSpace(block) +} /** * Builds [MstExpression] over [FunctionalExpressionRing]. */ -inline fun > FunctionalExpressionRing.mstInRing( - block: MstRing.() -> MST -): MstExpression = algebra.mstInRing(block) + +@OptIn(ExperimentalContracts::class) +inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInRing(block) +} /** * Builds [MstExpression] over [FunctionalExpressionField]. */ -inline fun > FunctionalExpressionField.mstInField( - block: MstField.() -> MST -): MstExpression = algebra.mstInField(block) +@OptIn(ExperimentalContracts::class) +inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInField(block) +} /** * Builds [MstExpression] over [FunctionalExpressionExtendedField]. */ -inline fun > FunctionalExpressionExtendedField.mstInExtendedField( - block: MstExtendedField.() -> MST -): MstExpression = algebra.mstInExtendedField(block) +@OptIn(ExperimentalContracts::class) +inline fun > FunctionalExpressionExtendedField.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInExtendedField(block) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt index a637289b8..97bfba7f2 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -7,6 +7,9 @@ import scientifik.kmath.ast.MST import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import java.lang.reflect.Method +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.reflect.KClass private val methodNameAdapters: Map, String> by lazy { @@ -26,8 +29,11 @@ internal val KClass<*>.asm: Type /** * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. */ -internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array = - if (predicate(this)) arrayOf(this) else emptyArray() +@OptIn(ExperimentalContracts::class) +internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array { + contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) } + return if (predicate(this)) arrayOf(this) else emptyArray() +} /** * Creates an [InstructionAdapter] from this [MethodVisitor]. @@ -37,8 +43,11 @@ private fun MethodVisitor.instructionAdapter(): InstructionAdapter = Instruction /** * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. */ -internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = - instructionAdapter().apply(block) +@OptIn(ExperimentalContracts::class) +internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return instructionAdapter().apply(block) +} /** * Constructs a [Label], then applies it to this visitor. @@ -63,10 +72,14 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String { return buildName(mst, collision + 1) } +@OptIn(ExperimentalContracts::class) @Suppress("FunctionName") -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = - ClassWriter(flags).apply(block) +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return ClassWriter(flags).apply(block) +} +@OptIn(ExperimentalContracts::class) internal inline fun ClassWriter.visitField( access: Int, name: String, @@ -74,7 +87,10 @@ internal inline fun ClassWriter.visitField( signature: String?, value: Any?, block: FieldVisitor.() -> Unit -): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) +): FieldVisitor { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return visitField(access, name, descriptor, signature, value).apply(block) +} private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = context.javaClass.methods.find { method -> @@ -151,6 +167,7 @@ private fun AsmBuilder.tryInvokeSpecific( /** * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. */ +@OptIn(ExperimentalContracts::class) internal inline fun AsmBuilder.buildAlgebraOperationCall( context: Algebra, name: String, @@ -158,6 +175,7 @@ internal inline fun AsmBuilder.buildAlgebraOperationCall( parameterTypes: Array, parameters: AsmBuilder.() -> Unit ) { + contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } val arity = parameterTypes.size loadAlgebra() if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 64ebe8da3..9119991e5 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -5,6 +5,7 @@ import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -15,7 +16,6 @@ class DerivativeStructureField( val order: Int, val parameters: Map ) : ExtendedField { - override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } @@ -23,17 +23,15 @@ class DerivativeStructureField( DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) } - val variable = object : ReadOnlyProperty { - override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure { - return variables[property.name] ?: error("A variable with name ${property.name} does not exist") - } + val variable: ReadOnlyProperty = object : ReadOnlyProperty { + override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure = + variables[property.name] ?: error("A variable with name ${property.name} does not exist") } fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure = variables[name] ?: default ?: error("A variable with name $name does not exist") - - fun Number.const() = DerivativeStructure(order, parameters.size, toDouble()) + fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble()) fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { return deriv(mapOf(parName to order)) @@ -83,16 +81,15 @@ class DerivativeStructureField( override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) - override operator fun Number.plus(b: DerivativeStructure) = b + this - override operator fun Number.minus(b: DerivativeStructure) = b - this + override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this + override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this } /** * A constructs that creates a derivative structure with required order on-demand */ class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression { - - override fun invoke(arguments: Map): Double = DerivativeStructureField( + override operator fun invoke(arguments: Map): Double = DerivativeStructureField( 0, arguments ).run(function).value @@ -101,45 +98,40 @@ class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStru * Get the derivative expression with given orders * TODO make result [DiffExpression] */ - fun derivative(orders: Map): Expression { - return object : Expression { - override fun invoke(arguments: Map): Double = - DerivativeStructureField(orders.values.max() ?: 0, arguments) - .run { - function().deriv(orders) - } - } + fun derivative(orders: Map): Expression = object : Expression { + override operator fun invoke(arguments: Map): Double = + (DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) } } //TODO add gradient and maybe other vector operators } -fun DiffExpression.derivative(vararg orders: Pair) = derivative(mapOf(*orders)) -fun DiffExpression.derivative(name: String) = derivative(name to 1) +fun DiffExpression.derivative(vararg orders: Pair): Expression = derivative(mapOf(*orders)) +fun DiffExpression.derivative(name: String): Expression = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ object DiffExpressionAlgebra : ExpressionAlgebra, Field { - override fun variable(name: String, default: Double?) = + override fun variable(name: String, default: Double?): DiffExpression = DiffExpression { variable(name, default?.const()) } override fun const(value: Double): DiffExpression = DiffExpression { value.const() } - override fun add(a: DiffExpression, b: DiffExpression) = + override fun add(a: DiffExpression, b: DiffExpression): DiffExpression = DiffExpression { a.function(this) + b.function(this) } - override val zero = DiffExpression { 0.0.const() } + override val zero: DiffExpression = DiffExpression { 0.0.const() } - override fun multiply(a: DiffExpression, k: Number) = + override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k } - override val one = DiffExpression { 1.0.const() } + override val one: DiffExpression = DiffExpression { 1.0.const() } - override fun multiply(a: DiffExpression, b: DiffExpression) = + override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression = DiffExpression { a.function(this) * b.function(this) } - override fun divide(a: DiffExpression, b: DiffExpression) = + override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression = DiffExpression { a.function(this) / b.function(this) } } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt index a17effccc..f0bbdbe65 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt @@ -1,8 +1,6 @@ package scientifik.kmath.commons.linear import org.apache.commons.math3.linear.* -import org.apache.commons.math3.linear.RealMatrix -import org.apache.commons.math3.linear.RealVector import scientifik.kmath.linear.* import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.NDStructure @@ -14,12 +12,12 @@ class CMMatrix(val origin: RealMatrix, features: Set? = null) : override val features: Set = features ?: sequence { if (origin is DiagonalMatrix) yield(DiagonalFeature) - }.toSet() + }.toHashSet() - override fun suggestFeature(vararg features: MatrixFeature) = + override fun suggestFeature(vararg features: MatrixFeature): CMMatrix = CMMatrix(origin, this.features + features) - override fun get(i: Int, j: Int): Double = origin.getEntry(i, j) + override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) override fun equals(other: Any?): Boolean { return NDStructure.equals(this, other as? NDStructure<*> ?: return false) @@ -40,24 +38,22 @@ fun Matrix.toCM(): CMMatrix = if (this is CMMatrix) { CMMatrix(Array2DRowRealMatrix(array)) } -fun RealMatrix.asMatrix() = CMMatrix(this) +fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this) class CMVector(val origin: RealVector) : Point { override val size: Int get() = origin.dimension - override fun get(index: Int): Double = origin.getEntry(index) + override operator fun get(index: Int): Double = origin.getEntry(index) - override fun iterator(): Iterator = origin.toArray().iterator() + override operator fun iterator(): Iterator = origin.toArray().iterator() } -fun Point.toCM(): CMVector = if (this is CMVector) { - this -} else { +fun Point.toCM(): CMVector = if (this is CMVector) this else { val array = DoubleArray(size) { this[it] } CMVector(ArrayRealVector(array)) } -fun RealVector.toPoint() = CMVector(this) +fun RealVector.toPoint(): CMVector = CMVector(this) object CMMatrixContext : MatrixContext { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { @@ -65,32 +61,33 @@ object CMMatrixContext : MatrixContext { return CMMatrix(Array2DRowRealMatrix(array)) } - override fun Matrix.dot(other: Matrix) = + override fun Matrix.dot(other: Matrix): CMMatrix = CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) override fun Matrix.dot(vector: Point): CMVector = CMVector(this.toCM().origin.preMultiply(vector.toCM().origin)) - override fun Matrix.unaryMinus(): CMMatrix = + override operator fun Matrix.unaryMinus(): CMMatrix = produce(rowNum, colNum) { i, j -> -get(i, j) } - override fun add(a: Matrix, b: Matrix) = + override fun add(a: Matrix, b: Matrix): CMMatrix = CMMatrix(a.toCM().origin.multiply(b.toCM().origin)) - override fun Matrix.minus(b: Matrix) = + override operator fun Matrix.minus(b: Matrix): CMMatrix = CMMatrix(this.toCM().origin.subtract(b.toCM().origin)) - override fun multiply(a: Matrix, k: Number) = + override fun multiply(a: Matrix, k: Number): CMMatrix = CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble())) - override fun Matrix.times(value: Double): Matrix = + override operator fun Matrix.times(value: Double): Matrix = produce(rowNum, colNum) { i, j -> get(i, j) * value } } operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin)) + operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin)) infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = - CMMatrix(this.origin.multiply(other.origin)) \ No newline at end of file + CMMatrix(this.origin.multiply(other.origin)) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt index 13e79d60e..cb2b5dd9c 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -4,10 +4,9 @@ import scientifik.kmath.prob.RandomGenerator class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : org.apache.commons.math3.random.RandomGenerator { - private var generator = factory(intArrayOf()) + private var generator: RandomGenerator = factory(intArrayOf()) override fun nextBoolean(): Boolean = generator.nextBoolean() - override fun nextFloat(): Float = generator.nextDouble().toFloat() override fun setSeed(seed: Int) { @@ -27,12 +26,8 @@ class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : } override fun nextInt(): Int = generator.nextInt() - override fun nextInt(n: Int): Int = generator.nextInt(n) - override fun nextGaussian(): Double = TODO() - override fun nextDouble(): Double = generator.nextDouble() - override fun nextLong(): Long = generator.nextLong() -} \ No newline at end of file +} diff --git a/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt index c77f30eb7..ae4bb755a 100644 --- a/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt +++ b/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt @@ -1,11 +1,17 @@ package scientifik.kmath.commons.expressions import scientifik.kmath.expressions.invoke +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.test.Test import kotlin.test.assertEquals -inline fun diff(order: Int, vararg parameters: Pair, block: DerivativeStructureField.() -> R) = - DerivativeStructureField(order, mapOf(*parameters)).run(block) +@OptIn(ExperimentalContracts::class) +inline fun diff(order: Int, vararg parameters: Pair, block: DerivativeStructureField.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return DerivativeStructureField(order, mapOf(*parameters)).run(block) +} class AutoDiffTest { @Test diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt index 8cd6e28f8..3bbc86ce5 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -4,28 +4,42 @@ import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * Creates a functional expression with this [Space]. */ -fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = - FunctionalExpressionSpace(this).run(block) +@OptIn(ExperimentalContracts::class) +inline fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionSpace(this).block() +} /** * Creates a functional expression with this [Ring]. */ -fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = - FunctionalExpressionRing(this).run(block) +@OptIn(ExperimentalContracts::class) +inline fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionRing(this).block() +} /** * Creates a functional expression with this [Field]. */ -fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression = - FunctionalExpressionField(this).run(block) +@OptIn(ExperimentalContracts::class) +inline fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionField(this).block() +} /** * Creates a functional expression with this [ExtendedField]. */ -fun ExtendedField.fieldExpression( - block: FunctionalExpressionExtendedField>.() -> Expression -): Expression = FunctionalExpressionExtendedField(this).run(block) +@OptIn(ExperimentalContracts::class) +inline fun ExtendedField.extendedFieldExpression(block: FunctionalExpressionExtendedField>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionExtendedField(this).block() +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 380822f78..fd11c246d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -22,7 +22,7 @@ interface Expression { */ fun Algebra.expression(block: Algebra.(arguments: Map) -> T): Expression = object : Expression { - override fun invoke(arguments: Map): T = block(arguments) + override operator fun invoke(arguments: Map): T = block(arguments) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt index 9fe8aaf93..d36c31a0d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -4,7 +4,7 @@ import scientifik.kmath.operations.* internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : Expression { - override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) + override operator fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) } internal class FunctionalBinaryOperation( @@ -13,17 +13,17 @@ internal class FunctionalBinaryOperation( val first: Expression, val second: Expression ) : Expression { - override fun invoke(arguments: Map): T = + override operator fun invoke(arguments: Map): T = context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) } internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { - override fun invoke(arguments: Map): T = + override operator fun invoke(arguments: Map): T = arguments[name] ?: default ?: error("Parameter not found: $name") } internal class FunctionalConstantExpression(val value: T) : Expression { - override fun invoke(arguments: Map): T = value + override operator fun invoke(arguments: Map): T = value } internal class FunctionalConstProductExpression( @@ -31,7 +31,7 @@ internal class FunctionalConstProductExpression( private val expr: Expression, val const: Number ) : Expression { - override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) + override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 2e1f32501..343b8287e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -53,16 +53,12 @@ class BufferMatrix( override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix = BufferMatrix(rowNum, colNum, buffer, this.features + features) - override fun get(index: IntArray): T = get(index[0], index[1]) + override operator fun get(index: IntArray): T = get(index[0], index[1]) - override fun get(i: Int, j: Int): T = buffer[i * colNum + j] + override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] override fun elements(): Sequence> = sequence { - for (i in 0 until rowNum) { - for (j in 0 until colNum) { - yield(intArrayOf(i, j) to get(i, j)) - } - } + for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) } override fun equals(other: Any?): Boolean { @@ -95,7 +91,7 @@ class BufferMatrix( * Optimized dot product for real matrices */ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { - if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } val array = DoubleArray(this.rowNum * other.colNum) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt index dd17c4fe7..39256f4ac 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt @@ -4,6 +4,8 @@ import scientifik.kmath.operations.Ring import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.asBuffer +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.sqrt /** @@ -26,15 +28,18 @@ interface FeaturedMatrix : Matrix { companion object } -fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix = - MatrixContext.real.produce(rows, columns, initializer) +@OptIn(ExperimentalContracts::class) +inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix { + contract { callsInPlace(initializer) } + return MatrixContext.real.produce(rows, columns, initializer) +} /** * Build a square matrix from given elements. */ fun Structure2D.Companion.square(vararg elements: T): FeaturedMatrix { val size: Int = sqrt(elements.size.toDouble()).toInt() - if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") + require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } val buffer = elements.asBuffer() return BufferMatrix(size, size, buffer) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index d04a99fbb..f3e4f648f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -3,6 +3,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Field import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.BufferAccessor2D import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Structure2D @@ -60,15 +61,13 @@ class LUPDecomposition( * @return determinant of the matrix */ override val determinant: T by lazy { - with(elementContext) { - (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } - } + elementContext { (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } } } } fun , F : Field> GenericMatrixContext.abs(value: T): T = - if (value > elementContext.zero) value else with(elementContext) { -value } + if (value > elementContext.zero) value else elementContext { -value } /** @@ -88,43 +87,34 @@ fun , F : Field> GenericMatrixContext.lup( //TODO just waits for KEEP-176 BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { - elementContext.run { - + elementContext { val lu = create(matrix) // Initialize permutation array and parity - for (row in 0 until m) { - pivot[row] = row - } + for (row in 0 until m) pivot[row] = row var even = true // Initialize permutation array and parity - for (row in 0 until m) { - pivot[row] = row - } + for (row in 0 until m) pivot[row] = row // Loop over columns for (col in 0 until m) { - // upper for (row in 0 until col) { val luRow = lu.row(row) var sum = luRow[col] - for (i in 0 until row) { - sum -= luRow[i] * lu[i, col] - } + for (i in 0 until row) sum -= luRow[i] * lu[i, col] luRow[col] = sum } // lower var max = col // permutation row var largest = -one + for (row in col until m) { val luRow = lu.row(row) var sum = luRow[col] - for (i in 0 until col) { - sum -= luRow[i] * lu[i, col] - } + for (i in 0 until col) sum -= luRow[i] * lu[i, col] luRow[col] = sum // maintain best permutation choice @@ -135,19 +125,19 @@ fun , F : Field> GenericMatrixContext.lup( } // Singularity check - if (checkSingular(this@lup.abs(lu[max, col]))) { - error("The matrix is singular") - } + check(!checkSingular(this@lup.abs(lu[max, col]))) { "The matrix is singular" } // Pivot if necessary if (max != col) { val luMax = lu.row(max) val luCol = lu.row(col) + for (i in 0 until m) { val tmp = luMax[i] luMax[i] = luCol[i] luCol[i] = tmp } + val temp = pivot[max] pivot[max] = pivot[col] pivot[col] = temp @@ -156,9 +146,7 @@ fun , F : Field> GenericMatrixContext.lup( // Divide the lower elements by the "winning" diagonal elt. val luDiag = lu[col, col] - for (row in col + 1 until m) { - lu[row, col] /= luDiag - } + for (row in col + 1 until m) lu[row, col] /= luDiag } return LUPDecomposition(this@lup, lu.collect(), pivot, even) @@ -175,28 +163,23 @@ fun GenericMatrixContext.lup(matrix: Matrix): LUPDeco lup(Double::class, matrix) { it < 1e-11 } fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Matrix { - - if (matrix.rowNum != pivot.size) { - error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}") - } + require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { - elementContext.run { - + elementContext { // Apply permutations to b val bp = create { _, _ -> zero } for (row in pivot.indices) { val bpRow = bp.row(row) val pRow = pivot[row] - for (col in 0 until matrix.colNum) { - bpRow[col] = matrix[pRow, col] - } + for (col in 0 until matrix.colNum) bpRow[col] = matrix[pRow, col] } // Solve LY = b for (col in pivot.indices) { val bpCol = bp.row(col) + for (i in col + 1 until pivot.size) { val bpI = bp.row(i) val luICol = lu[i, col] @@ -210,17 +193,15 @@ fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Mat for (col in pivot.size - 1 downTo 0) { val bpCol = bp.row(col) val luDiag = lu[col, col] - for (j in 0 until matrix.colNum) { - bpCol[j] /= luDiag - } + for (j in 0 until matrix.colNum) bpCol[j] /= luDiag + for (i in 0 until col) { val bpI = bp.row(i) val luICol = lu[i, col] - for (j in 0 until matrix.colNum) { - bpI[j] -= bpCol[j] * luICol - } + for (j in 0 until matrix.colNum) bpI[j] -= bpCol[j] * luICol } } + return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt index 516f65bb8..390362f8c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt @@ -7,7 +7,7 @@ import scientifik.kmath.structures.asBuffer class MatrixBuilder(val rows: Int, val columns: Int) { operator fun invoke(vararg elements: T): FeaturedMatrix { - if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns") + require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } val buffer = elements.asBuffer() return BufferMatrix(rows, columns, buffer) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt index 5dc86a7dd..763bb1615 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt @@ -2,6 +2,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Ring import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.operations.invoke import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory @@ -37,8 +38,7 @@ interface MatrixContext : SpaceOperations> { fun > buffered( ring: R, bufferFactory: BufferFactory = Buffer.Companion::boxing - ): GenericMatrixContext = - BufferMatrixContext(ring, bufferFactory) + ): GenericMatrixContext = BufferMatrixContext(ring, bufferFactory) /** * Automatic buffered matrix, unboxed if it is possible @@ -61,45 +61,49 @@ interface GenericMatrixContext> : MatrixContext { override infix fun Matrix.dot(other: Matrix): Matrix { //TODO add typed error - if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + return produce(rowNum, other.colNum) { i, j -> val row = rows[i] val column = other.columns[j] - with(elementContext) { - sum(row.asSequence().zip(column.asSequence(), ::multiply)) - } + elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) } } } override infix fun Matrix.dot(vector: Point): Point { //TODO add typed error - if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})") + require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } + return point(rowNum) { i -> val row = rows[i] - with(elementContext) { - sum(row.asSequence().zip(vector.asSequence(), ::multiply)) - } + elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) } } } override operator fun Matrix.unaryMinus(): Matrix = - produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } + produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } } override fun add(a: Matrix, b: Matrix): Matrix { - if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]") - return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } } + require(a.rowNum == b.rowNum && a.colNum == b.colNum) { + "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]" + } + + return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } } } override operator fun Matrix.minus(b: Matrix): Matrix { - if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]") - return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } + require(rowNum == b.rowNum && colNum == b.colNum) { + "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]" + } + + return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } } } override fun multiply(a: Matrix, k: Number): Matrix = - produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } } + produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } operator fun Number.times(matrix: FeaturedMatrix): Matrix = matrix * this - override fun Matrix.times(value: T): Matrix = - produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } } + override operator fun Matrix.times(value: T): Matrix = + produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt index 691b464fc..82e5c7ef6 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt @@ -2,6 +2,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory @@ -10,10 +11,9 @@ import scientifik.kmath.structures.BufferFactory * Could be used on any point-like structure */ interface VectorSpace> : Space> { - val size: Int - val space: S + override val zero: Point get() = produce { space.zero } fun produce(initializer: (Int) -> T): Point @@ -22,29 +22,24 @@ interface VectorSpace> : Space> { */ //fun produceElement(initializer: (Int) -> T): Vector - override val zero: Point get() = produce { space.zero } + override fun add(a: Point, b: Point): Point = produce { space { a[it] + b[it] } } - override fun add(a: Point, b: Point): Point = produce { with(space) { a[it] + b[it] } } - - override fun multiply(a: Point, k: Number): Point = produce { with(space) { a[it] * k } } + override fun multiply(a: Point, k: Number): Point = produce { space { a[it] * k } } //TODO add basis companion object { - - private val realSpaceCache = HashMap>() + private val realSpaceCache: MutableMap> = hashMapOf() /** * Non-boxing double vector space */ - fun real(size: Int): BufferVectorSpace { - return realSpaceCache.getOrPut(size) { - BufferVectorSpace( - size, - RealField, - Buffer.Companion::auto - ) - } + fun real(size: Int): BufferVectorSpace = realSpaceCache.getOrPut(size) { + BufferVectorSpace( + size, + RealField, + Buffer.Companion::auto + ) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt index 207151d57..5266dc884 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -18,7 +18,7 @@ class VirtualMatrix( override val shape: IntArray get() = intArrayOf(rowNum, colNum) - override fun get(i: Int, j: Int): T = generator(i, j) + override operator fun get(i: Int, j: Int): T = generator(i, j) override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix = VirtualMatrix(rowNum, colNum, this.features + features, generator) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt index db8863ae8..1609f063a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt @@ -3,8 +3,12 @@ package scientifik.kmath.misc import scientifik.kmath.linear.Point import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke import scientifik.kmath.operations.sum import scientifik.kmath.structures.asBuffer +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /* * Implementation of backward-mode automatic differentiation. @@ -27,15 +31,14 @@ class DerivationResult( /** * compute divergence */ - fun div(): T = context.run { sum(deriv.values) } + fun div(): T = context { sum(deriv.values) } /** * Compute a gradient for variables in given order */ - fun grad(vararg variables: Variable): Point = if (variables.isEmpty()) { - error("Variable order is not provided for gradient construction") - } else { - variables.map(::deriv).asBuffer() + fun grad(vararg variables: Variable): Point { + check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } + return variables.map(::deriv).asBuffer() } } @@ -52,19 +55,28 @@ class DerivationResult( * assertEquals(9.0, x.d) // dy/dx * ``` */ -fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult = - AutoDiffContext(this).run { +@OptIn(ExperimentalContracts::class) +inline fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult { + contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } + + return (AutoDiffContext(this)) { val result = body() - result.d = context.one// computing derivative w.r.t result + result.d = context.one // computing derivative w.r.t result runBackwardPass() DerivationResult(result.value, derivatives, this@deriv) } +} abstract class AutoDiffField> : Field> { - abstract val context: F + /** + * A variable accessing inner state of derivatives. + * Use this function in inner builders to avoid creating additional derivative bindings + */ + abstract var Variable.d: T + /** * Performs update of derivative after the rest of the formula in the back-pass. * @@ -78,12 +90,6 @@ abstract class AutoDiffField> : Field> { */ abstract fun derive(value: R, block: F.(R) -> Unit): R - /** - * A variable accessing inner state of derivatives. - * Use this function in inner builders to avoid creating additional derivative bindings - */ - abstract var Variable.d: T - abstract fun variable(value: T): Variable inline fun variable(block: F.() -> T): Variable = variable(context.block()) @@ -98,46 +104,35 @@ abstract class AutoDiffField> : Field> { override operator fun Variable.plus(b: Number): Variable = b.plus(this) override operator fun Number.minus(b: Variable): Variable = - derive(variable { this@minus.toDouble() * one - b.value }) { z -> - b.d -= z.d - } + derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } override operator fun Variable.minus(b: Number): Variable = - derive(variable { this@minus.value - one * b.toDouble() }) { z -> - this@minus.d += z.d - } + derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } } /** * Automatic Differentiation context class. */ -private class AutoDiffContext>(override val context: F) : AutoDiffField() { - +@PublishedApi +internal class AutoDiffContext>(override val context: F) : AutoDiffField() { // this stack contains pairs of blocks and values to apply them to - private var stack = arrayOfNulls(8) - private var sp = 0 - - internal val derivatives = HashMap, T>() - + private var stack: Array = arrayOfNulls(8) + private var sp: Int = 0 + val derivatives: MutableMap, T> = hashMapOf() + override val zero: Variable get() = Variable(context.zero) + override val one: Variable get() = Variable(context.one) /** * A variable coupled with its derivative. For internal use only */ private class VariableWithDeriv(x: T, var d: T) : Variable(x) - override fun variable(value: T): Variable = VariableWithDeriv(value, context.zero) override var Variable.d: T get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero - set(value) { - if (this is VariableWithDeriv) { - d = value - } else { - derivatives[this] = value - } - } + set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value @Suppress("UNCHECKED_CAST") override fun derive(value: R, block: F.(R) -> Unit): R { @@ -160,67 +155,49 @@ private class AutoDiffContext>(override val context: F) : // Basic math (+, -, *, /) - override fun add(a: Variable, b: Variable): Variable = - derive(variable { a.value + b.value }) { z -> - a.d += z.d - b.d += z.d - } + override fun add(a: Variable, b: Variable): Variable = derive(variable { a.value + b.value }) { z -> + a.d += z.d + b.d += z.d + } - override fun multiply(a: Variable, b: Variable): Variable = - derive(variable { a.value * b.value }) { z -> - a.d += z.d * b.value - b.d += z.d * a.value - } + override fun multiply(a: Variable, b: Variable): Variable = derive(variable { a.value * b.value }) { z -> + a.d += z.d * b.value + b.d += z.d * a.value + } - override fun divide(a: Variable, b: Variable): Variable = - derive(variable { a.value / b.value }) { z -> - a.d += z.d / b.value - b.d -= z.d * a.value / (b.value * b.value) - } + override fun divide(a: Variable, b: Variable): Variable = derive(variable { a.value / b.value }) { z -> + a.d += z.d / b.value + b.d -= z.d * a.value / (b.value * b.value) + } - override fun multiply(a: Variable, k: Number): Variable = - derive(variable { k.toDouble() * a.value }) { z -> - a.d += z.d * k.toDouble() - } - - override val zero: Variable get() = Variable(context.zero) - override val one: Variable get() = Variable(context.one) + override fun multiply(a: Variable, k: Number): Variable = derive(variable { k.toDouble() * a.value }) { z -> + a.d += z.d * k.toDouble() + } } // Extensions for differentiation of various basic mathematical functions // x ^ 2 fun > AutoDiffField.sqr(x: Variable): Variable = - derive(variable { x.value * x.value }) { z -> - x.d += z.d * 2 * x.value - } + derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } // x ^ 1/2 fun > AutoDiffField.sqrt(x: Variable): Variable = - derive(variable { sqrt(x.value) }) { z -> - x.d += z.d * 0.5 / z.value - } + derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } // x ^ y (const) fun > AutoDiffField.pow(x: Variable, y: Double): Variable = - derive(variable { power(x.value, y) }) { z -> - x.d += z.d * y * power(x.value, y - 1) - } + derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } fun > AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) // exp(x) fun > AutoDiffField.exp(x: Variable): Variable = - derive(variable { exp(x.value) }) { z -> - x.d += z.d * z.value - } + derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value } // ln(x) -fun > AutoDiffField.ln(x: Variable): Variable = derive( - variable { ln(x.value) } -) { z -> - x.d += z.d / x.value -} +fun > AutoDiffField.ln(x: Variable): Variable = + derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value } // x ^ y (any) fun > AutoDiffField.pow(x: Variable, y: Variable): Variable = @@ -228,12 +205,8 @@ fun > AutoDiffField.pow(x: Variable, y: V // sin(x) fun > AutoDiffField.sin(x: Variable): Variable = - derive(variable { sin(x.value) }) { z -> - x.d += z.d * cos(x.value) - } + derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } // cos(x) fun > AutoDiffField.cos(x: Variable): Variable = - derive(variable { cos(x.value) }) { z -> - x.d -= z.d * sin(x.value) - } + derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt index d3bf0891f..1272ddd1c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt @@ -41,6 +41,6 @@ fun ClosedFloatingPointRange.toSequenceWithPoints(numPoints: Int): Seque */ @Deprecated("Replace by 'toSequenceWithPoints'") fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { - if (numPoints < 2) error("Can't create generic grid with less than two points") + require(numPoints >= 2) { "Can't create generic grid with less than two points" } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt index a0f4525cc..dc7a4bf1f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt @@ -2,6 +2,8 @@ package scientifik.kmath.misc import scientifik.kmath.operations.Space import scientifik.kmath.operations.invoke +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.jvm.JvmName /** @@ -11,67 +13,69 @@ import kotlin.jvm.JvmName * @param R the type of resulting iterable. * @param initial lazy evaluated. */ -fun Iterator.cumulative(initial: R, operation: (R, T) -> R): Iterator = object : Iterator { - var state: R = initial - override fun hasNext(): Boolean = this@cumulative.hasNext() +@OptIn(ExperimentalContracts::class) +inline fun Iterator.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator { + contract { callsInPlace(operation) } - override fun next(): R { - state = operation(state, this@cumulative.next()) - return state + return object : Iterator { + var state: R = initial + + override fun hasNext(): Boolean = this@cumulative.hasNext() + + override fun next(): R { + state = operation(state, this@cumulative.next()) + return state + } } } -fun Iterable.cumulative(initial: R, operation: (R, T) -> R): Iterable = object : Iterable { - override fun iterator(): Iterator = this@cumulative.iterator().cumulative(initial, operation) -} +inline fun Iterable.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable = + Iterable { this@cumulative.iterator().cumulative(initial, operation) } -fun Sequence.cumulative(initial: R, operation: (R, T) -> R): Sequence = object : Sequence { - override fun iterator(): Iterator = this@cumulative.iterator().cumulative(initial, operation) +inline fun Sequence.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence = Sequence { + this@cumulative.iterator().cumulative(initial, operation) } fun List.cumulative(initial: R, operation: (R, T) -> R): List = - this.iterator().cumulative(initial, operation).asSequence().toList() + iterator().cumulative(initial, operation).asSequence().toList() //Cumulative sum /** * Cumulative sum with custom space */ -fun Iterable.cumulativeSum(space: Space): Iterable = space { - cumulative(zero) { element: T, sum: T -> sum + element } -} +fun Iterable.cumulativeSum(space: Space): Iterable = + space { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun Iterable.cumulativeSum(): Iterable = this.cumulative(0.0) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Iterable.cumulativeSum(): Iterable = this.cumulative(0) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Iterable.cumulativeSum(): Iterable = this.cumulative(0L) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = cumulative(0L) { element, sum -> sum + element } -fun Sequence.cumulativeSum(space: Space): Sequence = with(space) { - cumulative(zero) { element: T, sum: T -> sum + element } -} +fun Sequence.cumulativeSum(space: Space): Sequence = + space { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun Sequence.cumulativeSum(): Sequence = this.cumulative(0.0) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Sequence.cumulativeSum(): Sequence = this.cumulative(0) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Sequence.cumulativeSum(): Sequence = this.cumulative(0L) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = cumulative(0L) { element, sum -> sum + element } -fun List.cumulativeSum(space: Space): List = with(space) { - cumulative(zero) { element: T, sum: T -> sum + element } -} +fun List.cumulativeSum(space: Space): List = + space { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun List.cumulativeSum(): List = this.cumulative(0.0) { element, sum -> sum + element } +fun List.cumulativeSum(): List = cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun List.cumulativeSum(): List = this.cumulative(0) { element, sum -> sum + element } +fun List.cumulativeSum(): List = cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun List.cumulativeSum(): List = this.cumulative(0L) { element, sum -> sum + element } +fun List.cumulativeSum(): List = cumulative(0L) { element, sum -> sum + element } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt index 07163750b..d0c4989d8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt @@ -3,6 +3,8 @@ package scientifik.kmath.operations import scientifik.kmath.operations.BigInt.Companion.BASE import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE import scientifik.kmath.structures.* +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.log2 import kotlin.math.max import kotlin.math.min @@ -431,8 +433,8 @@ fun ULong.toBigInt(): BigInt = BigInt( * Create a [BigInt] with this array of magnitudes with protective copy */ fun UIntArray.toBigInt(sign: Byte): BigInt { - if (sign == 0.toByte() && isNotEmpty()) error("") - return BigInt(sign, this.copyOf()) + require(sign != 0.toByte() || !isNotEmpty()) + return BigInt(sign, copyOf()) } val hexChToInt: MutableMap = hashMapOf( @@ -485,11 +487,17 @@ fun String.parseBigInteger(): BigInt? { return res * sign } -inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer = - boxing(size, initializer) +@OptIn(ExperimentalContracts::class) +inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer { + contract { callsInPlace(initializer) } + return boxing(size, initializer) +} -inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer = - boxing(size, initializer) +@OptIn(ExperimentalContracts::class) +inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer { + contract { callsInPlace(initializer) } + return boxing(size, initializer) +} fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing = BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) @@ -497,5 +505,4 @@ fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing BigInt -): BufferedNDRingElement = - NDAlgebra.bigInt(*shape).produce(initializer) +): BufferedNDRingElement = NDAlgebra.bigInt(*shape).produce(initializer) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 76df0f45d..e80e6983a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -6,6 +6,8 @@ import scientifik.kmath.structures.MutableBuffer import scientifik.memory.MemoryReader import scientifik.memory.MemorySpec import scientifik.memory.MemoryWriter +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.* /** @@ -196,10 +198,14 @@ data class Complex(val re: Double, val im: Double) : FieldElement Complex): Buffer { + contract { callsInPlace(init) } return MemoryBuffer.create(Complex, size, init) } +@OptIn(ExperimentalContracts::class) inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer { + contract { callsInPlace(init) } return MemoryBuffer.create(Complex, size, init) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt index 4cbb565c1..be71645d1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt @@ -8,19 +8,17 @@ class BoxingNDField>( override val elementContext: F, val bufferFactory: BufferFactory ) : BufferedNDField { - + override val zero: BufferedNDFieldElement by lazy { produce { zero } } + override val one: BufferedNDFieldElement by lazy { produce { one } } override val strides: Strides = DefaultStrides(shape) fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) override fun check(vararg elements: NDBuffer) { - if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") + check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } } - override val zero: BufferedNDFieldElement by lazy { produce { zero } } - override val one: BufferedNDFieldElement by lazy { produce { one } } - override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement = BufferedNDFieldElement( this, @@ -28,6 +26,7 @@ class BoxingNDField>( override fun map(arg: NDBuffer, transform: F.(T) -> T): BufferedNDFieldElement { check(arg) + return BufferedNDFieldElement( this, buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) }) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt index f7be95736..91b945e79 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt @@ -8,19 +8,16 @@ class BoxingNDRing>( override val elementContext: R, val bufferFactory: BufferFactory ) : BufferedNDRing { - override val strides: Strides = DefaultStrides(shape) - - fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = - bufferFactory(size, initializer) - - override fun check(vararg elements: NDBuffer) { - if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") - } - override val zero: BufferedNDRingElement by lazy { produce { zero } } override val one: BufferedNDRingElement by lazy { produce { one } } + fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) + + override fun check(vararg elements: NDBuffer) { + require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + } + override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement = BufferedNDRingElement( this, diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt index 00832b69c..2c3d69094 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt @@ -6,7 +6,6 @@ import kotlin.reflect.KClass * A context that allows to operate on a [MutableBuffer] as on 2d array */ class BufferAccessor2D(val type: KClass, val rowNum: Int, val colNum: Int) { - operator fun Buffer.get(i: Int, j: Int): T = get(i + colNum * j) operator fun MutableBuffer.set(i: Int, j: Int, value: T) { @@ -26,15 +25,14 @@ class BufferAccessor2D(val type: KClass, val rowNum: Int, val colNum inner class Row(val buffer: MutableBuffer, val rowIndex: Int) : MutableBuffer { override val size: Int get() = colNum - override fun get(index: Int): T = buffer[rowIndex, index] + override operator fun get(index: Int): T = buffer[rowIndex, index] - override fun set(index: Int, value: T) { + override operator fun set(index: Int, value: T) { buffer[rowIndex, index] = value } override fun copy(): MutableBuffer = MutableBuffer.auto(type, colNum) { get(it) } - - override fun iterator(): Iterator = (0 until colNum).map(::get).iterator() + override operator fun iterator(): Iterator = (0 until colNum).map(::get).iterator() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt index 06922c56f..2c0c2021f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt @@ -5,9 +5,8 @@ import scientifik.kmath.operations.* interface BufferedNDAlgebra : NDAlgebra> { val strides: Strides - override fun check(vararg elements: NDBuffer) { - if (!elements.all { it.strides == this.strides }) error("Strides mismatch") - } + override fun check(vararg elements: NDBuffer): Unit = + require(elements.all { it.strides == strides }) { ("Strides mismatch") } /** * Convert any [NDStructure] to buffered structure using strides from this context. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt index d1d622b23..20e34fadd 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt @@ -30,7 +30,6 @@ class BufferedNDRingElement>( override val context: BufferedNDRing, override val buffer: Buffer ) : BufferedNDElement(), RingElement, BufferedNDRingElement, BufferedNDRing> { - override fun unwrap(): NDBuffer = this override fun NDBuffer.wrap(): BufferedNDRingElement { @@ -43,7 +42,6 @@ class BufferedNDFieldElement>( override val context: BufferedNDField, override val buffer: Buffer ) : BufferedNDElement(), FieldElement, BufferedNDFieldElement, BufferedNDField> { - override fun unwrap(): NDBuffer = this override fun NDBuffer.wrap(): BufferedNDFieldElement { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index 5fdf79e88..e065f4990 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -2,6 +2,8 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Complex import scientifik.kmath.operations.complex +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.reflect.KClass /** @@ -117,15 +119,14 @@ interface MutableBuffer : Buffer { MutableListBuffer(MutableList(size, initializer)) @Suppress("UNCHECKED_CAST") - inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer { - return when (type) { + inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer = + when (type) { Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer else -> boxing(size, initializer) } - } /** * Create most appropriate mutable buffer for given type avoiding boxing wherever possible @@ -150,9 +151,8 @@ inline class ListBuffer(val list: List) : Buffer { override val size: Int get() = list.size - override fun get(index: Int): T = list[index] - - override fun iterator(): Iterator = list.iterator() + override operator fun get(index: Int): T = list[index] + override operator fun iterator(): Iterator = list.iterator() } /** @@ -167,7 +167,11 @@ fun List.asBuffer(): ListBuffer = ListBuffer(this) * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an array element given its index. */ -inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer = List(size, init).asBuffer() +@OptIn(ExperimentalContracts::class) +inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer { + contract { callsInPlace(init) } + return List(size, init).asBuffer() +} /** * [MutableBuffer] implementation over [MutableList]. @@ -176,17 +180,16 @@ inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer = List(siz * @property list The underlying list. */ inline class MutableListBuffer(val list: MutableList) : MutableBuffer { - override val size: Int get() = list.size - override fun get(index: Int): T = list[index] + override operator fun get(index: Int): T = list[index] - override fun set(index: Int, value: T) { + override operator fun set(index: Int, value: T) { list[index] = value } - override fun iterator(): Iterator = list.iterator() + override operator fun iterator(): Iterator = list.iterator() override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) } @@ -201,14 +204,13 @@ class ArrayBuffer(private val array: Array) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): T = array[index] + override operator fun get(index: Int): T = array[index] - override fun set(index: Int, value: T) { + override operator fun set(index: Int, value: T) { array[index] = value } - override fun iterator(): Iterator = array.iterator() - + override operator fun iterator(): Iterator = array.iterator() override fun copy(): MutableBuffer = ArrayBuffer(array.copyOf()) } @@ -226,9 +228,9 @@ fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { override val size: Int get() = buffer.size - override fun get(index: Int): T = buffer[index] + override operator fun get(index: Int): T = buffer[index] - override fun iterator(): Iterator = buffer.iterator() + override operator fun iterator(): Iterator = buffer.iterator() } /** @@ -238,12 +240,12 @@ inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { * @param T the type of elements provided by the buffer. */ class VirtualBuffer(override val size: Int, private val generator: (Int) -> T) : Buffer { - override fun get(index: Int): T { + override operator fun get(index: Int): T { if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") return generator(index) } - override fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() + override operator fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() override fun contentEquals(other: Buffer<*>): Boolean { return if (other is VirtualBuffer) { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt index 370d8ff4d..a4855ca7c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt @@ -4,6 +4,9 @@ import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.complex +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract typealias ComplexNDElement = BufferedNDFieldElement @@ -109,7 +112,9 @@ inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(inde /** * Map one [ComplexNDElement] using function without indices. */ +@OptIn(ExperimentalContracts::class) inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { + contract { callsInPlace(transform) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } return BufferedNDFieldElement(context, buffer) } @@ -148,6 +153,8 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In /** * Produce a context for n-dimensional operations inside this real field */ +@OptIn(ExperimentalContracts::class) inline fun ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { - return NDField.complex(*shape).run(action) + contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } + return NDField.complex(*shape).action() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt index a2d0a71b3..b78ac0beb 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt @@ -1,5 +1,7 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.experimental.and /** @@ -57,17 +59,19 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged override val size: Int get() = values.size - override fun get(index: Int): Double? = if (isValid(index)) values[index] else null + override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null - override fun iterator(): Iterator = values.indices.asSequence().map { + override operator fun iterator(): Iterator = values.indices.asSequence().map { if (isValid(it)) values[it] else null }.iterator() } +@OptIn(ExperimentalContracts::class) inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { - for (i in indices) { - if (isValid(i)) { - block(values[i]) - } - } + contract { callsInPlace(block) } + + indices + .asSequence() + .filter(::isValid) + .forEach { block(values[it]) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt index e42df8c14..24822056b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [FloatArray]. * @@ -8,13 +11,13 @@ package scientifik.kmath.structures inline class FloatBuffer(val array: FloatArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Float = array[index] + override operator fun get(index: Int): Float = array[index] - override fun set(index: Int, value: Float) { + override operator fun set(index: Int, value: Float) { array[index] = value } - override fun iterator(): FloatIterator = array.iterator() + override operator fun iterator(): FloatIterator = array.iterator() override fun copy(): MutableBuffer = FloatBuffer(array.copyOf()) @@ -27,7 +30,11 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) }) +@OptIn(ExperimentalContracts::class) +inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer { + contract { callsInPlace(init) } + return FloatBuffer(FloatArray(size) { init(it) }) +} /** * Returns a new [FloatBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt index a3f0f3c3e..229b62b48 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt @@ -1,5 +1,9 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [IntArray]. * @@ -8,17 +12,16 @@ package scientifik.kmath.structures inline class IntBuffer(val array: IntArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Int = array[index] + override operator fun get(index: Int): Int = array[index] - override fun set(index: Int, value: Int) { + override operator fun set(index: Int, value: Int) { array[index] = value } - override fun iterator(): IntIterator = array.iterator() + override operator fun iterator(): IntIterator = array.iterator() override fun copy(): MutableBuffer = IntBuffer(array.copyOf()) - } /** @@ -28,7 +31,11 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) }) +@OptIn(ExperimentalContracts::class) +inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer { + contract { callsInPlace(init) } + return IntBuffer(IntArray(size) { init(it) }) +} /** * Returns a new [IntBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt index 912656c68..cfedb5f35 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [LongArray]. * @@ -8,13 +11,13 @@ package scientifik.kmath.structures inline class LongBuffer(val array: LongArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Long = array[index] + override operator fun get(index: Int): Long = array[index] - override fun set(index: Int, value: Long) { + override operator fun set(index: Int, value: Long) { array[index] = value } - override fun iterator(): LongIterator = array.iterator() + override operator fun iterator(): LongIterator = array.iterator() override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) @@ -28,7 +31,11 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) }) +@OptIn(ExperimentalContracts::class) +inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer { + contract { callsInPlace(init) } + return LongBuffer(LongArray(size) { init(it) }) +} /** * Returns a new [LongBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt index 1d0c87580..83c50b14b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt @@ -14,10 +14,8 @@ open class MemoryBuffer(protected val memory: Memory, protected val spe private val reader: MemoryReader = memory.reader() - override fun get(index: Int): T = reader.read(spec, spec.objectSize * index) - - override fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() - + override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index) + override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() companion object { fun create(spec: MemorySpec, size: Int): MemoryBuffer = @@ -48,8 +46,7 @@ class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : Memory private val writer: MemoryWriter = memory.writer() - override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) - + override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) override fun copy(): MutableBuffer = MutableMemoryBuffer(memory.copy(), spec) companion object { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt index 9dfe2b5a8..6cc0a72c0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt @@ -26,19 +26,20 @@ interface NDElement> : NDStructure { fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement = NDField.real(*shape).produce(initializer) - - fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = + inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = real(intArrayOf(dim)) { initializer(it[0]) } + inline fun real2D( + dim1: Int, + dim2: Int, + crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 } + ): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement = - real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - - fun real3D( + inline fun real3D( dim1: Int, dim2: Int, dim3: Int, - initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 } + crossinline initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 } ): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } @@ -72,7 +73,6 @@ fun > NDElement.mapIndexed(transform: C.(index fun > NDElement.map(transform: C.(T) -> T): NDElement = context.map(unwrap(), transform).wrap() - /** * Element by element application of any operation on elements to the whole [NDElement] */ @@ -107,7 +107,6 @@ operator fun , N : NDStructure> NDElement.times(arg: operator fun , N : NDStructure> NDElement.div(arg: T): NDElement = map { value -> arg / value } - // /** // * Reverse sum operation // */ diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 9d7735053..a16caab68 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -1,5 +1,7 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.jvm.JvmName import kotlin.reflect.KClass @@ -138,10 +140,10 @@ interface MutableNDStructure : NDStructure { operator fun set(index: IntArray, value: T) } +@OptIn(ExperimentalContracts::class) inline fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T) { - elements().forEach { (index, oldValue) -> - this[index] = action(index, oldValue) - } + contract { callsInPlace(action) } + elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } } /** @@ -200,14 +202,12 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides }.toList() } - override fun offset(index: IntArray): Int { - return index.mapIndexed { i, value -> - if (value < 0 || value >= this.shape[i]) { - throw RuntimeException("Index $value out of shape bounds: (0,${this.shape[i]})") - } - value * strides[i] - }.sum() - } + override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> + if (value < 0 || value >= this.shape[i]) + throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") + + value * strides[i] + }.sum() override fun index(offset: Int): IntArray { val res = IntArray(shape.size) @@ -259,7 +259,7 @@ abstract class NDBuffer : NDStructure { */ abstract val strides: Strides - override fun get(index: IntArray): T = buffer[strides.offset(index)] + override operator fun get(index: IntArray): T = buffer[strides.offset(index)] override val shape: IntArray get() = strides.shape @@ -319,13 +319,13 @@ class MutableBufferNDStructure( } } - override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) + override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) } inline fun NDStructure.combine( struct: NDStructure, crossinline block: (T, T) -> T ): NDStructure { - if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") + require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" } return NDStructure.auto(shape) { block(this[it], struct[it]) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt index e999e12b2..f897134db 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [DoubleArray]. * @@ -8,13 +11,13 @@ package scientifik.kmath.structures inline class RealBuffer(val array: DoubleArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Double = array[index] + override operator fun get(index: Int): Double = array[index] - override fun set(index: Int, value: Double) { + override operator fun set(index: Int, value: Double) { array[index] = value } - override fun iterator(): DoubleIterator = array.iterator() + override operator fun iterator(): DoubleIterator = array.iterator() override fun copy(): MutableBuffer = RealBuffer(array.copyOf()) @@ -27,7 +30,11 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) +@OptIn(ExperimentalContracts::class) +inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer { + contract { callsInPlace(init) } + return RealBuffer(DoubleArray(size) { init(it) }) +} /** * Returns a new [RealBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt index c6f19feaf..08d4a4376 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [ShortArray]. * @@ -8,17 +11,16 @@ package scientifik.kmath.structures inline class ShortBuffer(val array: ShortArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Short = array[index] + override operator fun get(index: Int): Short = array[index] - override fun set(index: Int, value: Short) { + override operator fun set(index: Int, value: Short) { array[index] = value } - override fun iterator(): ShortIterator = array.iterator() + override operator fun iterator(): ShortIterator = array.iterator() override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) - } /** @@ -28,7 +30,11 @@ inline class ShortBuffer(val array: ShortArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) }) +@OptIn(ExperimentalContracts::class) +inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer { + contract { callsInPlace(init) } + return ShortBuffer(ShortArray(size) { init(it) }) +} /** * Returns a new [ShortBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt index faf022367..a796c2037 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt @@ -6,12 +6,12 @@ package scientifik.kmath.structures interface Structure1D : NDStructure, Buffer { override val dimension: Int get() = 1 - override fun get(index: IntArray): T { - if (index.size != 1) error("Index dimension mismatch. Expected 1 but found ${index.size}") + override operator fun get(index: IntArray): T { + require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" } return get(index[0]) } - override fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() + override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() } /** @@ -22,7 +22,7 @@ private inline class Structure1DWrapper(val structure: NDStructure) : Stru override val shape: IntArray get() = structure.shape override val size: Int get() = structure.shape[0] - override fun get(index: Int): T = structure[index] + override operator fun get(index: Int): T = structure[index] override fun elements(): Sequence> = structure.elements() } @@ -39,7 +39,7 @@ private inline class Buffer1DWrapper(val buffer: Buffer) : Structure1D override fun elements(): Sequence> = asSequence().mapIndexed { index, value -> intArrayOf(index) to value } - override fun get(index: Int): T = buffer[index] + override operator fun get(index: Int): T = buffer[index] } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt index 30fd556d3..eeb6bd3dc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt @@ -9,8 +9,8 @@ interface Structure2D : NDStructure { operator fun get(i: Int, j: Int): T - override fun get(index: IntArray): T { - if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}") + override operator fun get(index: IntArray): T { + require(index.size == 2) { "Index dimension mismatch. Expected 2 but found ${index.size}" } return get(index[0], index[1]) } @@ -39,10 +39,10 @@ interface Structure2D : NDStructure { * A 2D wrapper for nd-structure */ private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { - override fun get(i: Int, j: Int): T = structure[i, j] - override val shape: IntArray get() = structure.shape + override operator fun get(i: Int, j: Int): T = structure[i, j] + override fun elements(): Sequence> = structure.elements() } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index 22b924ef9..485de08b4 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -3,6 +3,7 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals @@ -10,10 +11,12 @@ class ExpressionFieldTest { @Test fun testExpression() { val context = FunctionalExpressionField(RealField) - val expression = with(context) { + + val expression = context { val x = variable("x", 2.0) x * x + 2 * x + one } + assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression(), 9.0) } @@ -21,10 +24,12 @@ class ExpressionFieldTest { @Test fun testComplex() { val context = FunctionalExpressionField(ComplexField) - val expression = with(context) { + + val expression = context { val x = variable("x", Complex(2.0, 0.0)) x * x + 2 * x + one } + assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) assertEquals(expression(), Complex(9.0, 0.0)) } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt index 987426250..52a2f80a6 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt @@ -7,7 +7,6 @@ import kotlin.test.Test import kotlin.test.assertEquals class MatrixTest { - @Test fun testTranspose() { val matrix = MatrixContext.real.one(3, 3) @@ -51,6 +50,7 @@ class MatrixTest { fun test2DDot() { val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D() val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D() + MatrixContext.real.run { // val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() } // val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt index d48aabfd0..b7e2594ec 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt @@ -1,6 +1,7 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Norm +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.NDElement.Companion.real2D import kotlin.math.abs import kotlin.math.pow @@ -56,17 +57,12 @@ class NumberNDFieldTest { } object L2Norm : Norm, Double> { - override fun norm(arg: NDStructure): Double { - return kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() }) - } + override fun norm(arg: NDStructure): Double = + kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() }) } @Test fun testInternalContext() { - NDField.real(*array1.shape).run { - with(L2Norm) { - 1 + norm(array1) + exp(array2) - } - } + (NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } } } } diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt index 06f2b31ad..f10ef24da 100644 --- a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt @@ -17,10 +17,10 @@ object JBigIntegerField : Field { override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) - override fun BigInteger.minus(b: BigInteger): BigInteger = this.subtract(b) + override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) - override fun BigInteger.unaryMinus(): BigInteger = negate() + override operator fun BigInteger.unaryMinus(): BigInteger = negate() } /** @@ -38,7 +38,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo get() = BigDecimal.ONE override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) - override fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) + override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun multiply(a: BigDecimal, k: Number): BigDecimal = @@ -48,8 +48,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) - override fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) - + override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) } /** diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 6cc9770af..aac4f2534 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -139,9 +139,10 @@ fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain { override suspend fun next(): T { var next: T - do { - next = this@filter.next() - } while (!block(next)) + + do next = this@filter.next() + while (!block(next)) + return next } @@ -159,6 +160,7 @@ fun Chain.collect(mapper: suspend (Chain) -> R): Chain = object fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain) -> R): Chain = object : Chain { override suspend fun next(): R = state.mapper(this@collectWithState) + override fun fork(): Chain = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) } @@ -168,6 +170,5 @@ fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: s */ fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = object : Chain { override suspend fun next(): R = block(this@zip.next(), other.next()) - override fun fork(): Chain = this@zip.fork().zip(other.fork(), block) } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt index e8537304c..6624722ce 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt @@ -7,15 +7,16 @@ import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.scanReduce import scientifik.kmath.operations.Space import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.operations.invoke @ExperimentalCoroutinesApi -fun Flow.cumulativeSum(space: SpaceOperations): Flow = with(space) { +fun Flow.cumulativeSum(space: SpaceOperations): Flow = space { scanReduce { sum: T, element: T -> sum + element } } @ExperimentalCoroutinesApi -fun Flow.mean(space: Space): Flow = with(space) { +fun Flow.mean(space: Space): Flow = space { class Accumulator(var sum: T, var num: Int) scan(Accumulator(zero, 0)) { sum, element -> diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt index 7e00b30a1..8ee6c52ad 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt @@ -3,6 +3,8 @@ package scientifik.kmath.coroutines import kotlinx.coroutines.* import kotlinx.coroutines.channels.produce import kotlinx.coroutines.flow.* +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract val Dispatchers.Math: CoroutineDispatcher get() = Default @@ -81,21 +83,24 @@ suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector< } } +@OptIn(ExperimentalContracts::class) @ExperimentalCoroutinesApi @FlowPreview -suspend fun AsyncFlow.collect(concurrency: Int, action: suspend (value: T) -> Unit) { +suspend inline fun AsyncFlow.collect(concurrency: Int, crossinline action: suspend (value: T) -> Unit) { + contract { callsInPlace(action) } + collect(concurrency, object : FlowCollector { override suspend fun emit(value: T): Unit = action(value) }) } +@OptIn(ExperimentalContracts::class) @ExperimentalCoroutinesApi @FlowPreview -fun Flow.mapParallel( +inline fun Flow.mapParallel( dispatcher: CoroutineDispatcher = Dispatchers.Default, - transform: suspend (T) -> R + crossinline transform: suspend (T) -> R ): Flow { - return flatMapMerge { value -> - flow { emit(transform(value)) } - }.flowOn(dispatcher) + contract { callsInPlace(transform) } + return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher) } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt index 245d003b3..f1c0bfc6a 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt @@ -20,7 +20,7 @@ class RingBuffer( override var size: Int = size private set - override fun get(index: Int): T { + override operator fun get(index: Int): T { require(index >= 0) { "Index must be positive" } require(index < size) { "Index $index is out of circular buffer size $size" } return buffer[startIndex.forward(index)] as T @@ -31,15 +31,13 @@ class RingBuffer( /** * Iterator could provide wrong results if buffer is changed in initialization (iteration is safe) */ - override fun iterator(): Iterator = object : AbstractIterator() { + override operator fun iterator(): Iterator = object : AbstractIterator() { private var count = size private var index = startIndex val copy = buffer.copy() override fun computeNext() { - if (count == 0) { - done() - } else { + if (count == 0) done() else { setNext(copy[index] as T) index = index.forward(1) count-- diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt index 0a3c67e00..5686b0ac0 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt @@ -1,7 +1,6 @@ package scientifik.kmath.chains import kotlinx.coroutines.runBlocking -import kotlin.sequences.Sequence /** * Represent a chain as regular iterator (uses blocking calls) @@ -15,6 +14,4 @@ operator fun Chain.iterator(): Iterator = object : Iterator { /** * Represent a chain as a sequence */ -fun Chain.asSequence(): Sequence = object : Sequence { - override fun iterator(): Iterator = this@asSequence.iterator() -} \ No newline at end of file +fun Chain.asSequence(): Sequence = Sequence { this@asSequence.iterator() } \ No newline at end of file diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt index 8d5145976..ff732a06b 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt @@ -18,7 +18,7 @@ class LazyNDStructure( suspend fun await(index: IntArray): T = deferred(index).await() - override fun get(index: IntArray): T = runBlocking { + override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() } @@ -52,10 +52,12 @@ suspend fun NDStructure.await(index: IntArray): T = /** * PENDING would benefit from KEEP-176 */ -fun NDStructure.mapAsyncIndexed( +inline fun NDStructure.mapAsyncIndexed( scope: CoroutineScope, - function: suspend (T, index: IntArray) -> R + crossinline function: suspend (T, index: IntArray) -> R ): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index), index) } -fun NDStructure.mapAsync(scope: CoroutineScope, function: suspend (T) -> R): LazyNDStructure = - LazyNDStructure(scope, shape) { index -> function(get(index)) } \ No newline at end of file +inline fun NDStructure.mapAsync( + scope: CoroutineScope, + crossinline function: suspend (T) -> R +): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index)) } diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt index f447866c0..7b0244bdf 100644 --- a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt @@ -4,7 +4,9 @@ import scientifik.kmath.linear.GenericMatrixContext import scientifik.kmath.linear.MatrixContext import scientifik.kmath.linear.Point import scientifik.kmath.linear.transpose +import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Structure2D @@ -42,7 +44,7 @@ inline class DMatrixWrapper( val structure: Structure2D ) : DMatrix { override val shape: IntArray get() = structure.shape - override fun get(i: Int, j: Int): T = structure[i, j] + override operator fun get(i: Int, j: Int): T = structure[i, j] } /** @@ -70,9 +72,9 @@ inline class DPointWrapper(val point: Point) : DPoint { override val size: Int get() = point.size - override fun get(index: Int): T = point[index] + override operator fun get(index: Int): T = point[index] - override fun iterator(): Iterator = point.iterator() + override operator fun iterator(): Iterator = point.iterator() } @@ -82,12 +84,14 @@ inline class DPointWrapper(val point: Point) : inline class DMatrixContext>(val context: GenericMatrixContext) { inline fun Matrix.coerce(): DMatrix { - if (rowNum != Dimension.dim().toInt()) { - error("Row number mismatch: expected ${Dimension.dim()} but found $rowNum") - } - if (colNum != Dimension.dim().toInt()) { - error("Column number mismatch: expected ${Dimension.dim()} but found $colNum") - } + check( + rowNum == Dimension.dim().toInt() + ) { "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" } + + check( + colNum == Dimension.dim().toInt() + ) { "Column number mismatch: expected ${Dimension.dim()} but found $colNum" } + return DMatrix.coerceUnsafe(this) } @@ -97,11 +101,12 @@ inline class DMatrixContext>(val context: GenericMatrixCon inline fun produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix { val rows = Dimension.dim() val cols = Dimension.dim() - return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() + return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() } inline fun point(noinline initializer: (Int) -> T): DPoint { val size = Dimension.dim() + return DPoint.coerceUnsafe( context.point( size.toInt(), @@ -112,37 +117,28 @@ inline class DMatrixContext>(val context: GenericMatrixCon inline infix fun DMatrix.dot( other: DMatrix - ): DMatrix { - return context.run { this@dot dot other }.coerce() - } + ): DMatrix = context { this@dot dot other }.coerce() - inline infix fun DMatrix.dot(vector: DPoint): DPoint { - return DPoint.coerceUnsafe(context.run { this@dot dot vector }) - } + inline infix fun DMatrix.dot(vector: DPoint): DPoint = + DPoint.coerceUnsafe(context { this@dot dot vector }) - inline operator fun DMatrix.times(value: T): DMatrix { - return context.run { this@times.times(value) }.coerce() - } + inline operator fun DMatrix.times(value: T): DMatrix = + context { this@times.times(value) }.coerce() inline operator fun T.times(m: DMatrix): DMatrix = m * this + inline operator fun DMatrix.plus(other: DMatrix): DMatrix = + context { this@plus + other }.coerce() - inline operator fun DMatrix.plus(other: DMatrix): DMatrix { - return context.run { this@plus + other }.coerce() - } + inline operator fun DMatrix.minus(other: DMatrix): DMatrix = + context { this@minus + other }.coerce() - inline operator fun DMatrix.minus(other: DMatrix): DMatrix { - return context.run { this@minus + other }.coerce() - } + inline operator fun DMatrix.unaryMinus(): DMatrix = + context { this@unaryMinus.unaryMinus() }.coerce() - inline operator fun DMatrix.unaryMinus(): DMatrix { - return context.run { this@unaryMinus.unaryMinus() }.coerce() - } - - inline fun DMatrix.transpose(): DMatrix { - return context.run { (this@transpose as Matrix).transpose() }.coerce() - } + inline fun DMatrix.transpose(): DMatrix = + context { (this@transpose as Matrix).transpose() }.coerce() /** * A square unit matrix @@ -156,6 +152,6 @@ inline class DMatrixContext>(val context: GenericMatrixCon } companion object { - val real = DMatrixContext(MatrixContext.real) + val real: DMatrixContext = DMatrixContext(MatrixContext.real) } -} \ No newline at end of file +} diff --git a/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt index 74d20205c..8dabdeeac 100644 --- a/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt @@ -5,11 +5,10 @@ import scientifik.kmath.dimensions.D3 import scientifik.kmath.dimensions.DMatrixContext import kotlin.test.Test - class DMatrixContextTest { @Test fun testDimensionSafeMatrix() { - val res = DMatrixContext.real.run { + val res = with(DMatrixContext.real) { val m = produce { i, j -> (i + j).toDouble() } //The dimension of `one()` is inferred from type @@ -19,7 +18,7 @@ class DMatrixContextTest { @Test fun testTypeCheck() { - val res = DMatrixContext.real.run { + val res = with(DMatrixContext.real) { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i + j).toDouble() } diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index 2b89904e3..811b54d7c 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -14,8 +14,8 @@ import kotlin.math.sqrt typealias RealPoint = Point -fun DoubleArray.asVector() = RealVector(this.asBuffer()) -fun List.asVector() = RealVector(this.asBuffer()) +fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer()) +fun List.asVector(): RealVector = RealVector(this.asBuffer()) object VectorL2Norm : Norm, Double> { override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) @@ -32,15 +32,14 @@ inline class RealVector(private val point: Point) : override val size: Int get() = point.size - override fun get(index: Int): Double = point[index] + override operator fun get(index: Int): Double = point[index] - override fun iterator(): Iterator = point.iterator() + override operator fun iterator(): Iterator = point.iterator() companion object { + private val spaceCache: MutableMap> = hashMapOf() - private val spaceCache = HashMap>() - - inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = + inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector = RealVector(RealBuffer(dim, initializer)) operator fun invoke(vararg values: Double): RealVector = values.asVector() @@ -49,4 +48,4 @@ inline class RealVector(private val point: Point) : BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } } } -} \ No newline at end of file +} diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 65f86eec7..5ccba90a9 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -3,11 +3,14 @@ package scientifik.kmath.real import scientifik.kmath.linear.MatrixContext import scientifik.kmath.linear.RealMatrixContext.elementContext import scientifik.kmath.linear.VirtualMatrix +import scientifik.kmath.operations.invoke import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asIterable +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.pow /* @@ -27,7 +30,7 @@ typealias RealMatrix = Matrix fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) -fun Array.toMatrix(): RealMatrix{ +fun Array.toMatrix(): RealMatrix { return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } } @@ -117,13 +120,17 @@ operator fun Matrix.minus(other: Matrix): RealMatrix = * Operations on columns */ -inline fun Matrix.appendColumn(crossinline mapper: (Buffer) -> Double) = - MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> +@OptIn(ExperimentalContracts::class) +inline fun Matrix.appendColumn(crossinline mapper: (Buffer) -> Double): Matrix { + contract { callsInPlace(mapper) } + + return MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> if (col < colNum) this[row, col] else mapper(rows[row]) } +} fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = MatrixContext.real.produce(rowNum, columnRange.count()) { row, col -> @@ -135,17 +142,15 @@ fun Matrix.extractColumn(columnIndex: Int): RealMatrix = fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> val column = columns[j] - with(elementContext) { - sum(column.asIterable()) - } + elementContext { sum(column.asIterable()) } } fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> - columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") + columns[j].asIterable().min() ?: error("Cannot produce min on empty column") } fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> - columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") + columns[j].asIterable().max() ?: error("Cannot produce min on empty column") } fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> @@ -156,10 +161,7 @@ fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> * Operations processing all elements */ -fun Matrix.sum() = elements().map { (_, value) -> value }.sum() - -fun Matrix.min() = elements().map { (_, value) -> value }.min() - -fun Matrix.max() = elements().map { (_, value) -> value }.max() - -fun Matrix.average() = elements().map { (_, value) -> value }.average() +fun Matrix.sum(): Double = elements().map { (_, value) -> value }.sum() +fun Matrix.min(): Double? = elements().map { (_, value) -> value }.min() +fun Matrix.max(): Double? = elements().map { (_, value) -> value }.max() +fun Matrix.average(): Double = elements().map { (_, value) -> value }.average() diff --git a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt b/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt index 28e62b066..ef7f40afe 100644 --- a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt @@ -1,5 +1,6 @@ package scientifik.kmath.linear +import scientifik.kmath.operations.invoke import scientifik.kmath.real.RealVector import kotlin.test.Test import kotlin.test.assertEquals @@ -24,14 +25,10 @@ class VectorTest { fun testDot() { val vector1 = RealVector(5) { it.toDouble() } val vector2 = RealVector(5) { 5 - it.toDouble() } - val matrix1 = vector1.asMatrix() val matrix2 = vector2.asMatrix().transpose() - val product = MatrixContext.real.run { matrix1 dot matrix2 } - - + val product = MatrixContext.real { matrix1 dot matrix2 } assertEquals(5.0, product[1, 0]) assertEquals(6.0, product[2, 2]) } - -} \ No newline at end of file +} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt index b747b521d..46b1d7c90 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt @@ -2,6 +2,10 @@ package scientifik.kmath.functions import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.math.max import kotlin.math.pow @@ -13,20 +17,21 @@ inline class Polynomial(val coefficients: List) { constructor(vararg coefficients: T) : this(coefficients.toList()) } -fun Polynomial.value() = +fun Polynomial.value(): Double = coefficients.reduceIndexed { index: Int, acc: Double, d: Double -> acc + d.pow(index) } - -fun > Polynomial.value(ring: C, arg: T): T = ring.run { - if (coefficients.isEmpty()) return@run zero +fun > Polynomial.value(ring: C, arg: T): T = ring { + if (coefficients.isEmpty()) return@ring zero var res = coefficients.first() var powerArg = arg + for (index in 1 until coefficients.size) { res += coefficients[index] * powerArg //recalculating power on each step to avoid power costs on long polynomials powerArg *= arg } - return@run res + + res } /** @@ -34,7 +39,7 @@ fun > Polynomial.value(ring: C, arg: T): T = ring.run { */ fun > Polynomial.asMathFunction(): MathFunction = object : MathFunction { - override fun C.invoke(arg: T): T = value(this, arg) + override operator fun C.invoke(arg: T): T = value(this, arg) } /** @@ -49,18 +54,16 @@ class PolynomialSpace>(val ring: C) : Space> override fun add(a: Polynomial, b: Polynomial): Polynomial { val dim = max(a.coefficients.size, b.coefficients.size) - ring.run { - return Polynomial(List(dim) { index -> + + return ring { + Polynomial(List(dim) { index -> a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } }) } } - override fun multiply(a: Polynomial, k: Number): Polynomial { - ring.run { - return Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) - } - } + override fun multiply(a: Polynomial, k: Number): Polynomial = + ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) } override val zero: Polynomial = Polynomial(emptyList()) @@ -68,6 +71,8 @@ class PolynomialSpace>(val ring: C) : Space> operator fun Polynomial.invoke(arg: T): T = value(ring, arg) } -fun , R> C.polynomial(block: PolynomialSpace.() -> R): R { - return PolynomialSpace(this).run(block) -} \ No newline at end of file +@OptIn(ExperimentalContracts::class) +inline fun , R> C.polynomial(block: PolynomialSpace.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return PolynomialSpace(this).block() +} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt index 98beb4391..a7925180d 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt @@ -4,13 +4,13 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial import scientifik.kmath.functions.PiecewisePolynomial import scientifik.kmath.functions.Polynomial import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke /** * Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java */ class LinearInterpolator>(override val algebra: Field) : PolynomialInterpolator { - - override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra.run { + override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { require(points.size > 0) { "Point array should not be empty" } insureSorted(points) @@ -23,4 +23,4 @@ class LinearInterpolator>(override val algebra: Field) : Po } } } -} \ No newline at end of file +} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt index e1af0c1a2..b709c4e87 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt @@ -4,6 +4,7 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial import scientifik.kmath.functions.PiecewisePolynomial import scientifik.kmath.functions.Polynomial import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.MutableBufferFactory /** @@ -17,7 +18,7 @@ class SplineInterpolator>( //TODO possibly optimize zeroed buffers - override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra.run { + override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { if (points.size < 3) { error("Can't use spline interpolator with less than 3 points") } diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt index d8e10b880..56953f9fc 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt @@ -14,9 +14,7 @@ interface XYZPointSet : XYPointSet { } internal fun > insureSorted(points: XYPointSet) { - for (i in 0 until points.size - 1) { - if (points.x[i + 1] <= points.x[i]) error("Input data is not sorted at index $i") - } + for (i in 0 until points.size - 1) require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" } } class NDStructureColumn(val structure: Structure2D, val column: Int) : Buffer { @@ -26,9 +24,9 @@ class NDStructureColumn(val structure: Structure2D, val column: Int) : Buf override val size: Int get() = structure.rowNum - override fun get(index: Int): T = structure[index, column] + override operator fun get(index: Int): T = structure[index, column] - override fun iterator(): Iterator = sequence { + override operator fun iterator(): Iterator = sequence { repeat(size) { yield(get(it)) } diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt index 2313b2170..f0dc49882 100644 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt @@ -9,25 +9,21 @@ import kotlin.math.sqrt interface Vector2D : Point, Vector, SpaceElement { val x: Double val y: Double - + override val context: Euclidean2DSpace get() = Euclidean2DSpace override val size: Int get() = 2 - override fun get(index: Int): Double = when (index) { + override operator fun get(index: Int): Double = when (index) { 1 -> x 2 -> y else -> error("Accessing outside of point bounds") } - override fun iterator(): Iterator = listOf(x, y).iterator() - - override val context: Euclidean2DSpace get() = Euclidean2DSpace - + override operator fun iterator(): Iterator = listOf(x, y).iterator() override fun unwrap(): Vector2D = this - override fun Vector2D.wrap(): Vector2D = this } -val Vector2D.r: Double get() = Euclidean2DSpace.run { sqrt(norm()) } +val Vector2D.r: Double get() = Euclidean2DSpace { sqrt(norm()) } @Suppress("FunctionName") fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y) diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt index dd1776342..3748e58c7 100644 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt @@ -2,6 +2,7 @@ package scientifik.kmath.geometry import scientifik.kmath.linear.Point import scientifik.kmath.operations.SpaceElement +import scientifik.kmath.operations.invoke import kotlin.math.sqrt @@ -9,19 +10,17 @@ interface Vector3D : Point, Vector, SpaceElement x 2 -> y 3 -> z else -> error("Accessing outside of point bounds") } - override fun iterator(): Iterator = listOf(x, y, z).iterator() - - override val context: Euclidean3DSpace get() = Euclidean3DSpace + override operator fun iterator(): Iterator = listOf(x, y, z).iterator() override fun unwrap(): Vector3D = this @@ -31,7 +30,7 @@ interface Vector3D : Point, Vector, SpaceElement { override fun Vector3D.dot(other: Vector3D): Double = x * other.x + y * other.y + z * other.z -} \ No newline at end of file +} diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 43d50ad20..9ff2aacf5 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -4,6 +4,9 @@ import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.RealBuffer +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * The bin in the histogram. The histogram is by definition always done in the real space @@ -37,20 +40,20 @@ interface MutableHistogram> : Histogram { */ fun putWithWeight(point: Point, weight: Double) - fun put(point: Point) = putWithWeight(point, 1.0) + fun put(point: Point): Unit = putWithWeight(point, 1.0) } -fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) +fun MutableHistogram.put(vararg point: T): Unit = put(ArrayBuffer(point)) -fun MutableHistogram.put(vararg point: Number) = +fun MutableHistogram.put(vararg point: Number): Unit = put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(RealBuffer(point)) +fun MutableHistogram.put(vararg point: Double): Unit = put(RealBuffer(point)) -fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } +fun MutableHistogram.fill(sequence: Iterable>): Unit = sequence.forEach { put(it) } /** * Pass a sequence builder into histogram */ -fun MutableHistogram.fill(buider: suspend SequenceScope>.() -> Unit) = - fill(sequence(buider).asIterable()) \ No newline at end of file +fun MutableHistogram.fill(block: suspend SequenceScope>.() -> Unit): Unit = + fill(sequence(block).asIterable()) diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 628a68461..f05ae1694 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -2,6 +2,7 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.operations.invoke import scientifik.kmath.real.asVector import scientifik.kmath.structures.* import kotlin.math.floor @@ -9,19 +10,16 @@ import kotlin.math.floor data class BinDef>(val space: SpaceOperations>, val center: Point, val sizes: Point) { fun contains(vector: Point): Boolean { - if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") - val upper = space.run { center + sizes / 2.0 } - val lower = space.run { center - sizes / 2.0 } - return vector.asSequence().mapIndexed { i, value -> - value in lower[i]..upper[i] - }.all { it } + require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" } + val upper = space { center + sizes / 2.0 } + val lower = space { center - sizes / 2.0 } + return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it } } } class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - - override fun contains(point: Point): Boolean = def.contains(point) + override operator fun contains(point: Point): Boolean = def.contains(point) override val dimension: Int get() = def.center.size @@ -39,47 +37,34 @@ class RealHistogram( private val upper: Buffer, private val binNums: IntArray = IntArray(lower.size) { 20 } ) : MutableHistogram> { - - private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) - private val values: NDStructure = NDStructure.auto(strides) { LongCounter() } - private val weights: NDStructure = NDStructure.auto(strides) { DoubleCounter() } - override val dimension: Int get() = lower.size - - private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks - if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") - if (lower.size != binNums.size) error("Dimension mismatch in bin count.") - if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive") + require(lower.size == upper.size) { "Dimension mismatch in histogram lower and upper limits." } + require(lower.size == binNums.size) { "Dimension mismatch in bin count." } + require(!(0 until dimension).any { upper[it] - lower[it] < 0 }) { "Range for one of axis is not strictly positive" } } /** * Get internal [NDStructure] bin index for given axis */ - private fun getIndex(axis: Int, value: Double): Int { - return when { - value >= upper[axis] -> binNums[axis] + 1 // overflow - value < lower[axis] -> 0 // underflow - else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1 - } + private fun getIndex(axis: Int, value: Double): Int = when { + value >= upper[axis] -> binNums[axis] + 1 // overflow + value < lower[axis] -> 0 // underflow + else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1 } private fun getIndex(point: Buffer): IntArray = IntArray(dimension) { getIndex(it, point[it]) } - private fun getValue(index: IntArray): Long { - return values[index].sum() - } + private fun getValue(index: IntArray): Long = values[index].sum() - fun getValue(point: Buffer): Long { - return getValue(getIndex(point)) - } + fun getValue(point: Buffer): Long = getValue(getIndex(point)) private fun getDef(index: IntArray): BinDef { val center = index.mapIndexed { axis, i -> @@ -89,14 +74,13 @@ class RealHistogram( else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis] } }.asBuffer() + return BinDef(RealBufferFieldOperations, center, binSize) } - fun getDef(point: Buffer): BinDef { - return getDef(getIndex(point)) - } + fun getDef(point: Buffer): BinDef = getDef(getIndex(point)) - override fun get(point: Buffer): MultivariateBin? { + override operator fun get(point: Buffer): MultivariateBin? { val index = getIndex(point) return MultivariateBin(getDef(index), getValue(index)) } @@ -112,26 +96,21 @@ class RealHistogram( weights[index].add(weight) } - override fun iterator(): Iterator> = weights.elements().map { (index, value) -> + override operator fun iterator(): Iterator> = weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }.iterator() /** * Convert this histogram into NDStructure containing bin values but not bin descriptions */ - fun values(): NDStructure { - return NDStructure.auto(values.shape) { values[it].sum() } - } + fun values(): NDStructure = NDStructure.auto(values.shape) { values[it].sum() } /** * Sum of weights */ - fun weights():NDStructure{ - return NDStructure.auto(weights.shape) { weights[it].sum() } - } + fun weights(): NDStructure = NDStructure.auto(weights.shape) { weights[it].sum() } companion object { - /** * Use it like * ``` @@ -141,12 +120,10 @@ class RealHistogram( *) *``` */ - fun fromRanges(vararg ranges: ClosedFloatingPointRange): RealHistogram { - return RealHistogram( - ranges.map { it.start }.asVector(), - ranges.map { it.endInclusive }.asVector() - ) - } + fun fromRanges(vararg ranges: ClosedFloatingPointRange): RealHistogram = RealHistogram( + ranges.map { it.start }.asVector(), + ranges.map { it.endInclusive }.asVector() + ) /** * Use it like @@ -157,13 +134,10 @@ class RealHistogram( *) *``` */ - fun fromRanges(vararg ranges: Pair, Int>): RealHistogram { - return RealHistogram( - ListBuffer(ranges.map { it.first.start }), - ListBuffer(ranges.map { it.first.endInclusive }), - ranges.map { it.second }.toIntArray() - ) - } + fun fromRanges(vararg ranges: Pair, Int>): RealHistogram = RealHistogram( + ListBuffer(ranges.map { it.first.start }), + ListBuffer(ranges.map { it.first.endInclusive }), + ranges.map { it.second }.toIntArray() + ) } - -} \ No newline at end of file +} diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index af01205bf..e30a45f5a 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -46,11 +46,11 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U synchronized(this) { bins.put(it.position, it) } } - override fun get(point: Buffer): UnivariateBin? = get(point[0]) + override operator fun get(point: Buffer): UnivariateBin? = get(point[0]) override val dimension: Int get() = 1 - override fun iterator(): Iterator = bins.values.iterator() + override operator fun iterator(): Iterator = bins.values.iterator() /** * Thread safe put operation @@ -65,15 +65,14 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U } companion object { - fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram { - return UnivariateHistogram { value -> - val center = start + binSize * floor((value - start) / binSize + 0.5) - UnivariateBin(center, binSize) - } + fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value -> + val center = start + binSize * floor((value - start) / binSize + 0.5) + UnivariateBin(center, binSize) } fun custom(borders: DoubleArray): UnivariateHistogram { val sorted = borders.sortedArray() + return UnivariateHistogram { value -> when { value < sorted.first() -> UnivariateBin( diff --git a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt index 10deabd73..bd8fa782a 100644 --- a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt +++ b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt @@ -3,16 +3,16 @@ package scientifik.kmath.linear import koma.extensions.fill import koma.matrix.MatrixFactory import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.NDStructure class KomaMatrixContext( private val factory: MatrixFactory>, private val space: Space -) : - MatrixContext { +) : MatrixContext { - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) = + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix = KomaMatrix(factory.zeros(rows, columns).fill(initializer)) fun Matrix.toKoma(): KomaMatrix = if (this is KomaMatrix) { @@ -28,31 +28,28 @@ class KomaMatrixContext( } - override fun Matrix.dot(other: Matrix) = - KomaMatrix(this.toKoma().origin * other.toKoma().origin) + override fun Matrix.dot(other: Matrix): KomaMatrix = + KomaMatrix(toKoma().origin * other.toKoma().origin) - override fun Matrix.dot(vector: Point) = - KomaVector(this.toKoma().origin * vector.toKoma().origin) + override fun Matrix.dot(vector: Point): KomaVector = + KomaVector(toKoma().origin * vector.toKoma().origin) - override fun Matrix.unaryMinus() = - KomaMatrix(this.toKoma().origin.unaryMinus()) + override operator fun Matrix.unaryMinus(): KomaMatrix = + KomaMatrix(toKoma().origin.unaryMinus()) - override fun add(a: Matrix, b: Matrix) = + override fun add(a: Matrix, b: Matrix): KomaMatrix = KomaMatrix(a.toKoma().origin + b.toKoma().origin) - override fun Matrix.minus(b: Matrix) = - KomaMatrix(this.toKoma().origin - b.toKoma().origin) + override operator fun Matrix.minus(b: Matrix): KomaMatrix = + KomaMatrix(toKoma().origin - b.toKoma().origin) override fun multiply(a: Matrix, k: Number): Matrix = - produce(a.rowNum, a.colNum) { i, j -> space.run { a[i, j] * k } } + produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } - override fun Matrix.times(value: T) = - KomaMatrix(this.toKoma().origin * value) - - companion object { - - } + override operator fun Matrix.times(value: T): KomaMatrix = + KomaMatrix(toKoma().origin * value) + companion object } fun KomaMatrixContext.solve(a: Matrix, b: Matrix) = @@ -70,10 +67,11 @@ class KomaMatrix(val origin: koma.matrix.Matrix, features: Set = features ?: setOf( + override val features: Set = features ?: hashSetOf( object : DeterminantFeature { override val determinant: T get() = origin.det() }, + object : LUPDecompositionFeature { private val lup by lazy { origin.LU() } override val l: FeaturedMatrix get() = KomaMatrix(lup.second) @@ -85,7 +83,7 @@ class KomaMatrix(val origin: koma.matrix.Matrix, features: Set = KomaMatrix(this.origin, this.features + features) - override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) + override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j) override fun equals(other: Any?): Boolean { return NDStructure.equals(this, other as? NDStructure<*> ?: return false) @@ -101,14 +99,12 @@ class KomaMatrix(val origin: koma.matrix.Matrix, features: Set internal constructor(val origin: koma.matrix.Matrix) : Point { - init { - if (origin.numCols() != 1) error("Only single column matrices are allowed") - } - override val size: Int get() = origin.numRows() - override fun get(index: Int): T = origin.getGeneric(index) + init { + require(origin.numCols() == 1) { error("Only single column matrices are allowed") } + } - override fun iterator(): Iterator = origin.toIterable().iterator() + override operator fun get(index: Int): T = origin.getGeneric(index) + override operator fun iterator(): Iterator = origin.toIterable().iterator() } - diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt index a749a7074..9e6ca9bde 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt @@ -1,5 +1,9 @@ package scientifik.memory +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + /** * Represents a display of certain memory structure. */ @@ -80,7 +84,9 @@ interface MemoryReader { /** * Uses the memory for read then releases the reader. */ +@OptIn(ExperimentalContracts::class) inline fun Memory.read(block: MemoryReader.() -> Unit) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } reader().apply(block).release() } @@ -132,7 +138,9 @@ interface MemoryWriter { /** * Uses the memory for write then releases the writer. */ +@OptIn(ExperimentalContracts::class) inline fun Memory.write(block: MemoryWriter.() -> Unit) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } writer().apply(block).release() } diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt index 59a93f290..1381afbec 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt @@ -38,11 +38,7 @@ fun MemoryWriter.write(spec: MemorySpec, offset: Int, value: T): Un * Reads array of [size] objects mapped by [spec] at certain [offset]. */ inline fun MemoryReader.readArray(spec: MemorySpec, offset: Int, size: Int): Array = - Array(size) { i -> - spec.run { - read(offset + i * objectSize) - } - } + Array(size) { i -> with(spec) { read(offset + i * objectSize) } } /** * Writes [array] of objects mapped by [spec] at certain [offset]. diff --git a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt b/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt index b5a0dd51b..73ad7deec 100644 --- a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt +++ b/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt @@ -1,12 +1,17 @@ package scientifik.memory +import java.io.IOException import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.file.Files import java.nio.file.Path import java.nio.file.StandardOpenOption +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract -private class ByteBufferMemory( +@PublishedApi +internal class ByteBufferMemory( val buffer: ByteBuffer, val startOffset: Int = 0, override val size: Int = buffer.limit() @@ -112,7 +117,12 @@ fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory = /** * Uses direct memory-mapped buffer from file to read something and close it afterwards. */ -fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R = - FileChannel.open(this, StandardOpenOption.READ).use { - ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() - } +@OptIn(ExperimentalContracts::class) +@Throws(IOException::class) +inline fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + + return FileChannel + .open(this, StandardOpenOption.READ) + .use { ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt index 47fc6e4c5..fe5dc8c8c 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt @@ -1,6 +1,7 @@ package scientifik.kmath.prob import scientifik.kmath.chains.Chain +import kotlin.contracts.ExperimentalContracts /** * A possibly stateful chain producing random values. @@ -11,4 +12,4 @@ class RandomChain(val generator: RandomGenerator, private val gen: suspen override fun fork(): Chain = RandomChain(generator.fork(), gen) } -fun RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, gen) \ No newline at end of file +fun RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, gen) diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt index 3a60c0bda..7699df4ee 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt @@ -5,6 +5,7 @@ import scientifik.kmath.chains.ConstantChain import scientifik.kmath.chains.map import scientifik.kmath.chains.zip import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke class BasicSampler(val chainBuilder: (RandomGenerator) -> Chain) : Sampler { override fun sample(generator: RandomGenerator): Chain = chainBuilder(generator) @@ -22,10 +23,10 @@ class SamplerSpace(val space: Space) : Space> { override val zero: Sampler = ConstantSampler(space.zero) override fun add(a: Sampler, b: Sampler): Sampler = BasicSampler { generator -> - a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } } + a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } } } override fun multiply(a: Sampler, k: Number): Sampler = BasicSampler { generator -> - a.sample(generator).map { space.run { it * k.toDouble() } } + a.sample(generator).map { space { it * k.toDouble() } } } } \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt index 804aed089..0bc9c1565 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt @@ -29,8 +29,10 @@ interface Statistic { interface ComposableStatistic : Statistic { //compute statistic on a single block suspend fun computeIntermediate(data: Buffer): I + //Compose two blocks suspend fun composeIntermediate(first: I, second: I): I + //Transform block to result suspend fun toResult(intermediate: I): R @@ -58,26 +60,26 @@ private fun ComposableStatistic.flowIntermediate( fun ComposableStatistic.flow( flow: Flow>, dispatcher: CoroutineDispatcher = Dispatchers.Default -): Flow = flowIntermediate(flow,dispatcher).map(::toResult) +): Flow = flowIntermediate(flow, dispatcher).map(::toResult) /** * Arithmetic mean */ class Mean(val space: Space) : ComposableStatistic, T> { override suspend fun computeIntermediate(data: Buffer): Pair = - space.run { sum(data.asIterable()) } to data.size + space { sum(data.asIterable()) } to data.size override suspend fun composeIntermediate(first: Pair, second: Pair): Pair = - space.run { first.first + second.first } to (first.second + second.second) + space { first.first + second.first } to (first.second + second.second) override suspend fun toResult(intermediate: Pair): T = - space.run { intermediate.first / intermediate.second } + space { intermediate.first / intermediate.second } companion object { //TODO replace with optimized version which respects overflow - val real = Mean(RealField) - val int = Mean(IntRing) - val long = Mean(LongRing) + val real: Mean = Mean(RealField) + val int: Mean = Mean(IntRing) + val long: Mean = Mean(LongRing) } } @@ -85,11 +87,10 @@ class Mean(val space: Space) : ComposableStatistic, T> { * Non-composable median */ class Median(private val comparator: Comparator) : Statistic { - override suspend fun invoke(data: Buffer): T { - return data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct - } + override suspend fun invoke(data: Buffer): T = + data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct companion object { - val real = Median(Comparator { a: Double, b: Double -> a.compareTo(b) }) + val real: Median = Median(Comparator { a: Double, b: Double -> a.compareTo(b) }) } } \ No newline at end of file diff --git a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt b/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt index 040eee951..e7bf173c3 100644 --- a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt +++ b/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt @@ -16,5 +16,5 @@ inline class ViktorBuffer(val flatArray: F64FlatArray) : MutableBuffer { return ViktorBuffer(flatArray.copy().flatten()) } - override fun iterator(): Iterator = flatArray.data.iterator() + override operator fun iterator(): Iterator = flatArray.data.iterator() } \ No newline at end of file