diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c99c7bc1..4bcc57810 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,18 @@ # KMath ## [Unreleased] +### Added + +### Changed + +### Deprecated + +### Removed + +### Fixed + +### Security +## [0.1.4] ### Added - Functional Expressions API @@ -16,17 +28,23 @@ - Local coding conventions - Geometric Domains API in `kmath-core` - Blocking chains in `kmath-coroutines` +- Full hyperbolic functions support and default implementations within `ExtendedField` +- Norm support for `Complex` ### Changed +- `readAsMemory` now has `throws IOException` in JVM signature. +- Several functions taking functional types were made `inline`. +- Several functions taking functional types now have `callsInPlace` contracts. - BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations - `power(T, Int)` extension function has preconditions and supports `Field` - Memory objects have more preconditions (overflow checking) - `tg` function is renamed to `tan` (https://github.com/mipt-npm/kmath/pull/114) -- Gradle version: 6.3 -> 6.5.1 -- Moved probability distributions to commons-rng and to `kmath-prob`. +- Gradle version: 6.3 -> 6.6 +- Moved probability distributions to commons-rng and to `kmath-prob` ### Fixed - Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106) - D3.dim value in `kmath-dimensions` - Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101) - Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93) +- Multiplication of BigInt by scalar diff --git a/build.gradle.kts b/build.gradle.kts index 8a2ba3617..b24ecd15b 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,8 +1,9 @@ plugins { id("scientifik.publish") apply false + id("org.jetbrains.changelog") version "0.4.0" } -val kmathVersion by extra("0.1.4-dev-8") +val kmathVersion by extra("0.1.4") val bintrayRepo by extra("scientifik") val githubProject by extra("kmath") @@ -14,8 +15,18 @@ allprojects { maven("https://dl.bintray.com/hotkeytlt/maven") } - group = "scientifik" + group = "kscience.kmath" version = kmathVersion + + afterEvaluate { + extensions.findByType()?.run { + targets.all { + sourceSets.all { + languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") + } + } + } + } } subprojects { diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 73def3572..f5a4d5831 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -56,9 +56,16 @@ benchmark { } } +kotlin.sourceSets.all { + with(languageSettings) { + useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") + useExperimentalAnnotation("kotlin.ExperimentalUnsignedTypes") + } +} tasks.withType { kotlinOptions { jvmTarget = Scientifik.JVM_TARGET.toString() + freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn" } -} \ No newline at end of file +} diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt index ae27620f7..46da6c6d8 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt @@ -4,46 +4,38 @@ import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke @State(Scope.Benchmark) class NDFieldBenchmark { - @Benchmark fun autoFieldAdd() { - bufferedField.run { + bufferedField { var res: NDBuffer = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } @Benchmark fun autoElementAdd() { var res = genericField.one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } @Benchmark fun specializedFieldAdd() { - specializedField.run { + specializedField { var res: NDBuffer = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } } @Benchmark fun boxingFieldAdd() { - genericField.run { + genericField { var res: NDBuffer = one - repeat(n) { - res += one - } + repeat(n) { res += one } } } diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt index f7b9661ef..9627743c9 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt @@ -5,23 +5,22 @@ import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import scientifik.kmath.viktor.ViktorNDField - @State(Scope.Benchmark) class ViktorBenchmark { final val dim = 1000 final val n = 100 // automatically build context most suited for given type. - final val autoField = NDField.auto(RealField, dim, dim) - final val realField = NDField.real(dim, dim) - - final val viktorField = ViktorNDField(intArrayOf(dim, dim)) + final val autoField: BufferedNDField = NDField.auto(RealField, dim, dim) + final val realField: RealNDField = NDField.real(dim, dim) + final val viktorField: ViktorNDField = ViktorNDField(intArrayOf(dim, dim)) @Benchmark fun automaticFieldAddition() { - autoField.run { + autoField { var res = one repeat(n) { res += one } } @@ -29,7 +28,7 @@ class ViktorBenchmark { @Benchmark fun viktorFieldAddition() { - viktorField.run { + viktorField { var res = one repeat(n) { res += one } } @@ -44,7 +43,7 @@ class ViktorBenchmark { @Benchmark fun realdFieldLog() { - realField.run { + realField { val fortyTwo = produce { 42.0 } var res = one repeat(n) { res = ln(fortyTwo) } diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt index 6ec9e9c17..3b0d56291 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt +++ b/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt @@ -1,8 +1,11 @@ package scientifik.kmath.utils +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.system.measureTimeMillis internal inline fun measureAndPrint(title: String, block: () -> Unit) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } val time = measureTimeMillis(block) println("$title completed in $time millis") -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt index 960f03de3..6cc5411b8 100644 --- a/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt @@ -5,6 +5,7 @@ import scientifik.kmath.commons.linear.CMMatrixContext import scientifik.kmath.commons.linear.inverse import scientifik.kmath.commons.linear.toCM import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import kotlin.contracts.ExperimentalContracts import kotlin.random.Random @@ -21,29 +22,18 @@ fun main() { val n = 5000 // iterations - MatrixContext.real.run { - - repeat(50) { - val res = inverse(matrix) - } - - val inverseTime = measureTimeMillis { - repeat(n) { - val res = inverse(matrix) - } - } - + MatrixContext.real { + repeat(50) { val res = inverse(matrix) } + val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } } println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis") } //commons-math val commonsTime = measureTimeMillis { - CMMatrixContext.run { + CMMatrixContext { val cm = matrix.toCM() //avoid overhead on conversion - repeat(n) { - val res = inverse(cm) - } + repeat(n) { val res = inverse(cm) } } } @@ -53,7 +43,7 @@ fun main() { //koma-ejml val komaTime = measureTimeMillis { - KomaMatrixContext(EJMLMatrixFactory(), RealField).run { + (KomaMatrixContext(EJMLMatrixFactory(), RealField)) { val km = matrix.toKoma() //avoid overhead on conversion repeat(n) { val res = inverse(km) diff --git a/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt index 03bd0001c..3ae550682 100644 --- a/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt +++ b/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt @@ -4,6 +4,7 @@ import koma.matrix.ejml.EJMLMatrixFactory import scientifik.kmath.commons.linear.CMMatrixContext import scientifik.kmath.commons.linear.toCM import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import kotlin.random.Random import kotlin.system.measureTimeMillis @@ -18,7 +19,7 @@ fun main() { // //warmup // matrix1 dot matrix2 - CMMatrixContext.run { + CMMatrixContext { val cmMatrix1 = matrix1.toCM() val cmMatrix2 = matrix2.toCM() @@ -29,8 +30,7 @@ fun main() { println("CM implementation time: $cmTime") } - - KomaMatrixContext(EJMLMatrixFactory(), RealField).run { + (KomaMatrixContext(EJMLMatrixFactory(), RealField)) { val komaMatrix1 = matrix1.toKoma() val komaMatrix2 = matrix2.toKoma() diff --git a/examples/src/main/kotlin/scientifik/kmath/operations/BigIntDemo.kt b/examples/src/main/kotlin/scientifik/kmath/operations/BigIntDemo.kt new file mode 100644 index 000000000..10b038943 --- /dev/null +++ b/examples/src/main/kotlin/scientifik/kmath/operations/BigIntDemo.kt @@ -0,0 +1,8 @@ +package scientifik.kmath.operations + +fun main() { + val res = BigIntField { + number(1) * 2 + } + println("bigint:$res") +} \ No newline at end of file diff --git a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt index 4841f9dd8..6dbfebce1 100644 --- a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt +++ b/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt @@ -9,13 +9,11 @@ fun main() { Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) } - - val compute = NDField.complex(8).run { + val compute = (NDField.complex(8)) { val a = produce { (it) -> i * it - it.toDouble() } val b = 3 val c = Complex(1.0, 1.0) (a pow b) + c } - -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt index 991cd34a1..2329f3fc3 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt @@ -13,9 +13,8 @@ fun main() { val realField = NDField.real(dim, dim) val complexField = NDField.complex(dim, dim) - val realTime = measureTimeMillis { - realField.run { + realField { var res: NDBuffer = one repeat(n) { res += 1.0 @@ -26,18 +25,15 @@ fun main() { println("Real addition completed in $realTime millis") val complexTime = measureTimeMillis { - complexField.run { + complexField { var res: NDBuffer = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } } println("Complex addition completed in $complexTime millis") } - fun complexExample() { //Create a context for 2-d structure with complex values ComplexField { @@ -46,10 +42,7 @@ fun complexExample() { val x = one * 2.5 operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im) //a structure generator specific to this context - val matrix = produce { (k, l) -> - k + l * i - } - + val matrix = produce { (k, l) -> k + l * i } //Perform sum val sum = matrix + x + 1.0 diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt index 2aafb504d..1bc0ed7c8 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt @@ -2,14 +2,18 @@ package scientifik.kmath.structures import kotlinx.coroutines.GlobalScope import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.system.measureTimeMillis internal inline fun measureAndPrint(title: String, block: () -> Unit) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } val time = measureTimeMillis(block) println("$title completed in $time millis") } - fun main() { val dim = 1000 val n = 1000 @@ -22,27 +26,21 @@ fun main() { val genericField = NDField.boxing(RealField, dim, dim) measureAndPrint("Automatic field addition") { - autoField.run { + autoField { var res: NDBuffer = one - repeat(n) { - res += number(1.0) - } + repeat(n) { res += number(1.0) } } } measureAndPrint("Element addition") { var res = genericField.one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } measureAndPrint("Specialized addition") { - specializedField.run { + specializedField { var res: NDBuffer = one - repeat(n) { - res += 1.0 - } + repeat(n) { res += 1.0 } } } @@ -60,12 +58,11 @@ fun main() { measureAndPrint("Generic addition") { //genericField.run(action) - genericField.run { + genericField { var res: NDBuffer = one repeat(n) { - res += one // con't avoid using `one` due to resolution ambiguity + res += one // couldn't avoid using `one` due to resolution ambiguity } } } } - -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt index fdc09ed5d..5d323823a 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt @@ -23,13 +23,10 @@ fun DMatrixContext.custom() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i - j).toDouble() } val m3 = produce { i, j -> (i - j).toDouble() } - (m1 dot m2) + m3 } -fun main() { - DMatrixContext.real.run { - simple() - custom() - } -} \ No newline at end of file +fun main(): Unit = with(DMatrixContext.real) { + simple() + custom() +} diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index d13a7712d..86b10bdc7 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -1,11 +1,8 @@ plugins { id("scientifik.mpp") } kotlin.sourceSets { -// all { -// languageSettings.apply{ -// enableLanguageFeature("NewInference") -// } -// } + all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") } + commonMain { dependencies { api(project(":kmath-core")) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt index b47c7cae8..23deae24b 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt @@ -84,9 +84,9 @@ object MstExtendedField : ExtendedField { override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun asin(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg) + override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) override fun add(a: MST, b: MST): MST = MstField.add(a, b) override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt index 59f3f15d8..3cee33956 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -2,6 +2,9 @@ package scientifik.kmath.ast import scientifik.kmath.expressions.* import scientifik.kmath.operations.* +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than @@ -24,7 +27,7 @@ class MstExpression(val algebra: Algebra, val mst: MST) : Expression { error("Numeric nodes are not supported by $this") } - override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) + override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) } /** @@ -38,51 +41,63 @@ inline fun , E : Algebra> A.mst( /** * Builds [MstExpression] over [Space]. */ -inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = - MstExpression(this, MstSpace.block()) +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstSpace.block()) +} /** * Builds [MstExpression] over [Ring]. */ -inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = - MstExpression(this, MstRing.block()) +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstRing.block()) +} /** * Builds [MstExpression] over [Field]. */ -inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = - MstExpression(this, MstField.block()) +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstField.block()) +} /** * Builds [MstExpression] over [ExtendedField]. */ -inline fun Field.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression = - MstExpression(this, MstExtendedField.block()) +inline fun Field.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstExtendedField.block()) +} /** * Builds [MstExpression] over [FunctionalExpressionSpace]. */ -inline fun > FunctionalExpressionSpace.mstInSpace( - block: MstSpace.() -> MST -): MstExpression = algebra.mstInSpace(block) +inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInSpace(block) +} /** * Builds [MstExpression] over [FunctionalExpressionRing]. */ -inline fun > FunctionalExpressionRing.mstInRing( - block: MstRing.() -> MST -): MstExpression = algebra.mstInRing(block) +inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInRing(block) +} /** * Builds [MstExpression] over [FunctionalExpressionField]. */ -inline fun > FunctionalExpressionField.mstInField( - block: MstField.() -> MST -): MstExpression = algebra.mstInField(block) +inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInField(block) +} /** * Builds [MstExpression] over [FunctionalExpressionExtendedField]. */ -inline fun > FunctionalExpressionExtendedField.mstInExtendedField( - block: MstExtendedField.() -> MST -): MstExpression = algebra.mstInExtendedField(block) +inline fun > FunctionalExpressionExtendedField.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInExtendedField(block) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt index a637289b8..6f51fe855 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -7,6 +7,9 @@ import scientifik.kmath.ast.MST import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import java.lang.reflect.Method +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.reflect.KClass private val methodNameAdapters: Map, String> by lazy { @@ -26,8 +29,10 @@ internal val KClass<*>.asm: Type /** * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. */ -internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array = - if (predicate(this)) arrayOf(this) else emptyArray() +internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array { + contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) } + return if (predicate(this)) arrayOf(this) else emptyArray() +} /** * Creates an [InstructionAdapter] from this [MethodVisitor]. @@ -37,8 +42,10 @@ private fun MethodVisitor.instructionAdapter(): InstructionAdapter = Instruction /** * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. */ -internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = - instructionAdapter().apply(block) +internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return instructionAdapter().apply(block) +} /** * Constructs a [Label], then applies it to this visitor. @@ -64,8 +71,10 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String { } @Suppress("FunctionName") -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = - ClassWriter(flags).apply(block) +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return ClassWriter(flags).apply(block) +} internal inline fun ClassWriter.visitField( access: Int, @@ -74,7 +83,10 @@ internal inline fun ClassWriter.visitField( signature: String?, value: Any?, block: FieldVisitor.() -> Unit -): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) +): FieldVisitor { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return visitField(access, name, descriptor, signature, value).apply(block) +} private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = context.javaClass.methods.find { method -> @@ -158,6 +170,7 @@ internal inline fun AsmBuilder.buildAlgebraOperationCall( parameterTypes: Array, parameters: AsmBuilder.() -> Unit ) { + contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } val arity = parameterTypes.size loadAlgebra() if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) diff --git a/kmath-commons/build.gradle.kts b/kmath-commons/build.gradle.kts index 5ce1b935a..63c832b7c 100644 --- a/kmath-commons/build.gradle.kts +++ b/kmath-commons/build.gradle.kts @@ -1,7 +1,4 @@ -plugins { - id("scientifik.jvm") -} - +plugins { id("scientifik.jvm") } description = "Commons math binding for kmath" dependencies { @@ -10,4 +7,6 @@ dependencies { api(project(":kmath-prob")) api(project(":kmath-functions")) api("org.apache.commons:commons-math3:3.6.1") -} \ No newline at end of file +} + +kotlin.sourceSets.all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 54c404f57..9119991e5 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -5,6 +5,7 @@ import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -15,26 +16,22 @@ class DerivativeStructureField( val order: Int, val parameters: Map ) : ExtendedField { - override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } - override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } private val variables: Map = parameters.mapValues { (key, value) -> DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) } - val variable = object : ReadOnlyProperty { - override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure { - return variables[property.name] ?: error("A variable with name ${property.name} does not exist") - } + val variable: ReadOnlyProperty = object : ReadOnlyProperty { + override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure = + variables[property.name] ?: error("A variable with name ${property.name} does not exist") } fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure = variables[name] ?: default ?: error("A variable with name $name does not exist") - - fun Number.const() = DerivativeStructure(order, parameters.size, toDouble()) + fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble()) fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { return deriv(mapOf(parName to order)) @@ -60,10 +57,18 @@ class DerivativeStructureField( override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() + override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh() + override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh() + override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh() + override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh() + override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh() + override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh() + override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { is Double -> arg.pow(pow) is Int -> arg.pow(pow) @@ -71,23 +76,20 @@ class DerivativeStructureField( } fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) - override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() - override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) - override operator fun Number.plus(b: DerivativeStructure) = b + this - override operator fun Number.minus(b: DerivativeStructure) = b - this + override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this + override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this } /** * A constructs that creates a derivative structure with required order on-demand */ class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression { - - override fun invoke(arguments: Map): Double = DerivativeStructureField( + override operator fun invoke(arguments: Map): Double = DerivativeStructureField( 0, arguments ).run(function).value @@ -96,45 +98,40 @@ class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStru * Get the derivative expression with given orders * TODO make result [DiffExpression] */ - fun derivative(orders: Map): Expression { - return object : Expression { - override fun invoke(arguments: Map): Double = - DerivativeStructureField(orders.values.max() ?: 0, arguments) - .run { - function().deriv(orders) - } - } + fun derivative(orders: Map): Expression = object : Expression { + override operator fun invoke(arguments: Map): Double = + (DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) } } //TODO add gradient and maybe other vector operators } -fun DiffExpression.derivative(vararg orders: Pair) = derivative(mapOf(*orders)) -fun DiffExpression.derivative(name: String) = derivative(name to 1) +fun DiffExpression.derivative(vararg orders: Pair): Expression = derivative(mapOf(*orders)) +fun DiffExpression.derivative(name: String): Expression = derivative(name to 1) /** * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) */ object DiffExpressionAlgebra : ExpressionAlgebra, Field { - override fun variable(name: String, default: Double?) = + override fun variable(name: String, default: Double?): DiffExpression = DiffExpression { variable(name, default?.const()) } override fun const(value: Double): DiffExpression = DiffExpression { value.const() } - override fun add(a: DiffExpression, b: DiffExpression) = + override fun add(a: DiffExpression, b: DiffExpression): DiffExpression = DiffExpression { a.function(this) + b.function(this) } - override val zero = DiffExpression { 0.0.const() } + override val zero: DiffExpression = DiffExpression { 0.0.const() } - override fun multiply(a: DiffExpression, k: Number) = + override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k } - override val one = DiffExpression { 1.0.const() } + override val one: DiffExpression = DiffExpression { 1.0.const() } - override fun multiply(a: DiffExpression, b: DiffExpression) = + override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression = DiffExpression { a.function(this) * b.function(this) } - override fun divide(a: DiffExpression, b: DiffExpression) = + override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression = DiffExpression { a.function(this) / b.function(this) } } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt index a17effccc..f0bbdbe65 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt @@ -1,8 +1,6 @@ package scientifik.kmath.commons.linear import org.apache.commons.math3.linear.* -import org.apache.commons.math3.linear.RealMatrix -import org.apache.commons.math3.linear.RealVector import scientifik.kmath.linear.* import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.NDStructure @@ -14,12 +12,12 @@ class CMMatrix(val origin: RealMatrix, features: Set? = null) : override val features: Set = features ?: sequence { if (origin is DiagonalMatrix) yield(DiagonalFeature) - }.toSet() + }.toHashSet() - override fun suggestFeature(vararg features: MatrixFeature) = + override fun suggestFeature(vararg features: MatrixFeature): CMMatrix = CMMatrix(origin, this.features + features) - override fun get(i: Int, j: Int): Double = origin.getEntry(i, j) + override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) override fun equals(other: Any?): Boolean { return NDStructure.equals(this, other as? NDStructure<*> ?: return false) @@ -40,24 +38,22 @@ fun Matrix.toCM(): CMMatrix = if (this is CMMatrix) { CMMatrix(Array2DRowRealMatrix(array)) } -fun RealMatrix.asMatrix() = CMMatrix(this) +fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this) class CMVector(val origin: RealVector) : Point { override val size: Int get() = origin.dimension - override fun get(index: Int): Double = origin.getEntry(index) + override operator fun get(index: Int): Double = origin.getEntry(index) - override fun iterator(): Iterator = origin.toArray().iterator() + override operator fun iterator(): Iterator = origin.toArray().iterator() } -fun Point.toCM(): CMVector = if (this is CMVector) { - this -} else { +fun Point.toCM(): CMVector = if (this is CMVector) this else { val array = DoubleArray(size) { this[it] } CMVector(ArrayRealVector(array)) } -fun RealVector.toPoint() = CMVector(this) +fun RealVector.toPoint(): CMVector = CMVector(this) object CMMatrixContext : MatrixContext { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { @@ -65,32 +61,33 @@ object CMMatrixContext : MatrixContext { return CMMatrix(Array2DRowRealMatrix(array)) } - override fun Matrix.dot(other: Matrix) = + override fun Matrix.dot(other: Matrix): CMMatrix = CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) override fun Matrix.dot(vector: Point): CMVector = CMVector(this.toCM().origin.preMultiply(vector.toCM().origin)) - override fun Matrix.unaryMinus(): CMMatrix = + override operator fun Matrix.unaryMinus(): CMMatrix = produce(rowNum, colNum) { i, j -> -get(i, j) } - override fun add(a: Matrix, b: Matrix) = + override fun add(a: Matrix, b: Matrix): CMMatrix = CMMatrix(a.toCM().origin.multiply(b.toCM().origin)) - override fun Matrix.minus(b: Matrix) = + override operator fun Matrix.minus(b: Matrix): CMMatrix = CMMatrix(this.toCM().origin.subtract(b.toCM().origin)) - override fun multiply(a: Matrix, k: Number) = + override fun multiply(a: Matrix, k: Number): CMMatrix = CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble())) - override fun Matrix.times(value: Double): Matrix = + override operator fun Matrix.times(value: Double): Matrix = produce(rowNum, colNum) { i, j -> get(i, j) * value } } operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin)) + operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin)) infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = - CMMatrix(this.origin.multiply(other.origin)) \ No newline at end of file + CMMatrix(this.origin.multiply(other.origin)) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt index 13e79d60e..cb2b5dd9c 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -4,10 +4,9 @@ import scientifik.kmath.prob.RandomGenerator class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : org.apache.commons.math3.random.RandomGenerator { - private var generator = factory(intArrayOf()) + private var generator: RandomGenerator = factory(intArrayOf()) override fun nextBoolean(): Boolean = generator.nextBoolean() - override fun nextFloat(): Float = generator.nextDouble().toFloat() override fun setSeed(seed: Int) { @@ -27,12 +26,8 @@ class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : } override fun nextInt(): Int = generator.nextInt() - override fun nextInt(n: Int): Int = generator.nextInt(n) - override fun nextGaussian(): Double = TODO() - override fun nextDouble(): Double = generator.nextDouble() - override fun nextLong(): Long = generator.nextLong() -} \ No newline at end of file +} diff --git a/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt index c77f30eb7..bbdcff2fc 100644 --- a/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt +++ b/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt @@ -1,11 +1,15 @@ package scientifik.kmath.commons.expressions import scientifik.kmath.expressions.invoke +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.test.Test import kotlin.test.assertEquals -inline fun diff(order: Int, vararg parameters: Pair, block: DerivativeStructureField.() -> R) = - DerivativeStructureField(order, mapOf(*parameters)).run(block) +inline fun diff(order: Int, vararg parameters: Pair, block: DerivativeStructureField.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return DerivativeStructureField(order, mapOf(*parameters)).run(block) +} class AutoDiffTest { @Test diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index bea0fbf42..18c0cc771 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -1,7 +1,11 @@ -plugins { id("scientifik.mpp") } +plugins { + id("scientifik.mpp") +} kotlin.sourceSets { commonMain { - dependencies { api(project(":kmath-memory")) } + dependencies { + api(project(":kmath-memory")) + } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt index 8cd6e28f8..8d0b82a89 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -4,28 +4,38 @@ import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * Creates a functional expression with this [Space]. */ -fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = - FunctionalExpressionSpace(this).run(block) +inline fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionSpace(this).block() +} /** * Creates a functional expression with this [Ring]. */ -fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = - FunctionalExpressionRing(this).run(block) +inline fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionRing(this).block() +} /** * Creates a functional expression with this [Field]. */ -fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression = - FunctionalExpressionField(this).run(block) +inline fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionField(this).block() +} /** * Creates a functional expression with this [ExtendedField]. */ -fun ExtendedField.fieldExpression( - block: FunctionalExpressionExtendedField>.() -> Expression -): Expression = FunctionalExpressionExtendedField(this).run(block) +inline fun ExtendedField.extendedFieldExpression(block: FunctionalExpressionExtendedField>.() -> Expression): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionExtendedField(this).block() +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 380822f78..fd11c246d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -22,7 +22,7 @@ interface Expression { */ fun Algebra.expression(block: Algebra.(arguments: Map) -> T): Expression = object : Expression { - override fun invoke(arguments: Map): T = block(arguments) + override operator fun invoke(arguments: Map): T = block(arguments) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt index dd5fb572a..d36c31a0d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -4,7 +4,7 @@ import scientifik.kmath.operations.* internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : Expression { - override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) + override operator fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) } internal class FunctionalBinaryOperation( @@ -13,17 +13,17 @@ internal class FunctionalBinaryOperation( val first: Expression, val second: Expression ) : Expression { - override fun invoke(arguments: Map): T = + override operator fun invoke(arguments: Map): T = context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) } internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { - override fun invoke(arguments: Map): T = + override operator fun invoke(arguments: Map): T = arguments[name] ?: default ?: error("Parameter not found: $name") } internal class FunctionalConstantExpression(val value: T) : Expression { - override fun invoke(arguments: Map): T = value + override operator fun invoke(arguments: Map): T = value } internal class FunctionalConstProductExpression( @@ -31,7 +31,7 @@ internal class FunctionalConstProductExpression( private val expr: Expression, val const: Number ) : Expression { - override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) + override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } /** @@ -139,15 +139,9 @@ open class FunctionalExpressionExtendedField(algebra: A) : ExtendedField> where A : ExtendedField, A : NumericAlgebra { override fun sin(arg: Expression): Expression = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun cos(arg: Expression): Expression = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - - override fun asin(arg: Expression): Expression = - unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) - - override fun acos(arg: Expression): Expression = - unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg) - - override fun atan(arg: Expression): Expression = - unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg) + override fun asin(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + override fun acos(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + override fun atan(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) override fun power(arg: Expression, pow: Number): Expression = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 2e1f32501..343b8287e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -53,16 +53,12 @@ class BufferMatrix( override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix = BufferMatrix(rowNum, colNum, buffer, this.features + features) - override fun get(index: IntArray): T = get(index[0], index[1]) + override operator fun get(index: IntArray): T = get(index[0], index[1]) - override fun get(i: Int, j: Int): T = buffer[i * colNum + j] + override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] override fun elements(): Sequence> = sequence { - for (i in 0 until rowNum) { - for (j in 0 until colNum) { - yield(intArrayOf(i, j) to get(i, j)) - } - } + for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) } override fun equals(other: Any?): Boolean { @@ -95,7 +91,7 @@ class BufferMatrix( * Optimized dot product for real matrices */ infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { - if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } val array = DoubleArray(this.rowNum * other.colNum) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt index dd17c4fe7..9b60bf719 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt @@ -4,6 +4,8 @@ import scientifik.kmath.operations.Ring import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Structure2D import scientifik.kmath.structures.asBuffer +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.sqrt /** @@ -26,15 +28,17 @@ interface FeaturedMatrix : Matrix { companion object } -fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix = - MatrixContext.real.produce(rows, columns, initializer) +inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix { + contract { callsInPlace(initializer) } + return MatrixContext.real.produce(rows, columns, initializer) +} /** * Build a square matrix from given elements. */ fun Structure2D.Companion.square(vararg elements: T): FeaturedMatrix { val size: Int = sqrt(elements.size.toDouble()).toInt() - if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") + require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } val buffer = elements.asBuffer() return BufferMatrix(size, size, buffer) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index d04a99fbb..f3e4f648f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -3,6 +3,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Field import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.BufferAccessor2D import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Structure2D @@ -60,15 +61,13 @@ class LUPDecomposition( * @return determinant of the matrix */ override val determinant: T by lazy { - with(elementContext) { - (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } - } + elementContext { (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } } } } fun , F : Field> GenericMatrixContext.abs(value: T): T = - if (value > elementContext.zero) value else with(elementContext) { -value } + if (value > elementContext.zero) value else elementContext { -value } /** @@ -88,43 +87,34 @@ fun , F : Field> GenericMatrixContext.lup( //TODO just waits for KEEP-176 BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { - elementContext.run { - + elementContext { val lu = create(matrix) // Initialize permutation array and parity - for (row in 0 until m) { - pivot[row] = row - } + for (row in 0 until m) pivot[row] = row var even = true // Initialize permutation array and parity - for (row in 0 until m) { - pivot[row] = row - } + for (row in 0 until m) pivot[row] = row // Loop over columns for (col in 0 until m) { - // upper for (row in 0 until col) { val luRow = lu.row(row) var sum = luRow[col] - for (i in 0 until row) { - sum -= luRow[i] * lu[i, col] - } + for (i in 0 until row) sum -= luRow[i] * lu[i, col] luRow[col] = sum } // lower var max = col // permutation row var largest = -one + for (row in col until m) { val luRow = lu.row(row) var sum = luRow[col] - for (i in 0 until col) { - sum -= luRow[i] * lu[i, col] - } + for (i in 0 until col) sum -= luRow[i] * lu[i, col] luRow[col] = sum // maintain best permutation choice @@ -135,19 +125,19 @@ fun , F : Field> GenericMatrixContext.lup( } // Singularity check - if (checkSingular(this@lup.abs(lu[max, col]))) { - error("The matrix is singular") - } + check(!checkSingular(this@lup.abs(lu[max, col]))) { "The matrix is singular" } // Pivot if necessary if (max != col) { val luMax = lu.row(max) val luCol = lu.row(col) + for (i in 0 until m) { val tmp = luMax[i] luMax[i] = luCol[i] luCol[i] = tmp } + val temp = pivot[max] pivot[max] = pivot[col] pivot[col] = temp @@ -156,9 +146,7 @@ fun , F : Field> GenericMatrixContext.lup( // Divide the lower elements by the "winning" diagonal elt. val luDiag = lu[col, col] - for (row in col + 1 until m) { - lu[row, col] /= luDiag - } + for (row in col + 1 until m) lu[row, col] /= luDiag } return LUPDecomposition(this@lup, lu.collect(), pivot, even) @@ -175,28 +163,23 @@ fun GenericMatrixContext.lup(matrix: Matrix): LUPDeco lup(Double::class, matrix) { it < 1e-11 } fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Matrix { - - if (matrix.rowNum != pivot.size) { - error("Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}") - } + require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { - elementContext.run { - + elementContext { // Apply permutations to b val bp = create { _, _ -> zero } for (row in pivot.indices) { val bpRow = bp.row(row) val pRow = pivot[row] - for (col in 0 until matrix.colNum) { - bpRow[col] = matrix[pRow, col] - } + for (col in 0 until matrix.colNum) bpRow[col] = matrix[pRow, col] } // Solve LY = b for (col in pivot.indices) { val bpCol = bp.row(col) + for (i in col + 1 until pivot.size) { val bpI = bp.row(i) val luICol = lu[i, col] @@ -210,17 +193,15 @@ fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Mat for (col in pivot.size - 1 downTo 0) { val bpCol = bp.row(col) val luDiag = lu[col, col] - for (j in 0 until matrix.colNum) { - bpCol[j] /= luDiag - } + for (j in 0 until matrix.colNum) bpCol[j] /= luDiag + for (i in 0 until col) { val bpI = bp.row(i) val luICol = lu[i, col] - for (j in 0 until matrix.colNum) { - bpI[j] -= bpCol[j] * luICol - } + for (j in 0 until matrix.colNum) bpI[j] -= bpCol[j] * luICol } } + return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt index 516f65bb8..390362f8c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt @@ -7,7 +7,7 @@ import scientifik.kmath.structures.asBuffer class MatrixBuilder(val rows: Int, val columns: Int) { operator fun invoke(vararg elements: T): FeaturedMatrix { - if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns") + require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } val buffer = elements.asBuffer() return BufferMatrix(rows, columns, buffer) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt index 5dc86a7dd..763bb1615 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt @@ -2,6 +2,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Ring import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.operations.invoke import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory @@ -37,8 +38,7 @@ interface MatrixContext : SpaceOperations> { fun > buffered( ring: R, bufferFactory: BufferFactory = Buffer.Companion::boxing - ): GenericMatrixContext = - BufferMatrixContext(ring, bufferFactory) + ): GenericMatrixContext = BufferMatrixContext(ring, bufferFactory) /** * Automatic buffered matrix, unboxed if it is possible @@ -61,45 +61,49 @@ interface GenericMatrixContext> : MatrixContext { override infix fun Matrix.dot(other: Matrix): Matrix { //TODO add typed error - if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + return produce(rowNum, other.colNum) { i, j -> val row = rows[i] val column = other.columns[j] - with(elementContext) { - sum(row.asSequence().zip(column.asSequence(), ::multiply)) - } + elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) } } } override infix fun Matrix.dot(vector: Point): Point { //TODO add typed error - if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})") + require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } + return point(rowNum) { i -> val row = rows[i] - with(elementContext) { - sum(row.asSequence().zip(vector.asSequence(), ::multiply)) - } + elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) } } } override operator fun Matrix.unaryMinus(): Matrix = - produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } + produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } } override fun add(a: Matrix, b: Matrix): Matrix { - if (a.rowNum != b.rowNum || a.colNum != b.colNum) error("Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]") - return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] + b[i, j] } } + require(a.rowNum == b.rowNum && a.colNum == b.colNum) { + "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]" + } + + return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } } } override operator fun Matrix.minus(b: Matrix): Matrix { - if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]") - return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } + require(rowNum == b.rowNum && colNum == b.colNum) { + "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]" + } + + return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } } } override fun multiply(a: Matrix, k: Number): Matrix = - produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } } + produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } operator fun Number.times(matrix: FeaturedMatrix): Matrix = matrix * this - override fun Matrix.times(value: T): Matrix = - produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } } + override operator fun Matrix.times(value: T): Matrix = + produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt index 691b464fc..82e5c7ef6 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt @@ -2,6 +2,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory @@ -10,10 +11,9 @@ import scientifik.kmath.structures.BufferFactory * Could be used on any point-like structure */ interface VectorSpace> : Space> { - val size: Int - val space: S + override val zero: Point get() = produce { space.zero } fun produce(initializer: (Int) -> T): Point @@ -22,29 +22,24 @@ interface VectorSpace> : Space> { */ //fun produceElement(initializer: (Int) -> T): Vector - override val zero: Point get() = produce { space.zero } + override fun add(a: Point, b: Point): Point = produce { space { a[it] + b[it] } } - override fun add(a: Point, b: Point): Point = produce { with(space) { a[it] + b[it] } } - - override fun multiply(a: Point, k: Number): Point = produce { with(space) { a[it] * k } } + override fun multiply(a: Point, k: Number): Point = produce { space { a[it] * k } } //TODO add basis companion object { - - private val realSpaceCache = HashMap>() + private val realSpaceCache: MutableMap> = hashMapOf() /** * Non-boxing double vector space */ - fun real(size: Int): BufferVectorSpace { - return realSpaceCache.getOrPut(size) { - BufferVectorSpace( - size, - RealField, - Buffer.Companion::auto - ) - } + fun real(size: Int): BufferVectorSpace = realSpaceCache.getOrPut(size) { + BufferVectorSpace( + size, + RealField, + Buffer.Companion::auto + ) } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt index 207151d57..5266dc884 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -18,7 +18,7 @@ class VirtualMatrix( override val shape: IntArray get() = intArrayOf(rowNum, colNum) - override fun get(i: Int, j: Int): T = generator(i, j) + override operator fun get(i: Int, j: Int): T = generator(i, j) override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix = VirtualMatrix(rowNum, colNum, this.features + features, generator) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt index db8863ae8..be222783e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt @@ -3,8 +3,12 @@ package scientifik.kmath.misc import scientifik.kmath.linear.Point import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke import scientifik.kmath.operations.sum import scientifik.kmath.structures.asBuffer +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /* * Implementation of backward-mode automatic differentiation. @@ -27,15 +31,14 @@ class DerivationResult( /** * compute divergence */ - fun div(): T = context.run { sum(deriv.values) } + fun div(): T = context { sum(deriv.values) } /** * Compute a gradient for variables in given order */ - fun grad(vararg variables: Variable): Point = if (variables.isEmpty()) { - error("Variable order is not provided for gradient construction") - } else { - variables.map(::deriv).asBuffer() + fun grad(vararg variables: Variable): Point { + check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } + return variables.map(::deriv).asBuffer() } } @@ -52,19 +55,27 @@ class DerivationResult( * assertEquals(9.0, x.d) // dy/dx * ``` */ -fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult = - AutoDiffContext(this).run { +inline fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult { + contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } + + return (AutoDiffContext(this)) { val result = body() - result.d = context.one// computing derivative w.r.t result + result.d = context.one // computing derivative w.r.t result runBackwardPass() DerivationResult(result.value, derivatives, this@deriv) } +} abstract class AutoDiffField> : Field> { - abstract val context: F + /** + * A variable accessing inner state of derivatives. + * Use this function in inner builders to avoid creating additional derivative bindings + */ + abstract var Variable.d: T + /** * Performs update of derivative after the rest of the formula in the back-pass. * @@ -78,12 +89,6 @@ abstract class AutoDiffField> : Field> { */ abstract fun derive(value: R, block: F.(R) -> Unit): R - /** - * A variable accessing inner state of derivatives. - * Use this function in inner builders to avoid creating additional derivative bindings - */ - abstract var Variable.d: T - abstract fun variable(value: T): Variable inline fun variable(block: F.() -> T): Variable = variable(context.block()) @@ -98,46 +103,35 @@ abstract class AutoDiffField> : Field> { override operator fun Variable.plus(b: Number): Variable = b.plus(this) override operator fun Number.minus(b: Variable): Variable = - derive(variable { this@minus.toDouble() * one - b.value }) { z -> - b.d -= z.d - } + derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } override operator fun Variable.minus(b: Number): Variable = - derive(variable { this@minus.value - one * b.toDouble() }) { z -> - this@minus.d += z.d - } + derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } } /** * Automatic Differentiation context class. */ -private class AutoDiffContext>(override val context: F) : AutoDiffField() { - +@PublishedApi +internal class AutoDiffContext>(override val context: F) : AutoDiffField() { // this stack contains pairs of blocks and values to apply them to - private var stack = arrayOfNulls(8) - private var sp = 0 - - internal val derivatives = HashMap, T>() - + private var stack: Array = arrayOfNulls(8) + private var sp: Int = 0 + val derivatives: MutableMap, T> = hashMapOf() + override val zero: Variable get() = Variable(context.zero) + override val one: Variable get() = Variable(context.one) /** * A variable coupled with its derivative. For internal use only */ private class VariableWithDeriv(x: T, var d: T) : Variable(x) - override fun variable(value: T): Variable = VariableWithDeriv(value, context.zero) override var Variable.d: T get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero - set(value) { - if (this is VariableWithDeriv) { - d = value - } else { - derivatives[this] = value - } - } + set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value @Suppress("UNCHECKED_CAST") override fun derive(value: R, block: F.(R) -> Unit): R { @@ -160,67 +154,49 @@ private class AutoDiffContext>(override val context: F) : // Basic math (+, -, *, /) - override fun add(a: Variable, b: Variable): Variable = - derive(variable { a.value + b.value }) { z -> - a.d += z.d - b.d += z.d - } + override fun add(a: Variable, b: Variable): Variable = derive(variable { a.value + b.value }) { z -> + a.d += z.d + b.d += z.d + } - override fun multiply(a: Variable, b: Variable): Variable = - derive(variable { a.value * b.value }) { z -> - a.d += z.d * b.value - b.d += z.d * a.value - } + override fun multiply(a: Variable, b: Variable): Variable = derive(variable { a.value * b.value }) { z -> + a.d += z.d * b.value + b.d += z.d * a.value + } - override fun divide(a: Variable, b: Variable): Variable = - derive(variable { a.value / b.value }) { z -> - a.d += z.d / b.value - b.d -= z.d * a.value / (b.value * b.value) - } + override fun divide(a: Variable, b: Variable): Variable = derive(variable { a.value / b.value }) { z -> + a.d += z.d / b.value + b.d -= z.d * a.value / (b.value * b.value) + } - override fun multiply(a: Variable, k: Number): Variable = - derive(variable { k.toDouble() * a.value }) { z -> - a.d += z.d * k.toDouble() - } - - override val zero: Variable get() = Variable(context.zero) - override val one: Variable get() = Variable(context.one) + override fun multiply(a: Variable, k: Number): Variable = derive(variable { k.toDouble() * a.value }) { z -> + a.d += z.d * k.toDouble() + } } // Extensions for differentiation of various basic mathematical functions // x ^ 2 fun > AutoDiffField.sqr(x: Variable): Variable = - derive(variable { x.value * x.value }) { z -> - x.d += z.d * 2 * x.value - } + derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } // x ^ 1/2 fun > AutoDiffField.sqrt(x: Variable): Variable = - derive(variable { sqrt(x.value) }) { z -> - x.d += z.d * 0.5 / z.value - } + derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } // x ^ y (const) fun > AutoDiffField.pow(x: Variable, y: Double): Variable = - derive(variable { power(x.value, y) }) { z -> - x.d += z.d * y * power(x.value, y - 1) - } + derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } fun > AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) // exp(x) fun > AutoDiffField.exp(x: Variable): Variable = - derive(variable { exp(x.value) }) { z -> - x.d += z.d * z.value - } + derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value } // ln(x) -fun > AutoDiffField.ln(x: Variable): Variable = derive( - variable { ln(x.value) } -) { z -> - x.d += z.d / x.value -} +fun > AutoDiffField.ln(x: Variable): Variable = + derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value } // x ^ y (any) fun > AutoDiffField.pow(x: Variable, y: Variable): Variable = @@ -228,12 +204,8 @@ fun > AutoDiffField.pow(x: Variable, y: V // sin(x) fun > AutoDiffField.sin(x: Variable): Variable = - derive(variable { sin(x.value) }) { z -> - x.d += z.d * cos(x.value) - } + derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } // cos(x) fun > AutoDiffField.cos(x: Variable): Variable = - derive(variable { cos(x.value) }) { z -> - x.d -= z.d * sin(x.value) - } + derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt index d3bf0891f..1272ddd1c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt @@ -41,6 +41,6 @@ fun ClosedFloatingPointRange.toSequenceWithPoints(numPoints: Int): Seque */ @Deprecated("Replace by 'toSequenceWithPoints'") fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { - if (numPoints < 2) error("Can't create generic grid with less than two points") + require(numPoints >= 2) { "Can't create generic grid with less than two points" } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt index a0f4525cc..e11adc135 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt @@ -2,6 +2,8 @@ package scientifik.kmath.misc import scientifik.kmath.operations.Space import scientifik.kmath.operations.invoke +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.jvm.JvmName /** @@ -11,67 +13,68 @@ import kotlin.jvm.JvmName * @param R the type of resulting iterable. * @param initial lazy evaluated. */ -fun Iterator.cumulative(initial: R, operation: (R, T) -> R): Iterator = object : Iterator { - var state: R = initial - override fun hasNext(): Boolean = this@cumulative.hasNext() +inline fun Iterator.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator { + contract { callsInPlace(operation) } - override fun next(): R { - state = operation(state, this@cumulative.next()) - return state + return object : Iterator { + var state: R = initial + + override fun hasNext(): Boolean = this@cumulative.hasNext() + + override fun next(): R { + state = operation(state, this@cumulative.next()) + return state + } } } -fun Iterable.cumulative(initial: R, operation: (R, T) -> R): Iterable = object : Iterable { - override fun iterator(): Iterator = this@cumulative.iterator().cumulative(initial, operation) -} +inline fun Iterable.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable = + Iterable { this@cumulative.iterator().cumulative(initial, operation) } -fun Sequence.cumulative(initial: R, operation: (R, T) -> R): Sequence = object : Sequence { - override fun iterator(): Iterator = this@cumulative.iterator().cumulative(initial, operation) +inline fun Sequence.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence = Sequence { + this@cumulative.iterator().cumulative(initial, operation) } fun List.cumulative(initial: R, operation: (R, T) -> R): List = - this.iterator().cumulative(initial, operation).asSequence().toList() + iterator().cumulative(initial, operation).asSequence().toList() //Cumulative sum /** * Cumulative sum with custom space */ -fun Iterable.cumulativeSum(space: Space): Iterable = space { - cumulative(zero) { element: T, sum: T -> sum + element } -} +fun Iterable.cumulativeSum(space: Space): Iterable = + space { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun Iterable.cumulativeSum(): Iterable = this.cumulative(0.0) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Iterable.cumulativeSum(): Iterable = this.cumulative(0) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Iterable.cumulativeSum(): Iterable = this.cumulative(0L) { element, sum -> sum + element } +fun Iterable.cumulativeSum(): Iterable = cumulative(0L) { element, sum -> sum + element } -fun Sequence.cumulativeSum(space: Space): Sequence = with(space) { - cumulative(zero) { element: T, sum: T -> sum + element } -} +fun Sequence.cumulativeSum(space: Space): Sequence = + space { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun Sequence.cumulativeSum(): Sequence = this.cumulative(0.0) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Sequence.cumulativeSum(): Sequence = this.cumulative(0) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Sequence.cumulativeSum(): Sequence = this.cumulative(0L) { element, sum -> sum + element } +fun Sequence.cumulativeSum(): Sequence = cumulative(0L) { element, sum -> sum + element } -fun List.cumulativeSum(space: Space): List = with(space) { - cumulative(zero) { element: T, sum: T -> sum + element } -} +fun List.cumulativeSum(space: Space): List = + space { cumulative(zero) { element: T, sum: T -> sum + element } } @JvmName("cumulativeSumOfDouble") -fun List.cumulativeSum(): List = this.cumulative(0.0) { element, sum -> sum + element } +fun List.cumulativeSum(): List = cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun List.cumulativeSum(): List = this.cumulative(0) { element, sum -> sum + element } +fun List.cumulativeSum(): List = cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun List.cumulativeSum(): List = this.cumulative(0L) { element, sum -> sum + element } +fun List.cumulativeSum(): List = cumulative(0L) { element, sum -> sum + element } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt index fd7719157..0eed7132e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt @@ -3,12 +3,13 @@ package scientifik.kmath.operations import scientifik.kmath.operations.BigInt.Companion.BASE import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE import scientifik.kmath.structures.* +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.log2 import kotlin.math.max import kotlin.math.min import kotlin.math.sign - typealias Magnitude = UIntArray typealias TBase = ULong @@ -22,8 +23,9 @@ object BigIntField : Field { override val one: BigInt = BigInt.ONE override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b) + override fun number(value: Number): BigInt = value.toLong().toBigInt() - override fun multiply(a: BigInt, k: Number): BigInt = a.times(k.toLong()) + override fun multiply(a: BigInt, k: Number): BigInt = a.times(number(k)) override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) @@ -430,8 +432,8 @@ fun ULong.toBigInt(): BigInt = BigInt( * Create a [BigInt] with this array of magnitudes with protective copy */ fun UIntArray.toBigInt(sign: Byte): BigInt { - if (sign == 0.toByte() && isNotEmpty()) error("") - return BigInt(sign, this.copyOf()) + require(sign != 0.toByte() || !isNotEmpty()) + return BigInt(sign, copyOf()) } val hexChToInt: MutableMap = hashMapOf( @@ -484,11 +486,15 @@ fun String.parseBigInteger(): BigInt? { return res * sign } -inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer = - boxing(size, initializer) +inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer { + contract { callsInPlace(initializer) } + return boxing(size, initializer) +} -inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer = - boxing(size, initializer) +inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer { + contract { callsInPlace(initializer) } + return boxing(size, initializer) +} fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing = BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) @@ -496,5 +502,4 @@ fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing BigInt -): BufferedNDRingElement = - NDAlgebra.bigInt(*shape).produce(initializer) +): BufferedNDRingElement = NDAlgebra.bigInt(*shape).produce(initializer) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 0ce144a33..dcfd97d1a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -6,17 +6,45 @@ import scientifik.kmath.structures.MutableBuffer import scientifik.memory.MemoryReader import scientifik.memory.MemorySpec import scientifik.memory.MemoryWriter +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.* +/** + * This complex's conjugate. + */ +val Complex.conjugate: Complex + get() = Complex(re, -im) + +/** + * This complex's reciprocal. + */ +val Complex.reciprocal: Complex + get() { + val scale = re * re + im * im + return Complex(re / scale, -im / scale) + } + +/** + * Absolute value of complex number. + */ +val Complex.r: Double + get() = sqrt(re * re + im * im) + +/** + * An angle between vector represented by complex number and X axis. + */ +val Complex.theta: Double + get() = atan(im / re) + private val PI_DIV_2 = Complex(PI / 2, 0) /** * A field of [Complex]. */ -object ComplexField : ExtendedField { - override val zero: Complex = Complex(0.0, 0.0) - - override val one: Complex = Complex(1.0, 0.0) +object ComplexField : ExtendedField, Norm { + override val zero: Complex = 0.0.toComplex() + override val one: Complex = 1.0.toComplex() /** * The imaginary unit. @@ -30,19 +58,53 @@ object ComplexField : ExtendedField { override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) - override fun divide(a: Complex, b: Complex): Complex { - val norm = b.re * b.re + b.im * b.im - return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) + override fun divide(a: Complex, b: Complex): Complex = when { + b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN) + + (if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> { + val wr = b.im / b.re + val wd = b.re + wr * b.im + + if (wd.isNaN() || wd == 0.0) + Complex(Double.NaN, Double.NaN) + else + Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd) + } + + b.im == 0.0 -> Complex(Double.NaN, Double.NaN) + + else -> { + val wr = b.re / b.im + val wd = b.im + wr * b.re + + if (wd.isNaN() || wd == 0.0) + Complex(Double.NaN, Double.NaN) + else + Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd) + } } override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 - override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg) - override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg) - override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2 - override fun power(arg: Complex, pow: Number): Complex = - arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) + override fun tan(arg: Complex): Complex { + val e1 = exp(-i * arg) + val e2 = exp(i * arg) + return i * (e1 - e2) / (e1 + e2) + } + + override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg) + override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg) + + override fun atan(arg: Complex): Complex { + val iArg = i * arg + return i * (ln(1 - iArg) - ln(1 + iArg)) / 2 + } + + override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0) + arg.re.pow(pow.toDouble()).toComplex() + else + exp(pow * ln(arg)) override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) @@ -93,6 +155,8 @@ object ComplexField : ExtendedField { */ operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) + override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) + override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) } @@ -105,12 +169,12 @@ object ComplexField : ExtendedField { data class Complex(val re: Double, val im: Double) : FieldElement, Comparable { constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) + override val context: ComplexField get() = ComplexField + override fun unwrap(): Complex = this override fun Complex.wrap(): Complex = this - override val context: ComplexField get() = ComplexField - override fun compareTo(other: Complex): Int = r.compareTo(other.r) companion object : MemorySpec { @@ -126,33 +190,20 @@ data class Complex(val re: Double, val im: Double) : FieldElement Complex): Buffer { + contract { callsInPlace(init) } return MemoryBuffer.create(Complex, size, init) } inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer { + contract { callsInPlace(init) } return MemoryBuffer.create(Complex, size, init) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index b113e07a1..0735a96da 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -1,5 +1,6 @@ package scientifik.kmath.operations +import scientifik.kmath.operations.RealField.pow import kotlin.math.abs import kotlin.math.pow as kpow @@ -7,19 +8,28 @@ import kotlin.math.pow as kpow * Advanced Number-like semifield that implements basic operations. */ interface ExtendedFieldOperations : - InverseTrigonometricOperations, + FieldOperations, + TrigonometricOperations, + HyperbolicOperations, PowerOperations, ExponentialOperations { override fun tan(arg: T): T = sin(arg) / cos(arg) + override fun tanh(arg: T): T = sinh(arg) / cosh(arg) override fun unaryOperation(operation: String, arg: T): T = when (operation) { TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.TAN_OPERATION -> tan(arg) - InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) - InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) - InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) + TrigonometricOperations.ACOS_OPERATION -> acos(arg) + TrigonometricOperations.ASIN_OPERATION -> asin(arg) + TrigonometricOperations.ATAN_OPERATION -> atan(arg) + HyperbolicOperations.COSH_OPERATION -> cosh(arg) + HyperbolicOperations.SINH_OPERATION -> sinh(arg) + HyperbolicOperations.TANH_OPERATION -> tanh(arg) + HyperbolicOperations.ACOSH_OPERATION -> acosh(arg) + HyperbolicOperations.ASINH_OPERATION -> asinh(arg) + HyperbolicOperations.ATANH_OPERATION -> atanh(arg) PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) @@ -32,6 +42,13 @@ interface ExtendedFieldOperations : * Advanced Number-like field that implements basic operations. */ interface ExtendedField : ExtendedFieldOperations, Field { + override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2 + override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2 + override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) + override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg) + override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) + override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { PowerOperations.POW_OPERATION -> power(left, right) else -> super.rightSideNumberOperation(operation, left, right) @@ -46,12 +63,13 @@ interface ExtendedField : ExtendedFieldOperations, Field { * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 */ inline class Real(val value: Double) : FieldElement { + override val context: RealField + get() = RealField + override fun unwrap(): Double = value override fun Double.wrap(): Real = Real(value) - override val context: RealField get() = RealField - companion object } @@ -60,12 +78,22 @@ inline class Real(val value: Double) : FieldElement { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object RealField : ExtendedField, Norm { - override val zero: Double = 0.0 + override val zero: Double + get() = 0.0 + + override val one: Double + get() = 1.0 + + override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { + PowerOperations.POW_OPERATION -> left pow right + else -> super.binaryOperation(operation, left, right) + } + override inline fun add(a: Double, b: Double): Double = a + b - override inline fun multiply(a: Double, b: Double): Double = a * b override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() - override val one: Double = 1.0 + override inline fun multiply(a: Double, b: Double): Double = a * b + override inline fun divide(a: Double, b: Double): Double = a / b override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) @@ -75,27 +103,24 @@ object RealField : ExtendedField, Norm { override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) - override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) + override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg) + override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg) + override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg) + override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg) + override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) + override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) + override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) override inline fun norm(arg: Double): Double = abs(arg) override inline fun Double.unaryMinus(): Double = -this - override inline fun Double.plus(b: Double): Double = this + b - override inline fun Double.minus(b: Double): Double = this - b - override inline fun Double.times(b: Double): Double = this * b - override inline fun Double.div(b: Double): Double = this / b - - override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) - } } /** @@ -103,12 +128,22 @@ object RealField : ExtendedField, Norm { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object FloatField : ExtendedField, Norm { - override val zero: Float = 0f + override val zero: Float + get() = 0.0f + + override val one: Float + get() = 1.0f + + override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { + PowerOperations.POW_OPERATION -> left pow right + else -> super.binaryOperation(operation, left, right) + } + override inline fun add(a: Float, b: Float): Float = a + b - override inline fun multiply(a: Float, b: Float): Float = a * b override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() - override val one: Float = 1f + override inline fun multiply(a: Float, b: Float): Float = a * b + override inline fun divide(a: Float, b: Float): Float = a / b override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) @@ -118,108 +153,118 @@ object FloatField : ExtendedField, Norm { override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) - override inline fun power(arg: Float, pow: Number): Float = arg.pow(pow.toFloat()) + override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg) + override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg) + override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg) + override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg) + override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg) + override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg) + override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) override inline fun norm(arg: Float): Float = abs(arg) override inline fun Float.unaryMinus(): Float = -this - override inline fun Float.plus(b: Float): Float = this + b - override inline fun Float.minus(b: Float): Float = this - b - override inline fun Float.times(b: Float): Float = this * b - override inline fun Float.div(b: Float): Float = this / b } /** - * A field for [Int] without boxing. Does not produce corresponding field element + * A field for [Int] without boxing. Does not produce corresponding ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object IntRing : Ring, Norm { - override val zero: Int = 0 + override val zero: Int + get() = 0 + + override val one: Int + get() = 1 + override inline fun add(a: Int, b: Int): Int = a + b - override inline fun multiply(a: Int, b: Int): Int = a * b override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a - override val one: Int = 1 + + override inline fun multiply(a: Int, b: Int): Int = a * b override inline fun norm(arg: Int): Int = abs(arg) override inline fun Int.unaryMinus(): Int = -this - override inline fun Int.plus(b: Int): Int = this + b - override inline fun Int.minus(b: Int): Int = this - b - override inline fun Int.times(b: Int): Int = this * b } /** - * A field for [Short] without boxing. Does not produce appropriate field element + * A field for [Short] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object ShortRing : Ring, Norm { - override val zero: Short = 0 + override val zero: Short + get() = 0 + + override val one: Short + get() = 1 + override inline fun add(a: Short, b: Short): Short = (a + b).toShort() - override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() - override val one: Short = 1 + + override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort() - override inline fun Short.plus(b: Short): Short = (this + b).toShort() - override inline fun Short.minus(b: Short): Short = (this - b).toShort() - override inline fun Short.times(b: Short): Short = (this * b).toShort() } /** - * A field for [Byte] values + * A field for [Byte] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object ByteRing : Ring, Norm { - override val zero: Byte = 0 + override val zero: Byte + get() = 0 + + override val one: Byte + get() = 1 + override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() - override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() - override val one: Byte = 1 + + override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte() - override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() - override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() - override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() } /** - * A field for [Long] values + * A field for [Double] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object LongRing : Ring, Norm { - override val zero: Long = 0 - override inline fun add(a: Long, b: Long): Long = (a + b) - override inline fun multiply(a: Long, b: Long): Long = (a * b) + override val zero: Long + get() = 0 + + override val one: Long + get() = 1 + + override inline fun add(a: Long, b: Long): Long = a + b override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() - override val one: Long = 1 + + override inline fun multiply(a: Long, b: Long): Long = a * b override fun norm(arg: Long): Long = abs(arg) override inline fun Long.unaryMinus(): Long = (-this) - override inline fun Long.plus(b: Long): Long = (this + b) - override inline fun Long.minus(b: Long): Long = (this - b) - override inline fun Long.times(b: Long): Long = (this * b) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index dea45a145..1dac649aa 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -1,12 +1,11 @@ package scientifik.kmath.operations /** - * A container for trigonometric operations for specific type. They are limited to semifields. + * A container for trigonometric operations for specific type. * - * The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. - * It also allows to override behavior for optional operations. + * @param T the type of element of this structure. */ -interface TrigonometricOperations : FieldOperations { +interface TrigonometricOperations : Algebra { /** * Computes the sine of [arg]. */ @@ -22,31 +21,6 @@ interface TrigonometricOperations : FieldOperations { */ fun tan(arg: T): T - companion object { - /** - * The identifier of sine. - */ - const val SIN_OPERATION: String = "sin" - - /** - * The identifier of cosine. - */ - const val COS_OPERATION: String = "cos" - - /** - * The identifier of tangent. - */ - const val TAN_OPERATION: String = "tan" - } -} - -/** - * A container for inverse trigonometric operations for specific type. They are limited to semifields. - * - * The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. - * It also allows to override behavior for optional operations. - */ -interface InverseTrigonometricOperations : TrigonometricOperations { /** * Computes the inverse sine of [arg]. */ @@ -63,6 +37,21 @@ interface InverseTrigonometricOperations : TrigonometricOperations { fun atan(arg: T): T companion object { + /** + * The identifier of sine. + */ + const val SIN_OPERATION: String = "sin" + + /** + * The identifier of cosine. + */ + const val COS_OPERATION: String = "cos" + + /** + * The identifier of tangent. + */ + const val TAN_OPERATION: String = "tan" + /** * The identifier of inverse sine. */ @@ -98,20 +87,121 @@ fun >> tan(arg: T): T = arg.conte /** * Computes the inverse sine of [arg]. */ -fun >> asin(arg: T): T = arg.context.asin(arg) +fun >> asin(arg: T): T = arg.context.asin(arg) /** * Computes the inverse cosine of [arg]. */ -fun >> acos(arg: T): T = arg.context.acos(arg) +fun >> acos(arg: T): T = arg.context.acos(arg) /** * Computes the inverse tangent of [arg]. */ -fun >> atan(arg: T): T = arg.context.atan(arg) +fun >> atan(arg: T): T = arg.context.atan(arg) + +/** + * A container for hyperbolic trigonometric operations for specific type. + * + * @param T the type of element of this structure. + */ +interface HyperbolicOperations : Algebra { + /** + * Computes the hyperbolic sine of [arg]. + */ + fun sinh(arg: T): T + + /** + * Computes the hyperbolic cosine of [arg]. + */ + fun cosh(arg: T): T + + /** + * Computes the hyperbolic tangent of [arg]. + */ + fun tanh(arg: T): T + + /** + * Computes the inverse hyperbolic sine of [arg]. + */ + fun asinh(arg: T): T + + /** + * Computes the inverse hyperbolic cosine of [arg]. + */ + fun acosh(arg: T): T + + /** + * Computes the inverse hyperbolic tangent of [arg]. + */ + fun atanh(arg: T): T + + companion object { + /** + * The identifier of hyperbolic sine. + */ + const val SINH_OPERATION: String = "sinh" + + /** + * The identifier of hyperbolic cosine. + */ + const val COSH_OPERATION: String = "cosh" + + /** + * The identifier of hyperbolic tangent. + */ + const val TANH_OPERATION: String = "tanh" + + /** + * The identifier of inverse hyperbolic sine. + */ + const val ASINH_OPERATION: String = "asinh" + + /** + * The identifier of inverse hyperbolic cosine. + */ + const val ACOSH_OPERATION: String = "acosh" + + /** + * The identifier of inverse hyperbolic tangent. + */ + const val ATANH_OPERATION: String = "atanh" + } +} + +/** + * Computes the hyperbolic sine of [arg]. + */ +fun >> sinh(arg: T): T = arg.context.sinh(arg) + +/** + * Computes the hyperbolic cosine of [arg]. + */ +fun >> cosh(arg: T): T = arg.context.cosh(arg) + +/** + * Computes the hyperbolic tangent of [arg]. + */ +fun >> tanh(arg: T): T = arg.context.tanh(arg) + +/** + * Computes the inverse hyperbolic sine of [arg]. + */ +fun >> asinh(arg: T): T = arg.context.asinh(arg) + +/** + * Computes the inverse hyperbolic cosine of [arg]. + */ +fun >> acosh(arg: T): T = arg.context.acosh(arg) + +/** + * Computes the inverse hyperbolic tangent of [arg]. + */ +fun >> atanh(arg: T): T = arg.context.atanh(arg) /** * A context extension to include power operations based on exponentiation. + * + * @param T the type of element of this structure. */ interface PowerOperations : Algebra { /** @@ -163,6 +253,8 @@ fun >> sqr(arg: T): T = arg pow 2.0 /** * A container for operations related to `exp` and `ln` functions. + * + * @param T the type of element of this structure. */ interface ExponentialOperations : Algebra { /** @@ -200,6 +292,9 @@ fun >> ln(arg: T): T = arg.context. /** * A container for norm functional on element. + * + * @param T the type of element having norm defined. + * @param R the type of norm. */ interface Norm { /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt index 4cbb565c1..be71645d1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt @@ -8,19 +8,17 @@ class BoxingNDField>( override val elementContext: F, val bufferFactory: BufferFactory ) : BufferedNDField { - + override val zero: BufferedNDFieldElement by lazy { produce { zero } } + override val one: BufferedNDFieldElement by lazy { produce { one } } override val strides: Strides = DefaultStrides(shape) fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) override fun check(vararg elements: NDBuffer) { - if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") + check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } } - override val zero: BufferedNDFieldElement by lazy { produce { zero } } - override val one: BufferedNDFieldElement by lazy { produce { one } } - override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement = BufferedNDFieldElement( this, @@ -28,6 +26,7 @@ class BoxingNDField>( override fun map(arg: NDBuffer, transform: F.(T) -> T): BufferedNDFieldElement { check(arg) + return BufferedNDFieldElement( this, buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) }) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt index f7be95736..91b945e79 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt @@ -8,19 +8,16 @@ class BoxingNDRing>( override val elementContext: R, val bufferFactory: BufferFactory ) : BufferedNDRing { - override val strides: Strides = DefaultStrides(shape) - - fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = - bufferFactory(size, initializer) - - override fun check(vararg elements: NDBuffer) { - if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") - } - override val zero: BufferedNDRingElement by lazy { produce { zero } } override val one: BufferedNDRingElement by lazy { produce { one } } + fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) + + override fun check(vararg elements: NDBuffer) { + require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + } + override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement = BufferedNDRingElement( this, diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt index 00832b69c..2c3d69094 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt @@ -6,7 +6,6 @@ import kotlin.reflect.KClass * A context that allows to operate on a [MutableBuffer] as on 2d array */ class BufferAccessor2D(val type: KClass, val rowNum: Int, val colNum: Int) { - operator fun Buffer.get(i: Int, j: Int): T = get(i + colNum * j) operator fun MutableBuffer.set(i: Int, j: Int, value: T) { @@ -26,15 +25,14 @@ class BufferAccessor2D(val type: KClass, val rowNum: Int, val colNum inner class Row(val buffer: MutableBuffer, val rowIndex: Int) : MutableBuffer { override val size: Int get() = colNum - override fun get(index: Int): T = buffer[rowIndex, index] + override operator fun get(index: Int): T = buffer[rowIndex, index] - override fun set(index: Int, value: T) { + override operator fun set(index: Int, value: T) { buffer[rowIndex, index] = value } override fun copy(): MutableBuffer = MutableBuffer.auto(type, colNum) { get(it) } - - override fun iterator(): Iterator = (0 until colNum).map(::get).iterator() + override operator fun iterator(): Iterator = (0 until colNum).map(::get).iterator() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt index 06922c56f..2c0c2021f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt @@ -5,9 +5,8 @@ import scientifik.kmath.operations.* interface BufferedNDAlgebra : NDAlgebra> { val strides: Strides - override fun check(vararg elements: NDBuffer) { - if (!elements.all { it.strides == this.strides }) error("Strides mismatch") - } + override fun check(vararg elements: NDBuffer): Unit = + require(elements.all { it.strides == strides }) { ("Strides mismatch") } /** * Convert any [NDStructure] to buffered structure using strides from this context. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt index d1d622b23..20e34fadd 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt @@ -30,7 +30,6 @@ class BufferedNDRingElement>( override val context: BufferedNDRing, override val buffer: Buffer ) : BufferedNDElement(), RingElement, BufferedNDRingElement, BufferedNDRing> { - override fun unwrap(): NDBuffer = this override fun NDBuffer.wrap(): BufferedNDRingElement { @@ -43,7 +42,6 @@ class BufferedNDFieldElement>( override val context: BufferedNDField, override val buffer: Buffer ) : BufferedNDElement(), FieldElement, BufferedNDFieldElement, BufferedNDField> { - override fun unwrap(): NDBuffer = this override fun NDBuffer.wrap(): BufferedNDFieldElement { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index 5fdf79e88..4afaa63ab 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -2,6 +2,8 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Complex import scientifik.kmath.operations.complex +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.reflect.KClass /** @@ -117,15 +119,14 @@ interface MutableBuffer : Buffer { MutableListBuffer(MutableList(size, initializer)) @Suppress("UNCHECKED_CAST") - inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer { - return when (type) { + inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer = + when (type) { Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer else -> boxing(size, initializer) } - } /** * Create most appropriate mutable buffer for given type avoiding boxing wherever possible @@ -150,9 +151,8 @@ inline class ListBuffer(val list: List) : Buffer { override val size: Int get() = list.size - override fun get(index: Int): T = list[index] - - override fun iterator(): Iterator = list.iterator() + override operator fun get(index: Int): T = list[index] + override operator fun iterator(): Iterator = list.iterator() } /** @@ -167,7 +167,10 @@ fun List.asBuffer(): ListBuffer = ListBuffer(this) * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an array element given its index. */ -inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer = List(size, init).asBuffer() +inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer { + contract { callsInPlace(init) } + return List(size, init).asBuffer() +} /** * [MutableBuffer] implementation over [MutableList]. @@ -176,17 +179,16 @@ inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer = List(siz * @property list The underlying list. */ inline class MutableListBuffer(val list: MutableList) : MutableBuffer { - override val size: Int get() = list.size - override fun get(index: Int): T = list[index] + override operator fun get(index: Int): T = list[index] - override fun set(index: Int, value: T) { + override operator fun set(index: Int, value: T) { list[index] = value } - override fun iterator(): Iterator = list.iterator() + override operator fun iterator(): Iterator = list.iterator() override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) } @@ -201,14 +203,13 @@ class ArrayBuffer(private val array: Array) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): T = array[index] + override operator fun get(index: Int): T = array[index] - override fun set(index: Int, value: T) { + override operator fun set(index: Int, value: T) { array[index] = value } - override fun iterator(): Iterator = array.iterator() - + override operator fun iterator(): Iterator = array.iterator() override fun copy(): MutableBuffer = ArrayBuffer(array.copyOf()) } @@ -226,9 +227,9 @@ fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { override val size: Int get() = buffer.size - override fun get(index: Int): T = buffer[index] + override operator fun get(index: Int): T = buffer[index] - override fun iterator(): Iterator = buffer.iterator() + override operator fun iterator(): Iterator = buffer.iterator() } /** @@ -238,12 +239,12 @@ inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { * @param T the type of elements provided by the buffer. */ class VirtualBuffer(override val size: Int, private val generator: (Int) -> T) : Buffer { - override fun get(index: Int): T { + override operator fun get(index: Int): T { if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") return generator(index) } - override fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() + override operator fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() override fun contentEquals(other: Buffer<*>): Boolean { return if (other is VirtualBuffer) { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt index be0b9e5c6..2c6e3a5c7 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt @@ -4,6 +4,9 @@ import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.complex +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract typealias ComplexNDElement = BufferedNDFieldElement @@ -15,7 +18,6 @@ class ComplexNDField(override val shape: IntArray) : ExtendedNDField> { override val strides: Strides = DefaultStrides(shape) - override val elementContext: ComplexField get() = ComplexField override val zero: ComplexNDElement by lazy { produce { zero } } override val one: ComplexNDElement by lazy { produce { one } } @@ -45,6 +47,7 @@ class ComplexNDField(override val shape: IntArray) : transform: ComplexField.(index: IntArray, Complex) -> Complex ): ComplexNDElement { check(arg) + return BufferedNDFieldElement( this, buildBuffer(arg.strides.linearSize) { offset -> @@ -61,6 +64,7 @@ class ComplexNDField(override val shape: IntArray) : transform: ComplexField.(Complex, Complex) -> Complex ): ComplexNDElement { check(a, b) + return BufferedNDFieldElement( this, buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) @@ -69,23 +73,25 @@ class ComplexNDField(override val shape: IntArray) : override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = BufferedNDFieldElement(this@ComplexNDField, buffer) - override fun power(arg: NDBuffer, pow: Number): ComplexNDElement = map(arg) { power(it, pow) } + override fun power(arg: NDBuffer, pow: Number): ComplexNDElement = + map(arg) { power(it, pow) } override fun exp(arg: NDBuffer): ComplexNDElement = map(arg) { exp(it) } - override fun ln(arg: NDBuffer): ComplexNDElement = map(arg) { ln(it) } override fun sin(arg: NDBuffer): ComplexNDElement = map(arg) { sin(it) } - override fun cos(arg: NDBuffer): ComplexNDElement = map(arg) { cos(it) } - override fun tan(arg: NDBuffer): ComplexNDElement = map(arg) { tan(it) } - override fun asin(arg: NDBuffer): ComplexNDElement = map(arg) { asin(it) } - override fun acos(arg: NDBuffer): ComplexNDElement = map(arg) { acos(it) } - override fun atan(arg: NDBuffer): ComplexNDElement = map(arg) { atan(it) } + + override fun sinh(arg: NDBuffer): ComplexNDElement = map(arg) { sinh(it) } + override fun cosh(arg: NDBuffer): ComplexNDElement = map(arg) { cosh(it) } + override fun tanh(arg: NDBuffer): ComplexNDElement = map(arg) { tanh(it) } + override fun asinh(arg: NDBuffer): ComplexNDElement = map(arg) { asinh(it) } + override fun acosh(arg: NDBuffer): ComplexNDElement = map(arg) { acosh(it) } + override fun atanh(arg: NDBuffer): ComplexNDElement = map(arg) { atanh(it) } } @@ -107,6 +113,7 @@ inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(inde * Map one [ComplexNDElement] using function without indices. */ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { + contract { callsInPlace(transform) } val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } return BufferedNDFieldElement(context, buffer) } @@ -146,5 +153,6 @@ fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(In * Produce a context for n-dimensional operations inside this real field */ inline fun ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { - return NDField.complex(*shape).run(action) + contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } + return NDField.complex(*shape).action() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt index a2d0a71b3..9c32aa31b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt @@ -1,5 +1,7 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.experimental.and /** @@ -57,17 +59,18 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged override val size: Int get() = values.size - override fun get(index: Int): Double? = if (isValid(index)) values[index] else null + override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null - override fun iterator(): Iterator = values.indices.asSequence().map { + override operator fun iterator(): Iterator = values.indices.asSequence().map { if (isValid(it)) values[it] else null }.iterator() } inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { - for (i in indices) { - if (isValid(i)) { - block(values[i]) - } - } + contract { callsInPlace(block) } + + indices + .asSequence() + .filter(::isValid) + .forEach { block(values[it]) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt index e42df8c14..9e974c644 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [FloatArray]. * @@ -8,13 +11,13 @@ package scientifik.kmath.structures inline class FloatBuffer(val array: FloatArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Float = array[index] + override operator fun get(index: Int): Float = array[index] - override fun set(index: Int, value: Float) { + override operator fun set(index: Int, value: Float) { array[index] = value } - override fun iterator(): FloatIterator = array.iterator() + override operator fun iterator(): FloatIterator = array.iterator() override fun copy(): MutableBuffer = FloatBuffer(array.copyOf()) @@ -27,7 +30,10 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) }) +inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer { + contract { callsInPlace(init) } + return FloatBuffer(FloatArray(size) { init(it) }) +} /** * Returns a new [FloatBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt index a3f0f3c3e..95651c547 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt @@ -1,5 +1,9 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [IntArray]. * @@ -8,17 +12,16 @@ package scientifik.kmath.structures inline class IntBuffer(val array: IntArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Int = array[index] + override operator fun get(index: Int): Int = array[index] - override fun set(index: Int, value: Int) { + override operator fun set(index: Int, value: Int) { array[index] = value } - override fun iterator(): IntIterator = array.iterator() + override operator fun iterator(): IntIterator = array.iterator() override fun copy(): MutableBuffer = IntBuffer(array.copyOf()) - } /** @@ -28,7 +31,10 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) }) +inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer { + contract { callsInPlace(init) } + return IntBuffer(IntArray(size) { init(it) }) +} /** * Returns a new [IntBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt index 912656c68..a44109f8a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [LongArray]. * @@ -8,13 +11,13 @@ package scientifik.kmath.structures inline class LongBuffer(val array: LongArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Long = array[index] + override operator fun get(index: Int): Long = array[index] - override fun set(index: Int, value: Long) { + override operator fun set(index: Int, value: Long) { array[index] = value } - override fun iterator(): LongIterator = array.iterator() + override operator fun iterator(): LongIterator = array.iterator() override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) @@ -28,7 +31,10 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) }) +inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer { + contract { callsInPlace(init) } + return LongBuffer(LongArray(size) { init(it) }) +} /** * Returns a new [LongBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt index 1d0c87580..83c50b14b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt @@ -14,10 +14,8 @@ open class MemoryBuffer(protected val memory: Memory, protected val spe private val reader: MemoryReader = memory.reader() - override fun get(index: Int): T = reader.read(spec, spec.objectSize * index) - - override fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() - + override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index) + override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() companion object { fun create(spec: MemorySpec, size: Int): MemoryBuffer = @@ -48,8 +46,7 @@ class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : Memory private val writer: MemoryWriter = memory.writer() - override fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) - + override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) override fun copy(): MutableBuffer = MutableMemoryBuffer(memory.copy(), spec) companion object { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt index 9dfe2b5a8..6cc0a72c0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt @@ -26,19 +26,20 @@ interface NDElement> : NDStructure { fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement = NDField.real(*shape).produce(initializer) - - fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = + inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = real(intArrayOf(dim)) { initializer(it[0]) } + inline fun real2D( + dim1: Int, + dim2: Int, + crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 } + ): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): RealNDElement = - real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - - fun real3D( + inline fun real3D( dim1: Int, dim2: Int, dim3: Int, - initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 } + crossinline initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 } ): RealNDElement = real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } @@ -72,7 +73,6 @@ fun > NDElement.mapIndexed(transform: C.(index fun > NDElement.map(transform: C.(T) -> T): NDElement = context.map(unwrap(), transform).wrap() - /** * Element by element application of any operation on elements to the whole [NDElement] */ @@ -107,7 +107,6 @@ operator fun , N : NDStructure> NDElement.times(arg: operator fun , N : NDStructure> NDElement.div(arg: T): NDElement = map { value -> arg / value } - // /** // * Reverse sum operation // */ diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 9d7735053..f4eb93b9e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -1,5 +1,7 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.jvm.JvmName import kotlin.reflect.KClass @@ -139,9 +141,8 @@ interface MutableNDStructure : NDStructure { } inline fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T) { - elements().forEach { (index, oldValue) -> - this[index] = action(index, oldValue) - } + contract { callsInPlace(action) } + elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } } /** @@ -200,14 +201,12 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides }.toList() } - override fun offset(index: IntArray): Int { - return index.mapIndexed { i, value -> - if (value < 0 || value >= this.shape[i]) { - throw RuntimeException("Index $value out of shape bounds: (0,${this.shape[i]})") - } - value * strides[i] - }.sum() - } + override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> + if (value < 0 || value >= this.shape[i]) + throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") + + value * strides[i] + }.sum() override fun index(offset: Int): IntArray { val res = IntArray(shape.size) @@ -259,7 +258,7 @@ abstract class NDBuffer : NDStructure { */ abstract val strides: Strides - override fun get(index: IntArray): T = buffer[strides.offset(index)] + override operator fun get(index: IntArray): T = buffer[strides.offset(index)] override val shape: IntArray get() = strides.shape @@ -319,13 +318,13 @@ class MutableBufferNDStructure( } } - override fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) + override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) } inline fun NDStructure.combine( struct: NDStructure, crossinline block: (T, T) -> T ): NDStructure { - if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") + require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" } return NDStructure.auto(shape) { block(this[it], struct[it]) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt index e999e12b2..cba8e9689 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [DoubleArray]. * @@ -8,13 +11,13 @@ package scientifik.kmath.structures inline class RealBuffer(val array: DoubleArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Double = array[index] + override operator fun get(index: Int): Double = array[index] - override fun set(index: Int, value: Double) { + override operator fun set(index: Int, value: Double) { array[index] = value } - override fun iterator(): DoubleIterator = array.iterator() + override operator fun iterator(): DoubleIterator = array.iterator() override fun copy(): MutableBuffer = RealBuffer(array.copyOf()) @@ -27,7 +30,10 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) +inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer { + contract { callsInPlace(init) } + return RealBuffer(DoubleArray(size) { init(it) }) +} /** * Returns a new [RealBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 33198aac1..a11826e7e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -10,14 +10,15 @@ import kotlin.math.* */ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun add(a: Buffer, b: Buffer): RealBuffer { - require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + require(b.size == a.size) { + "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + } return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else - RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) + } else RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } override fun multiply(a: Buffer, k: Number): RealBuffer { @@ -26,12 +27,13 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { return if (a is RealBuffer) { val aArray = a.array RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) - } else - RealBuffer(DoubleArray(a.size) { a[it] * kValue }) + } else RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } override fun multiply(a: Buffer, b: Buffer): RealBuffer { - require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + require(b.size == a.size) { + "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + } return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array @@ -42,34 +44,31 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } override fun divide(a: Buffer, b: Buffer): RealBuffer { - require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + require(b.size == a.size) { + "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + } return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else - RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) + } else RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else { - RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - } + } else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array @@ -90,23 +89,50 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } else RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) + override fun sinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { sinh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) + + override fun cosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { cosh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) + + override fun tanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { tanh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) + + override fun asinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { asinh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) + + override fun acosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { acosh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) + + override fun atanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { atanh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) + override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else - RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + } else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } /** @@ -168,6 +194,36 @@ class RealBufferField(val size: Int) : ExtendedField> { return RealBufferFieldOperations.atan(arg) } + override fun sinh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.sinh(arg) + } + + override fun cosh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.cosh(arg) + } + + override fun tanh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.tanh(arg) + } + + override fun asinh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.asinh(arg) + } + + override fun acosh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.acosh(arg) + } + + override fun atanh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.atanh(arg) + } + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index e2a1a33df..6533f64be 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -40,6 +40,7 @@ class RealNDField(override val shape: IntArray) : transform: RealField.(index: IntArray, Double) -> Double ): RealNDElement { check(arg) + return BufferedNDFieldElement( this, buildBuffer(arg.strides.linearSize) { offset -> @@ -71,16 +72,18 @@ class RealNDField(override val shape: IntArray) : override fun ln(arg: NDBuffer): RealNDElement = map(arg) { ln(it) } override fun sin(arg: NDBuffer): RealNDElement = map(arg) { sin(it) } - override fun cos(arg: NDBuffer): RealNDElement = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): RealNDElement = map(arg) { tan(it) } + override fun asin(arg: NDBuffer): RealNDElement = map(arg) { asin(it) } + override fun acos(arg: NDBuffer): RealNDElement = map(arg) { acos(it) } + override fun atan(arg: NDBuffer): RealNDElement = map(arg) { atan(it) } - override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } - - override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } - - override fun acos(arg: NDBuffer): NDBuffer = map(arg) { acos(it) } - - override fun atan(arg: NDBuffer): NDBuffer = map(arg) { atan(it) } + override fun sinh(arg: NDBuffer): RealNDElement = map(arg) { sinh(it) } + override fun cosh(arg: NDBuffer): RealNDElement = map(arg) { cosh(it) } + override fun tanh(arg: NDBuffer): RealNDElement = map(arg) { tanh(it) } + override fun asinh(arg: NDBuffer): RealNDElement = map(arg) { asinh(it) } + override fun acosh(arg: NDBuffer): RealNDElement = map(arg) { acosh(it) } + override fun atanh(arg: NDBuffer): RealNDElement = map(arg) { atanh(it) } } @@ -130,6 +133,5 @@ operator fun RealNDElement.minus(arg: Double): RealNDElement = /** * Produce a context for n-dimensional operations inside this real field */ -inline fun RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R { - return NDField.real(*shape).run(action) -} + +inline fun RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt index c6f19feaf..9aa674177 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt @@ -1,5 +1,8 @@ package scientifik.kmath.structures +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + /** * Specialized [MutableBuffer] implementation over [ShortArray]. * @@ -8,17 +11,16 @@ package scientifik.kmath.structures inline class ShortBuffer(val array: ShortArray) : MutableBuffer { override val size: Int get() = array.size - override fun get(index: Int): Short = array[index] + override operator fun get(index: Int): Short = array[index] - override fun set(index: Int, value: Short) { + override operator fun set(index: Int, value: Short) { array[index] = value } - override fun iterator(): ShortIterator = array.iterator() + override operator fun iterator(): ShortIterator = array.iterator() override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) - } /** @@ -28,7 +30,10 @@ inline class ShortBuffer(val array: ShortArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) }) +inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer { + contract { callsInPlace(init) } + return ShortBuffer(ShortArray(size) { init(it) }) +} /** * Returns a new [ShortBuffer] of given elements. diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt index faf022367..a796c2037 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt @@ -6,12 +6,12 @@ package scientifik.kmath.structures interface Structure1D : NDStructure, Buffer { override val dimension: Int get() = 1 - override fun get(index: IntArray): T { - if (index.size != 1) error("Index dimension mismatch. Expected 1 but found ${index.size}") + override operator fun get(index: IntArray): T { + require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" } return get(index[0]) } - override fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() + override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() } /** @@ -22,7 +22,7 @@ private inline class Structure1DWrapper(val structure: NDStructure) : Stru override val shape: IntArray get() = structure.shape override val size: Int get() = structure.shape[0] - override fun get(index: Int): T = structure[index] + override operator fun get(index: Int): T = structure[index] override fun elements(): Sequence> = structure.elements() } @@ -39,7 +39,7 @@ private inline class Buffer1DWrapper(val buffer: Buffer) : Structure1D override fun elements(): Sequence> = asSequence().mapIndexed { index, value -> intArrayOf(index) to value } - override fun get(index: Int): T = buffer[index] + override operator fun get(index: Int): T = buffer[index] } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt index 30fd556d3..eeb6bd3dc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt @@ -9,8 +9,8 @@ interface Structure2D : NDStructure { operator fun get(i: Int, j: Int): T - override fun get(index: IntArray): T { - if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}") + override operator fun get(index: IntArray): T { + require(index.size == 2) { "Index dimension mismatch. Expected 2 but found ${index.size}" } return get(index[0], index[1]) } @@ -39,10 +39,10 @@ interface Structure2D : NDStructure { * A 2D wrapper for nd-structure */ private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { - override fun get(i: Int, j: Int): T = structure[i, j] - override val shape: IntArray get() = structure.shape + override operator fun get(i: Int, j: Int): T = structure[i, j] + override fun elements(): Sequence> = structure.elements() } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index 22b924ef9..485de08b4 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -3,6 +3,7 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals @@ -10,10 +11,12 @@ class ExpressionFieldTest { @Test fun testExpression() { val context = FunctionalExpressionField(RealField) - val expression = with(context) { + + val expression = context { val x = variable("x", 2.0) x * x + 2 * x + one } + assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression(), 9.0) } @@ -21,10 +24,12 @@ class ExpressionFieldTest { @Test fun testComplex() { val context = FunctionalExpressionField(ComplexField) - val expression = with(context) { + + val expression = context { val x = variable("x", Complex(2.0, 0.0)) x * x + 2 * x + one } + assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) assertEquals(expression(), Complex(9.0, 0.0)) } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt index 987426250..52a2f80a6 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt @@ -7,7 +7,6 @@ import kotlin.test.Test import kotlin.test.assertEquals class MatrixTest { - @Test fun testTranspose() { val matrix = MatrixContext.real.one(3, 3) @@ -51,6 +50,7 @@ class MatrixTest { fun test2DDot() { val firstMatrix = NDStructure.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D() val secondMatrix = NDStructure.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D() + MatrixContext.real.run { // val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() } // val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt index c22d2f27b..d140f1017 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt @@ -1,9 +1,13 @@ package scientifik.kmath.operations +import scientifik.kmath.operations.internal.RingVerifier import kotlin.test.Test import kotlin.test.assertEquals -class BigIntAlgebraTest { +internal class BigIntAlgebraTest { + @Test + fun verify() = BigIntField { RingVerifier(this, +"42", +"10", +"-12", 10).verify() } + @Test fun testKBigIntegerRingSum() { val res = BigIntField { diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexFieldTest.kt new file mode 100644 index 000000000..2c480ebea --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexFieldTest.kt @@ -0,0 +1,77 @@ +package scientifik.kmath.operations + +import scientifik.kmath.operations.internal.FieldVerifier +import kotlin.math.PI +import kotlin.math.abs +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +internal class ComplexFieldTest { + @Test + fun verify() = ComplexField { FieldVerifier(this, 42.0 * i, 66.0 + 28 * i, 2.0 + 0 * i, 5).verify() } + + @Test + fun testAddition() { + assertEquals(Complex(42, 42), ComplexField { Complex(16, 16) + Complex(26, 26) }) + assertEquals(Complex(42, 16), ComplexField { Complex(16, 16) + 26 }) + assertEquals(Complex(42, 16), ComplexField { 26 + Complex(16, 16) }) + } + + @Test + fun testSubtraction() { + assertEquals(Complex(42, 42), ComplexField { Complex(86, 55) - Complex(44, 13) }) + assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 }) + assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) }) + } + + @Test + fun testMultiplication() { + assertEquals(Complex(42, 42), ComplexField { Complex(4.2, 0) * Complex(10, 10) }) + assertEquals(Complex(42, 21), ComplexField { Complex(4.2, 2.1) * 10 }) + assertEquals(Complex(42, 21), ComplexField { 10 * Complex(4.2, 2.1) }) + } + + @Test + fun testDivision() { + assertEquals(Complex(42, 42), ComplexField { Complex(0, 168) / Complex(2, 2) }) + assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 }) + assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) }) + assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(Double.NaN, Double.NaN) }) + assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) }) + } + + @Test + fun testSine() { + assertEquals(ComplexField { i * sinh(one) }, ComplexField { sin(i) }) + assertEquals(ComplexField { i * sinh(PI.toComplex()) }, ComplexField { sin(i * PI.toComplex()) }) + } + + @Test + fun testInverseSine() { + assertEquals(Complex(0, -0.0), ComplexField { asin(zero) }) + assertTrue(abs(ComplexField { i * asinh(one) }.r - ComplexField { asin(i) }.r) < 0.000000000000001) + } + + @Test + fun testInverseHyperbolicSine() { + assertEquals( + ComplexField { i * PI.toComplex() / 2 }, + ComplexField { asinh(i) }) + } + + @Test + fun testPower() { + assertEquals(ComplexField.zero, ComplexField { zero pow 2 }) + assertEquals(ComplexField.zero, ComplexField { zero pow 2 }) + + assertEquals( + ComplexField { i * 8 }.let { it.im.toInt() to it.re.toInt() }, + ComplexField { Complex(2, 2) pow 2 }.let { it.im.toInt() to it.re.toInt() }) + } + + @Test + fun testNorm() { + assertEquals(2.toComplex(), ComplexField { norm(2 * i) }) + } +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt new file mode 100644 index 000000000..e8d698c70 --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.operations + +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class ComplexTest { + @Test + fun conjugate() { + assertEquals( + Complex(0, -42), (ComplexField.i * 42).conjugate + ) + } + + @Test + fun reciprocal() { + assertEquals(Complex(0.5, -0.0), 2.toComplex().reciprocal) + } + + @Test + fun r() { + assertEquals(kotlin.math.sqrt(2.0), (ComplexField.i + 1.0.toComplex()).r) + } + + @Test + fun theta() { + assertEquals(0.0, 1.toComplex().theta) + } + + @Test + fun toComplex() { + assertEquals(Complex(42, 0), 42.toComplex()) + assertEquals(Complex(42.0, 0), 42.0.toComplex()) + assertEquals(Complex(42f, 0), 42f.toComplex()) + assertEquals(Complex(42.0, 0), 42.0.toComplex()) + assertEquals(Complex(42.toByte(), 0), 42.toByte().toComplex()) + assertEquals(Complex(42.toShort(), 0), 42.toShort().toComplex()) + } +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt index 9dfa3bdd1..a168b4afd 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt @@ -1,14 +1,16 @@ package scientifik.kmath.operations +import scientifik.kmath.operations.internal.FieldVerifier import kotlin.test.Test import kotlin.test.assertEquals -class RealFieldTest { +internal class RealFieldTest { + @Test + fun verify() = FieldVerifier(RealField, 42.0, 66.0, 2.0, 5).verify() + @Test fun testSqrt() { - val sqrt = RealField { - sqrt(25 * one) - } + val sqrt = RealField { sqrt(25 * one) } assertEquals(5.0, sqrt) } } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/AlgebraicVerifier.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/AlgebraicVerifier.kt new file mode 100644 index 000000000..cb097d46e --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/AlgebraicVerifier.kt @@ -0,0 +1,9 @@ +package scientifik.kmath.operations.internal + +import scientifik.kmath.operations.Algebra + +internal interface AlgebraicVerifier where A : Algebra { + val algebra: A + + fun verify() +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/FieldVerifier.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/FieldVerifier.kt new file mode 100644 index 000000000..973fd00b1 --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/FieldVerifier.kt @@ -0,0 +1,24 @@ +package scientifik.kmath.operations.internal + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals + +internal class FieldVerifier(override val algebra: Field, a: T, b: T, c: T, x: Number) : + RingVerifier(algebra, a, b, c, x) { + + override fun verify() { + super.verify() + + algebra { + assertNotEquals(a / b, b / a, "Division in $algebra is not anti-commutative.") + assertNotEquals((a / b) / c, a / (b / c), "Division in $algebra is associative.") + assertEquals((a + b) / c, (a / c) + (b / c), "Division in $algebra is not right-distributive.") + assertEquals(a, a / one, "$one in $algebra is not neutral division element.") + assertEquals(one, one / a * a, "$algebra does not provide single reciprocal element.") + assertEquals(zero / a, zero, "$zero in $algebra is not left neutral element for division.") + assertEquals(-one, a / (-a), "Division by sign reversal element in $algebra does not give ${-one}.") + } + } +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/RingVerifier.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/RingVerifier.kt new file mode 100644 index 000000000..047a213e9 --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/RingVerifier.kt @@ -0,0 +1,28 @@ +package scientifik.kmath.operations.internal + +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.invoke +import kotlin.test.assertEquals + +internal open class RingVerifier(override val algebra: Ring, a: T, b: T, c: T, x: Number) : + SpaceVerifier(algebra, a, b, c, x) { + override fun verify() { + super.verify() + + algebra { + assertEquals(a * b, a * b, "Multiplication in $algebra is not commutative.") + assertEquals(a * b * c, a * (b * c), "Multiplication in $algebra is not associative.") + assertEquals(c * (a + b), (c * a) + (c * b), "Multiplication in $algebra is not distributive.") + assertEquals(a * one, one * a, "$one in $algebra is not a neutral multiplication element.") + assertEquals(a, one * a, "$one in $algebra is not a neutral multiplication element.") + assertEquals(a, a * one, "$one in $algebra is not a neutral multiplication element.") + assertEquals(a, one * a, "$one in $algebra is not a neutral multiplication element.") + assertEquals(a, a * one * one, "Multiplication by $one in $algebra is not idempotent.") + assertEquals(a, a * one * one * one, "Multiplication by $one in $algebra is not idempotent.") + assertEquals(a, a * one * one * one * one, "Multiplication by $one in $algebra is not idempotent.") + assertEquals(zero, a * zero, "Multiplication by $zero in $algebra doesn't give $zero.") + assertEquals(zero, zero * a, "Multiplication by $zero in $algebra doesn't give $zero.") + assertEquals(a * zero, a * zero, "Multiplication by $zero in $algebra doesn't give $zero.") + } + } +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/SpaceVerifier.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/SpaceVerifier.kt new file mode 100644 index 000000000..bc241c97d --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/SpaceVerifier.kt @@ -0,0 +1,33 @@ +package scientifik.kmath.operations.internal + +import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals + +internal open class SpaceVerifier( + override val algebra: Space, + val a: T, + val b: T, + val c: T, + val x: Number +) : + AlgebraicVerifier> { + override fun verify() { + algebra { + assertEquals(a + b, b + a, "Addition in $algebra is not commutative.") + assertEquals(a + b + c, a + (b + c), "Addition in $algebra is not associative.") + assertEquals(x * (a + b), x * a + x * b, "Addition in $algebra is not distributive.") + assertEquals((a + b) * x, a * x + b * x, "Addition in $algebra is not distributive.") + assertEquals(a + zero, zero + a, "$zero in $algebra is not a neutral addition element.") + assertEquals(a, a + zero, "$zero in $algebra is not a neutral addition element.") + assertEquals(a, zero + a, "$zero in $algebra is not a neutral addition element.") + assertEquals(a - b, -(b - a), "Subtraction in $algebra is not anti-commutative.") + assertNotEquals(a - b - c, a - (b - c), "Subtraction in $algebra is associative.") + assertEquals(x * (a - b), x * a - x * b, "Subtraction in $algebra is not distributive.") + assertEquals(a, a - zero, "$zero in $algebra is not a neutral addition element.") + assertEquals(a * x, x * a, "Multiplication by scalar in $algebra is not commutative.") + assertEquals(x * (a + b), (x * a) + (x * b), "Multiplication by scalar in $algebra is not distributive.") + } + } +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt index d48aabfd0..b7e2594ec 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt @@ -1,6 +1,7 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Norm +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.NDElement.Companion.real2D import kotlin.math.abs import kotlin.math.pow @@ -56,17 +57,12 @@ class NumberNDFieldTest { } object L2Norm : Norm, Double> { - override fun norm(arg: NDStructure): Double { - return kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() }) - } + override fun norm(arg: NDStructure): Double = + kotlin.math.sqrt(arg.elements().sumByDouble { it.second.toDouble() }) } @Test fun testInternalContext() { - NDField.real(*array1.shape).run { - with(L2Norm) { - 1 + norm(array1) + exp(array2) - } - } + (NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } } } } diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt index 06f2b31ad..f10ef24da 100644 --- a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt @@ -17,10 +17,10 @@ object JBigIntegerField : Field { override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) - override fun BigInteger.minus(b: BigInteger): BigInteger = this.subtract(b) + override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) - override fun BigInteger.unaryMinus(): BigInteger = negate() + override operator fun BigInteger.unaryMinus(): BigInteger = negate() } /** @@ -38,7 +38,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo get() = BigDecimal.ONE override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) - override fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) + override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun multiply(a: BigDecimal, k: Number): BigDecimal = @@ -48,8 +48,7 @@ abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathCo override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) - override fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) - + override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) } /** diff --git a/kmath-coroutines/build.gradle.kts b/kmath-coroutines/build.gradle.kts index 373d9b8ac..4469a9ef6 100644 --- a/kmath-coroutines/build.gradle.kts +++ b/kmath-coroutines/build.gradle.kts @@ -4,20 +4,27 @@ plugins { } kotlin.sourceSets { + all { + with(languageSettings) { + useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") + useExperimentalAnnotation("kotlinx.coroutines.InternalCoroutinesApi") + useExperimentalAnnotation("kotlinx.coroutines.ExperimentalCoroutinesApi") + useExperimentalAnnotation("kotlinx.coroutines.FlowPreview") + } + } + commonMain { dependencies { api(project(":kmath-core")) api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Scientifik.coroutinesVersion}") } } + jvmMain { - dependencies { - api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Scientifik.coroutinesVersion}") - } + dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Scientifik.coroutinesVersion}") } } + jsMain { - dependencies { - api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Scientifik.coroutinesVersion}") - } + dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Scientifik.coroutinesVersion}") } } } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 6cc9770af..f0ffd13cd 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -16,9 +16,9 @@ package scientifik.kmath.chains -import kotlinx.coroutines.InternalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -37,14 +37,8 @@ interface Chain : Flow { */ fun fork(): Chain - @OptIn(InternalCoroutinesApi::class) - override suspend fun collect(collector: FlowCollector) { - kotlinx.coroutines.flow.flow { - while (true) { - emit(next()) - } - }.collect(collector) - } + override suspend fun collect(collector: FlowCollector): Unit = + flow { while (true) emit(next()) }.collect(collector) companion object } @@ -139,9 +133,10 @@ fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain { override suspend fun next(): T { var next: T - do { - next = this@filter.next() - } while (!block(next)) + + do next = this@filter.next() + while (!block(next)) + return next } @@ -159,6 +154,7 @@ fun Chain.collect(mapper: suspend (Chain) -> R): Chain = object fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain) -> R): Chain = object : Chain { override suspend fun next(): R = state.mapper(this@collectWithState) + override fun fork(): Chain = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) } @@ -168,6 +164,5 @@ fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: s */ fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = object : Chain { override suspend fun next(): R = block(this@zip.next(), other.next()) - override fun fork(): Chain = this@zip.fork().zip(other.fork(), block) } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt index e8537304c..5db660c39 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt @@ -7,15 +7,15 @@ import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.scanReduce import scientifik.kmath.operations.Space import scientifik.kmath.operations.SpaceOperations - +import scientifik.kmath.operations.invoke @ExperimentalCoroutinesApi -fun Flow.cumulativeSum(space: SpaceOperations): Flow = with(space) { +fun Flow.cumulativeSum(space: SpaceOperations): Flow = space { scanReduce { sum: T, element: T -> sum + element } } @ExperimentalCoroutinesApi -fun Flow.mean(space: Space): Flow = with(space) { +fun Flow.mean(space: Space): Flow = space { class Accumulator(var sum: T, var num: Int) scan(Accumulator(zero, 0)) { sum, element -> diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt index 7e00b30a1..692f89589 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt @@ -3,6 +3,7 @@ package scientifik.kmath.coroutines import kotlinx.coroutines.* import kotlinx.coroutines.channels.produce import kotlinx.coroutines.flow.* +import kotlin.contracts.contract val Dispatchers.Math: CoroutineDispatcher get() = Default @@ -23,15 +24,11 @@ internal class LazyDeferred(val dispatcher: CoroutineDispatcher, val block: s } class AsyncFlow internal constructor(internal val deferredFlow: Flow>) : Flow { - @InternalCoroutinesApi override suspend fun collect(collector: FlowCollector) { - deferredFlow.collect { - collector.emit((it.await())) - } + deferredFlow.collect { collector.emit((it.await())) } } } -@FlowPreview fun Flow.async( dispatcher: CoroutineDispatcher = Dispatchers.Default, block: suspend CoroutineScope.(T) -> R @@ -42,7 +39,6 @@ fun Flow.async( return AsyncFlow(flow) } -@FlowPreview fun AsyncFlow.map(action: (T) -> R): AsyncFlow = AsyncFlow(deferredFlow.map { input -> //TODO add function composition @@ -52,10 +48,9 @@ fun AsyncFlow.map(action: (T) -> R): AsyncFlow = } }) -@ExperimentalCoroutinesApi -@FlowPreview suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector) { require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" } + coroutineScope { //Starting up to N deferred coroutines ahead of time val channel = produce(capacity = concurrency - 1) { @@ -81,21 +76,18 @@ suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector< } } -@ExperimentalCoroutinesApi -@FlowPreview -suspend fun AsyncFlow.collect(concurrency: Int, action: suspend (value: T) -> Unit) { +suspend inline fun AsyncFlow.collect(concurrency: Int, crossinline action: suspend (value: T) -> Unit) { + contract { callsInPlace(action) } + collect(concurrency, object : FlowCollector { override suspend fun emit(value: T): Unit = action(value) }) } -@ExperimentalCoroutinesApi -@FlowPreview -fun Flow.mapParallel( +inline fun Flow.mapParallel( dispatcher: CoroutineDispatcher = Dispatchers.Default, - transform: suspend (T) -> R + crossinline transform: suspend (T) -> R ): Flow { - return flatMapMerge { value -> - flow { emit(transform(value)) } - }.flowOn(dispatcher) + contract { callsInPlace(transform) } + return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher) } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt index 245d003b3..f1c0bfc6a 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt @@ -20,7 +20,7 @@ class RingBuffer( override var size: Int = size private set - override fun get(index: Int): T { + override operator fun get(index: Int): T { require(index >= 0) { "Index must be positive" } require(index < size) { "Index $index is out of circular buffer size $size" } return buffer[startIndex.forward(index)] as T @@ -31,15 +31,13 @@ class RingBuffer( /** * Iterator could provide wrong results if buffer is changed in initialization (iteration is safe) */ - override fun iterator(): Iterator = object : AbstractIterator() { + override operator fun iterator(): Iterator = object : AbstractIterator() { private var count = size private var index = startIndex val copy = buffer.copy() override fun computeNext() { - if (count == 0) { - done() - } else { + if (count == 0) done() else { setNext(copy[index] as T) index = index.forward(1) count-- diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt index 0a3c67e00..5686b0ac0 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt @@ -1,7 +1,6 @@ package scientifik.kmath.chains import kotlinx.coroutines.runBlocking -import kotlin.sequences.Sequence /** * Represent a chain as regular iterator (uses blocking calls) @@ -15,6 +14,4 @@ operator fun Chain.iterator(): Iterator = object : Iterator { /** * Represent a chain as a sequence */ -fun Chain.asSequence(): Sequence = object : Sequence { - override fun iterator(): Iterator = this@asSequence.iterator() -} \ No newline at end of file +fun Chain.asSequence(): Sequence = Sequence { this@asSequence.iterator() } \ No newline at end of file diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt index 8d5145976..ff732a06b 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt @@ -18,7 +18,7 @@ class LazyNDStructure( suspend fun await(index: IntArray): T = deferred(index).await() - override fun get(index: IntArray): T = runBlocking { + override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() } @@ -52,10 +52,12 @@ suspend fun NDStructure.await(index: IntArray): T = /** * PENDING would benefit from KEEP-176 */ -fun NDStructure.mapAsyncIndexed( +inline fun NDStructure.mapAsyncIndexed( scope: CoroutineScope, - function: suspend (T, index: IntArray) -> R + crossinline function: suspend (T, index: IntArray) -> R ): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index), index) } -fun NDStructure.mapAsync(scope: CoroutineScope, function: suspend (T) -> R): LazyNDStructure = - LazyNDStructure(scope, shape) { index -> function(get(index)) } \ No newline at end of file +inline fun NDStructure.mapAsync( + scope: CoroutineScope, + crossinline function: suspend (T) -> R +): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index)) } diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt index f447866c0..7b0244bdf 100644 --- a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt @@ -4,7 +4,9 @@ import scientifik.kmath.linear.GenericMatrixContext import scientifik.kmath.linear.MatrixContext import scientifik.kmath.linear.Point import scientifik.kmath.linear.transpose +import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.Structure2D @@ -42,7 +44,7 @@ inline class DMatrixWrapper( val structure: Structure2D ) : DMatrix { override val shape: IntArray get() = structure.shape - override fun get(i: Int, j: Int): T = structure[i, j] + override operator fun get(i: Int, j: Int): T = structure[i, j] } /** @@ -70,9 +72,9 @@ inline class DPointWrapper(val point: Point) : DPoint { override val size: Int get() = point.size - override fun get(index: Int): T = point[index] + override operator fun get(index: Int): T = point[index] - override fun iterator(): Iterator = point.iterator() + override operator fun iterator(): Iterator = point.iterator() } @@ -82,12 +84,14 @@ inline class DPointWrapper(val point: Point) : inline class DMatrixContext>(val context: GenericMatrixContext) { inline fun Matrix.coerce(): DMatrix { - if (rowNum != Dimension.dim().toInt()) { - error("Row number mismatch: expected ${Dimension.dim()} but found $rowNum") - } - if (colNum != Dimension.dim().toInt()) { - error("Column number mismatch: expected ${Dimension.dim()} but found $colNum") - } + check( + rowNum == Dimension.dim().toInt() + ) { "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" } + + check( + colNum == Dimension.dim().toInt() + ) { "Column number mismatch: expected ${Dimension.dim()} but found $colNum" } + return DMatrix.coerceUnsafe(this) } @@ -97,11 +101,12 @@ inline class DMatrixContext>(val context: GenericMatrixCon inline fun produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix { val rows = Dimension.dim() val cols = Dimension.dim() - return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() + return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() } inline fun point(noinline initializer: (Int) -> T): DPoint { val size = Dimension.dim() + return DPoint.coerceUnsafe( context.point( size.toInt(), @@ -112,37 +117,28 @@ inline class DMatrixContext>(val context: GenericMatrixCon inline infix fun DMatrix.dot( other: DMatrix - ): DMatrix { - return context.run { this@dot dot other }.coerce() - } + ): DMatrix = context { this@dot dot other }.coerce() - inline infix fun DMatrix.dot(vector: DPoint): DPoint { - return DPoint.coerceUnsafe(context.run { this@dot dot vector }) - } + inline infix fun DMatrix.dot(vector: DPoint): DPoint = + DPoint.coerceUnsafe(context { this@dot dot vector }) - inline operator fun DMatrix.times(value: T): DMatrix { - return context.run { this@times.times(value) }.coerce() - } + inline operator fun DMatrix.times(value: T): DMatrix = + context { this@times.times(value) }.coerce() inline operator fun T.times(m: DMatrix): DMatrix = m * this + inline operator fun DMatrix.plus(other: DMatrix): DMatrix = + context { this@plus + other }.coerce() - inline operator fun DMatrix.plus(other: DMatrix): DMatrix { - return context.run { this@plus + other }.coerce() - } + inline operator fun DMatrix.minus(other: DMatrix): DMatrix = + context { this@minus + other }.coerce() - inline operator fun DMatrix.minus(other: DMatrix): DMatrix { - return context.run { this@minus + other }.coerce() - } + inline operator fun DMatrix.unaryMinus(): DMatrix = + context { this@unaryMinus.unaryMinus() }.coerce() - inline operator fun DMatrix.unaryMinus(): DMatrix { - return context.run { this@unaryMinus.unaryMinus() }.coerce() - } - - inline fun DMatrix.transpose(): DMatrix { - return context.run { (this@transpose as Matrix).transpose() }.coerce() - } + inline fun DMatrix.transpose(): DMatrix = + context { (this@transpose as Matrix).transpose() }.coerce() /** * A square unit matrix @@ -156,6 +152,6 @@ inline class DMatrixContext>(val context: GenericMatrixCon } companion object { - val real = DMatrixContext(MatrixContext.real) + val real: DMatrixContext = DMatrixContext(MatrixContext.real) } -} \ No newline at end of file +} diff --git a/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt index 74d20205c..8dabdeeac 100644 --- a/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt @@ -5,11 +5,10 @@ import scientifik.kmath.dimensions.D3 import scientifik.kmath.dimensions.DMatrixContext import kotlin.test.Test - class DMatrixContextTest { @Test fun testDimensionSafeMatrix() { - val res = DMatrixContext.real.run { + val res = with(DMatrixContext.real) { val m = produce { i, j -> (i + j).toDouble() } //The dimension of `one()` is inferred from type @@ -19,7 +18,7 @@ class DMatrixContextTest { @Test fun testTypeCheck() { - val res = DMatrixContext.real.run { + val res = with(DMatrixContext.real) { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i + j).toDouble() } diff --git a/kmath-for-real/build.gradle.kts b/kmath-for-real/build.gradle.kts index a8a8975bc..46d2682f7 100644 --- a/kmath-for-real/build.gradle.kts +++ b/kmath-for-real/build.gradle.kts @@ -1,11 +1,6 @@ -plugins { - id("scientifik.mpp") -} +plugins { id("scientifik.mpp") } kotlin.sourceSets { - commonMain { - dependencies { - api(project(":kmath-core")) - } - } -} \ No newline at end of file + all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") } + commonMain { dependencies { api(project(":kmath-core")) } } +} diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt index 2b89904e3..811b54d7c 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt @@ -14,8 +14,8 @@ import kotlin.math.sqrt typealias RealPoint = Point -fun DoubleArray.asVector() = RealVector(this.asBuffer()) -fun List.asVector() = RealVector(this.asBuffer()) +fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer()) +fun List.asVector(): RealVector = RealVector(this.asBuffer()) object VectorL2Norm : Norm, Double> { override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) @@ -32,15 +32,14 @@ inline class RealVector(private val point: Point) : override val size: Int get() = point.size - override fun get(index: Int): Double = point[index] + override operator fun get(index: Int): Double = point[index] - override fun iterator(): Iterator = point.iterator() + override operator fun iterator(): Iterator = point.iterator() companion object { + private val spaceCache: MutableMap> = hashMapOf() - private val spaceCache = HashMap>() - - inline operator fun invoke(dim: Int, initializer: (Int) -> Double) = + inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector = RealVector(RealBuffer(dim, initializer)) operator fun invoke(vararg values: Double): RealVector = values.asVector() @@ -49,4 +48,4 @@ inline class RealVector(private val point: Point) : BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } } } -} \ No newline at end of file +} diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt index 65f86eec7..3752fc3ca 100644 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt @@ -3,11 +3,14 @@ package scientifik.kmath.real import scientifik.kmath.linear.MatrixContext import scientifik.kmath.linear.RealMatrixContext.elementContext import scientifik.kmath.linear.VirtualMatrix +import scientifik.kmath.operations.invoke import scientifik.kmath.operations.sum import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.RealBuffer import scientifik.kmath.structures.asIterable +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract import kotlin.math.pow /* @@ -27,7 +30,7 @@ typealias RealMatrix = Matrix fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) -fun Array.toMatrix(): RealMatrix{ +fun Array.toMatrix(): RealMatrix { return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } } @@ -117,13 +120,16 @@ operator fun Matrix.minus(other: Matrix): RealMatrix = * Operations on columns */ -inline fun Matrix.appendColumn(crossinline mapper: (Buffer) -> Double) = - MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> +inline fun Matrix.appendColumn(crossinline mapper: (Buffer) -> Double): Matrix { + contract { callsInPlace(mapper) } + + return MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> if (col < colNum) this[row, col] else mapper(rows[row]) } +} fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = MatrixContext.real.produce(rowNum, columnRange.count()) { row, col -> @@ -135,17 +141,15 @@ fun Matrix.extractColumn(columnIndex: Int): RealMatrix = fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> val column = columns[j] - with(elementContext) { - sum(column.asIterable()) - } + elementContext { sum(column.asIterable()) } } fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> - columns[j].asIterable().min() ?: throw Exception("Cannot produce min on empty column") + columns[j].asIterable().min() ?: error("Cannot produce min on empty column") } fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> - columns[j].asIterable().max() ?: throw Exception("Cannot produce min on empty column") + columns[j].asIterable().max() ?: error("Cannot produce min on empty column") } fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> @@ -156,10 +160,7 @@ fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> * Operations processing all elements */ -fun Matrix.sum() = elements().map { (_, value) -> value }.sum() - -fun Matrix.min() = elements().map { (_, value) -> value }.min() - -fun Matrix.max() = elements().map { (_, value) -> value }.max() - -fun Matrix.average() = elements().map { (_, value) -> value }.average() +fun Matrix.sum(): Double = elements().map { (_, value) -> value }.sum() +fun Matrix.min(): Double? = elements().map { (_, value) -> value }.min() +fun Matrix.max(): Double? = elements().map { (_, value) -> value }.max() +fun Matrix.average(): Double = elements().map { (_, value) -> value }.average() diff --git a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt b/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt index 28e62b066..ef7f40afe 100644 --- a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt @@ -1,5 +1,6 @@ package scientifik.kmath.linear +import scientifik.kmath.operations.invoke import scientifik.kmath.real.RealVector import kotlin.test.Test import kotlin.test.assertEquals @@ -24,14 +25,10 @@ class VectorTest { fun testDot() { val vector1 = RealVector(5) { it.toDouble() } val vector2 = RealVector(5) { 5 - it.toDouble() } - val matrix1 = vector1.asMatrix() val matrix2 = vector2.asMatrix().transpose() - val product = MatrixContext.real.run { matrix1 dot matrix2 } - - + val product = MatrixContext.real { matrix1 dot matrix2 } assertEquals(5.0, product[1, 0]) assertEquals(6.0, product[2, 2]) } - -} \ No newline at end of file +} diff --git a/kmath-functions/build.gradle.kts b/kmath-functions/build.gradle.kts index 4c158a32e..46d2682f7 100644 --- a/kmath-functions/build.gradle.kts +++ b/kmath-functions/build.gradle.kts @@ -1,11 +1,6 @@ -plugins { - id("scientifik.mpp") -} +plugins { id("scientifik.mpp") } kotlin.sourceSets { - commonMain { - dependencies { - api(project(":kmath-core")) - } - } + all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") } + commonMain { dependencies { api(project(":kmath-core")) } } } diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt index b747b521d..c4470ad27 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt @@ -2,6 +2,10 @@ package scientifik.kmath.functions import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.math.max import kotlin.math.pow @@ -13,20 +17,21 @@ inline class Polynomial(val coefficients: List) { constructor(vararg coefficients: T) : this(coefficients.toList()) } -fun Polynomial.value() = +fun Polynomial.value(): Double = coefficients.reduceIndexed { index: Int, acc: Double, d: Double -> acc + d.pow(index) } - -fun > Polynomial.value(ring: C, arg: T): T = ring.run { - if (coefficients.isEmpty()) return@run zero +fun > Polynomial.value(ring: C, arg: T): T = ring { + if (coefficients.isEmpty()) return@ring zero var res = coefficients.first() var powerArg = arg + for (index in 1 until coefficients.size) { res += coefficients[index] * powerArg //recalculating power on each step to avoid power costs on long polynomials powerArg *= arg } - return@run res + + res } /** @@ -34,7 +39,7 @@ fun > Polynomial.value(ring: C, arg: T): T = ring.run { */ fun > Polynomial.asMathFunction(): MathFunction = object : MathFunction { - override fun C.invoke(arg: T): T = value(this, arg) + override operator fun C.invoke(arg: T): T = value(this, arg) } /** @@ -49,18 +54,16 @@ class PolynomialSpace>(val ring: C) : Space> override fun add(a: Polynomial, b: Polynomial): Polynomial { val dim = max(a.coefficients.size, b.coefficients.size) - ring.run { - return Polynomial(List(dim) { index -> + + return ring { + Polynomial(List(dim) { index -> a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } }) } } - override fun multiply(a: Polynomial, k: Number): Polynomial { - ring.run { - return Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) - } - } + override fun multiply(a: Polynomial, k: Number): Polynomial = + ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) } override val zero: Polynomial = Polynomial(emptyList()) @@ -68,6 +71,7 @@ class PolynomialSpace>(val ring: C) : Space> operator fun Polynomial.invoke(arg: T): T = value(ring, arg) } -fun , R> C.polynomial(block: PolynomialSpace.() -> R): R { - return PolynomialSpace(this).run(block) -} \ No newline at end of file +inline fun , R> C.polynomial(block: PolynomialSpace.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return PolynomialSpace(this).block() +} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt index 98beb4391..a7925180d 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt @@ -4,13 +4,13 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial import scientifik.kmath.functions.PiecewisePolynomial import scientifik.kmath.functions.Polynomial import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke /** * Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java */ class LinearInterpolator>(override val algebra: Field) : PolynomialInterpolator { - - override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra.run { + override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { require(points.size > 0) { "Point array should not be empty" } insureSorted(points) @@ -23,4 +23,4 @@ class LinearInterpolator>(override val algebra: Field) : Po } } } -} \ No newline at end of file +} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt index e1af0c1a2..b709c4e87 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt @@ -4,6 +4,7 @@ import scientifik.kmath.functions.OrderedPiecewisePolynomial import scientifik.kmath.functions.PiecewisePolynomial import scientifik.kmath.functions.Polynomial import scientifik.kmath.operations.Field +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.MutableBufferFactory /** @@ -17,7 +18,7 @@ class SplineInterpolator>( //TODO possibly optimize zeroed buffers - override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra.run { + override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { if (points.size < 3) { error("Can't use spline interpolator with less than 3 points") } diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt index d8e10b880..56953f9fc 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt +++ b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt @@ -14,9 +14,7 @@ interface XYZPointSet : XYPointSet { } internal fun > insureSorted(points: XYPointSet) { - for (i in 0 until points.size - 1) { - if (points.x[i + 1] <= points.x[i]) error("Input data is not sorted at index $i") - } + for (i in 0 until points.size - 1) require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" } } class NDStructureColumn(val structure: Structure2D, val column: Int) : Buffer { @@ -26,9 +24,9 @@ class NDStructureColumn(val structure: Structure2D, val column: Int) : Buf override val size: Int get() = structure.rowNum - override fun get(index: Int): T = structure[index, column] + override operator fun get(index: Int): T = structure[index, column] - override fun iterator(): Iterator = sequence { + override operator fun iterator(): Iterator = sequence { repeat(size) { yield(get(it)) } diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt index 2313b2170..f0dc49882 100644 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt @@ -9,25 +9,21 @@ import kotlin.math.sqrt interface Vector2D : Point, Vector, SpaceElement { val x: Double val y: Double - + override val context: Euclidean2DSpace get() = Euclidean2DSpace override val size: Int get() = 2 - override fun get(index: Int): Double = when (index) { + override operator fun get(index: Int): Double = when (index) { 1 -> x 2 -> y else -> error("Accessing outside of point bounds") } - override fun iterator(): Iterator = listOf(x, y).iterator() - - override val context: Euclidean2DSpace get() = Euclidean2DSpace - + override operator fun iterator(): Iterator = listOf(x, y).iterator() override fun unwrap(): Vector2D = this - override fun Vector2D.wrap(): Vector2D = this } -val Vector2D.r: Double get() = Euclidean2DSpace.run { sqrt(norm()) } +val Vector2D.r: Double get() = Euclidean2DSpace { sqrt(norm()) } @Suppress("FunctionName") fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y) diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt index dd1776342..3748e58c7 100644 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt @@ -2,6 +2,7 @@ package scientifik.kmath.geometry import scientifik.kmath.linear.Point import scientifik.kmath.operations.SpaceElement +import scientifik.kmath.operations.invoke import kotlin.math.sqrt @@ -9,19 +10,17 @@ interface Vector3D : Point, Vector, SpaceElement x 2 -> y 3 -> z else -> error("Accessing outside of point bounds") } - override fun iterator(): Iterator = listOf(x, y, z).iterator() - - override val context: Euclidean3DSpace get() = Euclidean3DSpace + override operator fun iterator(): Iterator = listOf(x, y, z).iterator() override fun unwrap(): Vector3D = this @@ -31,7 +30,7 @@ interface Vector3D : Point, Vector, SpaceElement { override fun Vector3D.dot(other: Vector3D): Double = x * other.x + y * other.y + z * other.z -} \ No newline at end of file +} diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 43d50ad20..9ff2aacf5 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -4,6 +4,9 @@ import scientifik.kmath.domains.Domain import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.RealBuffer +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * The bin in the histogram. The histogram is by definition always done in the real space @@ -37,20 +40,20 @@ interface MutableHistogram> : Histogram { */ fun putWithWeight(point: Point, weight: Double) - fun put(point: Point) = putWithWeight(point, 1.0) + fun put(point: Point): Unit = putWithWeight(point, 1.0) } -fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) +fun MutableHistogram.put(vararg point: T): Unit = put(ArrayBuffer(point)) -fun MutableHistogram.put(vararg point: Number) = +fun MutableHistogram.put(vararg point: Number): Unit = put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(RealBuffer(point)) +fun MutableHistogram.put(vararg point: Double): Unit = put(RealBuffer(point)) -fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } +fun MutableHistogram.fill(sequence: Iterable>): Unit = sequence.forEach { put(it) } /** * Pass a sequence builder into histogram */ -fun MutableHistogram.fill(buider: suspend SequenceScope>.() -> Unit) = - fill(sequence(buider).asIterable()) \ No newline at end of file +fun MutableHistogram.fill(block: suspend SequenceScope>.() -> Unit): Unit = + fill(sequence(block).asIterable()) diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 628a68461..f05ae1694 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -2,6 +2,7 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point import scientifik.kmath.operations.SpaceOperations +import scientifik.kmath.operations.invoke import scientifik.kmath.real.asVector import scientifik.kmath.structures.* import kotlin.math.floor @@ -9,19 +10,16 @@ import kotlin.math.floor data class BinDef>(val space: SpaceOperations>, val center: Point, val sizes: Point) { fun contains(vector: Point): Boolean { - if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") - val upper = space.run { center + sizes / 2.0 } - val lower = space.run { center - sizes / 2.0 } - return vector.asSequence().mapIndexed { i, value -> - value in lower[i]..upper[i] - }.all { it } + require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" } + val upper = space { center + sizes / 2.0 } + val lower = space { center - sizes / 2.0 } + return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it } } } class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - - override fun contains(point: Point): Boolean = def.contains(point) + override operator fun contains(point: Point): Boolean = def.contains(point) override val dimension: Int get() = def.center.size @@ -39,47 +37,34 @@ class RealHistogram( private val upper: Buffer, private val binNums: IntArray = IntArray(lower.size) { 20 } ) : MutableHistogram> { - - private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) - private val values: NDStructure = NDStructure.auto(strides) { LongCounter() } - private val weights: NDStructure = NDStructure.auto(strides) { DoubleCounter() } - override val dimension: Int get() = lower.size - - private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks - if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") - if (lower.size != binNums.size) error("Dimension mismatch in bin count.") - if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive") + require(lower.size == upper.size) { "Dimension mismatch in histogram lower and upper limits." } + require(lower.size == binNums.size) { "Dimension mismatch in bin count." } + require(!(0 until dimension).any { upper[it] - lower[it] < 0 }) { "Range for one of axis is not strictly positive" } } /** * Get internal [NDStructure] bin index for given axis */ - private fun getIndex(axis: Int, value: Double): Int { - return when { - value >= upper[axis] -> binNums[axis] + 1 // overflow - value < lower[axis] -> 0 // underflow - else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1 - } + private fun getIndex(axis: Int, value: Double): Int = when { + value >= upper[axis] -> binNums[axis] + 1 // overflow + value < lower[axis] -> 0 // underflow + else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1 } private fun getIndex(point: Buffer): IntArray = IntArray(dimension) { getIndex(it, point[it]) } - private fun getValue(index: IntArray): Long { - return values[index].sum() - } + private fun getValue(index: IntArray): Long = values[index].sum() - fun getValue(point: Buffer): Long { - return getValue(getIndex(point)) - } + fun getValue(point: Buffer): Long = getValue(getIndex(point)) private fun getDef(index: IntArray): BinDef { val center = index.mapIndexed { axis, i -> @@ -89,14 +74,13 @@ class RealHistogram( else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis] } }.asBuffer() + return BinDef(RealBufferFieldOperations, center, binSize) } - fun getDef(point: Buffer): BinDef { - return getDef(getIndex(point)) - } + fun getDef(point: Buffer): BinDef = getDef(getIndex(point)) - override fun get(point: Buffer): MultivariateBin? { + override operator fun get(point: Buffer): MultivariateBin? { val index = getIndex(point) return MultivariateBin(getDef(index), getValue(index)) } @@ -112,26 +96,21 @@ class RealHistogram( weights[index].add(weight) } - override fun iterator(): Iterator> = weights.elements().map { (index, value) -> + override operator fun iterator(): Iterator> = weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }.iterator() /** * Convert this histogram into NDStructure containing bin values but not bin descriptions */ - fun values(): NDStructure { - return NDStructure.auto(values.shape) { values[it].sum() } - } + fun values(): NDStructure = NDStructure.auto(values.shape) { values[it].sum() } /** * Sum of weights */ - fun weights():NDStructure{ - return NDStructure.auto(weights.shape) { weights[it].sum() } - } + fun weights(): NDStructure = NDStructure.auto(weights.shape) { weights[it].sum() } companion object { - /** * Use it like * ``` @@ -141,12 +120,10 @@ class RealHistogram( *) *``` */ - fun fromRanges(vararg ranges: ClosedFloatingPointRange): RealHistogram { - return RealHistogram( - ranges.map { it.start }.asVector(), - ranges.map { it.endInclusive }.asVector() - ) - } + fun fromRanges(vararg ranges: ClosedFloatingPointRange): RealHistogram = RealHistogram( + ranges.map { it.start }.asVector(), + ranges.map { it.endInclusive }.asVector() + ) /** * Use it like @@ -157,13 +134,10 @@ class RealHistogram( *) *``` */ - fun fromRanges(vararg ranges: Pair, Int>): RealHistogram { - return RealHistogram( - ListBuffer(ranges.map { it.first.start }), - ListBuffer(ranges.map { it.first.endInclusive }), - ranges.map { it.second }.toIntArray() - ) - } + fun fromRanges(vararg ranges: Pair, Int>): RealHistogram = RealHistogram( + ListBuffer(ranges.map { it.first.start }), + ListBuffer(ranges.map { it.first.endInclusive }), + ranges.map { it.second }.toIntArray() + ) } - -} \ No newline at end of file +} diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index af01205bf..e30a45f5a 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -46,11 +46,11 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U synchronized(this) { bins.put(it.position, it) } } - override fun get(point: Buffer): UnivariateBin? = get(point[0]) + override operator fun get(point: Buffer): UnivariateBin? = get(point[0]) override val dimension: Int get() = 1 - override fun iterator(): Iterator = bins.values.iterator() + override operator fun iterator(): Iterator = bins.values.iterator() /** * Thread safe put operation @@ -65,15 +65,14 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U } companion object { - fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram { - return UnivariateHistogram { value -> - val center = start + binSize * floor((value - start) / binSize + 0.5) - UnivariateBin(center, binSize) - } + fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value -> + val center = start + binSize * floor((value - start) / binSize + 0.5) + UnivariateBin(center, binSize) } fun custom(borders: DoubleArray): UnivariateHistogram { val sorted = borders.sortedArray() + return UnivariateHistogram { value -> when { value < sorted.first() -> UnivariateBin( diff --git a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt index 10deabd73..bd8fa782a 100644 --- a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt +++ b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt @@ -3,16 +3,16 @@ package scientifik.kmath.linear import koma.extensions.fill import koma.matrix.MatrixFactory import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.NDStructure class KomaMatrixContext( private val factory: MatrixFactory>, private val space: Space -) : - MatrixContext { +) : MatrixContext { - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) = + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix = KomaMatrix(factory.zeros(rows, columns).fill(initializer)) fun Matrix.toKoma(): KomaMatrix = if (this is KomaMatrix) { @@ -28,31 +28,28 @@ class KomaMatrixContext( } - override fun Matrix.dot(other: Matrix) = - KomaMatrix(this.toKoma().origin * other.toKoma().origin) + override fun Matrix.dot(other: Matrix): KomaMatrix = + KomaMatrix(toKoma().origin * other.toKoma().origin) - override fun Matrix.dot(vector: Point) = - KomaVector(this.toKoma().origin * vector.toKoma().origin) + override fun Matrix.dot(vector: Point): KomaVector = + KomaVector(toKoma().origin * vector.toKoma().origin) - override fun Matrix.unaryMinus() = - KomaMatrix(this.toKoma().origin.unaryMinus()) + override operator fun Matrix.unaryMinus(): KomaMatrix = + KomaMatrix(toKoma().origin.unaryMinus()) - override fun add(a: Matrix, b: Matrix) = + override fun add(a: Matrix, b: Matrix): KomaMatrix = KomaMatrix(a.toKoma().origin + b.toKoma().origin) - override fun Matrix.minus(b: Matrix) = - KomaMatrix(this.toKoma().origin - b.toKoma().origin) + override operator fun Matrix.minus(b: Matrix): KomaMatrix = + KomaMatrix(toKoma().origin - b.toKoma().origin) override fun multiply(a: Matrix, k: Number): Matrix = - produce(a.rowNum, a.colNum) { i, j -> space.run { a[i, j] * k } } + produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } - override fun Matrix.times(value: T) = - KomaMatrix(this.toKoma().origin * value) - - companion object { - - } + override operator fun Matrix.times(value: T): KomaMatrix = + KomaMatrix(toKoma().origin * value) + companion object } fun KomaMatrixContext.solve(a: Matrix, b: Matrix) = @@ -70,10 +67,11 @@ class KomaMatrix(val origin: koma.matrix.Matrix, features: Set = features ?: setOf( + override val features: Set = features ?: hashSetOf( object : DeterminantFeature { override val determinant: T get() = origin.det() }, + object : LUPDecompositionFeature { private val lup by lazy { origin.LU() } override val l: FeaturedMatrix get() = KomaMatrix(lup.second) @@ -85,7 +83,7 @@ class KomaMatrix(val origin: koma.matrix.Matrix, features: Set = KomaMatrix(this.origin, this.features + features) - override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) + override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j) override fun equals(other: Any?): Boolean { return NDStructure.equals(this, other as? NDStructure<*> ?: return false) @@ -101,14 +99,12 @@ class KomaMatrix(val origin: koma.matrix.Matrix, features: Set internal constructor(val origin: koma.matrix.Matrix) : Point { - init { - if (origin.numCols() != 1) error("Only single column matrices are allowed") - } - override val size: Int get() = origin.numRows() - override fun get(index: Int): T = origin.getGeneric(index) + init { + require(origin.numCols() == 1) { error("Only single column matrices are allowed") } + } - override fun iterator(): Iterator = origin.toIterable().iterator() + override operator fun get(index: Int): T = origin.getGeneric(index) + override operator fun iterator(): Iterator = origin.toIterable().iterator() } - diff --git a/kmath-memory/build.gradle.kts b/kmath-memory/build.gradle.kts index 1f34a4f17..75b4f174e 100644 --- a/kmath-memory/build.gradle.kts +++ b/kmath-memory/build.gradle.kts @@ -1,3 +1,3 @@ plugins { id("scientifik.mpp") -} +} \ No newline at end of file diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt index a749a7074..177c6b46b 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt @@ -1,5 +1,8 @@ package scientifik.memory +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + /** * Represents a display of certain memory structure. */ @@ -80,8 +83,12 @@ interface MemoryReader { /** * Uses the memory for read then releases the reader. */ -inline fun Memory.read(block: MemoryReader.() -> Unit) { - reader().apply(block).release() +inline fun Memory.read(block: MemoryReader.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + val reader = reader() + val result = reader.block() + reader.release() + return result } /** @@ -133,6 +140,7 @@ interface MemoryWriter { * Uses the memory for write then releases the writer. */ inline fun Memory.write(block: MemoryWriter.() -> Unit) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } writer().apply(block).release() } diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt index 59a93f290..1381afbec 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt @@ -38,11 +38,7 @@ fun MemoryWriter.write(spec: MemorySpec, offset: Int, value: T): Un * Reads array of [size] objects mapped by [spec] at certain [offset]. */ inline fun MemoryReader.readArray(spec: MemorySpec, offset: Int, size: Int): Array = - Array(size) { i -> - spec.run { - read(offset + i * objectSize) - } - } + Array(size) { i -> with(spec) { read(offset + i * objectSize) } } /** * Writes [array] of objects mapped by [spec] at certain [offset]. diff --git a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt b/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt index b5a0dd51b..f4967bf5c 100644 --- a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt +++ b/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt @@ -1,12 +1,17 @@ package scientifik.memory +import java.io.IOException import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.file.Files import java.nio.file.Path import java.nio.file.StandardOpenOption +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract -private class ByteBufferMemory( +@PublishedApi +internal class ByteBufferMemory( val buffer: ByteBuffer, val startOffset: Int = 0, override val size: Int = buffer.limit() @@ -112,7 +117,11 @@ fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory = /** * Uses direct memory-mapped buffer from file to read something and close it afterwards. */ -fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R = - FileChannel.open(this, StandardOpenOption.READ).use { - ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() - } +@Throws(IOException::class) +inline fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + + return FileChannel + .open(this, StandardOpenOption.READ) + .use { ByteBufferMemory(it.map(FileChannel.MapMode.READ_ONLY, position, size)).block() } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt index 47fc6e4c5..49163c701 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt @@ -11,4 +11,4 @@ class RandomChain(val generator: RandomGenerator, private val gen: suspen override fun fork(): Chain = RandomChain(generator.fork(), gen) } -fun RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, gen) \ No newline at end of file +fun RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, gen) diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt index 3a60c0bda..02f98439e 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt @@ -5,6 +5,7 @@ import scientifik.kmath.chains.ConstantChain import scientifik.kmath.chains.map import scientifik.kmath.chains.zip import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke class BasicSampler(val chainBuilder: (RandomGenerator) -> Chain) : Sampler { override fun sample(generator: RandomGenerator): Chain = chainBuilder(generator) @@ -22,10 +23,10 @@ class SamplerSpace(val space: Space) : Space> { override val zero: Sampler = ConstantSampler(space.zero) override fun add(a: Sampler, b: Sampler): Sampler = BasicSampler { generator -> - a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } } + a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } } } override fun multiply(a: Sampler, k: Number): Sampler = BasicSampler { generator -> - a.sample(generator).map { space.run { it * k.toDouble() } } + a.sample(generator).map { space { it * k.toDouble() } } } -} \ No newline at end of file +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt index 804aed089..c82d262bf 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt @@ -29,8 +29,10 @@ interface Statistic { interface ComposableStatistic : Statistic { //compute statistic on a single block suspend fun computeIntermediate(data: Buffer): I + //Compose two blocks suspend fun composeIntermediate(first: I, second: I): I + //Transform block to result suspend fun toResult(intermediate: I): R @@ -58,26 +60,26 @@ private fun ComposableStatistic.flowIntermediate( fun ComposableStatistic.flow( flow: Flow>, dispatcher: CoroutineDispatcher = Dispatchers.Default -): Flow = flowIntermediate(flow,dispatcher).map(::toResult) +): Flow = flowIntermediate(flow, dispatcher).map(::toResult) /** * Arithmetic mean */ class Mean(val space: Space) : ComposableStatistic, T> { override suspend fun computeIntermediate(data: Buffer): Pair = - space.run { sum(data.asIterable()) } to data.size + space { sum(data.asIterable()) } to data.size override suspend fun composeIntermediate(first: Pair, second: Pair): Pair = - space.run { first.first + second.first } to (first.second + second.second) + space { first.first + second.first } to (first.second + second.second) override suspend fun toResult(intermediate: Pair): T = - space.run { intermediate.first / intermediate.second } + space { intermediate.first / intermediate.second } companion object { //TODO replace with optimized version which respects overflow - val real = Mean(RealField) - val int = Mean(IntRing) - val long = Mean(LongRing) + val real: Mean = Mean(RealField) + val int: Mean = Mean(IntRing) + val long: Mean = Mean(LongRing) } } @@ -85,11 +87,10 @@ class Mean(val space: Space) : ComposableStatistic, T> { * Non-composable median */ class Median(private val comparator: Comparator) : Statistic { - override suspend fun invoke(data: Buffer): T { - return data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct - } + override suspend fun invoke(data: Buffer): T = + data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct companion object { - val real = Median(Comparator { a: Double, b: Double -> a.compareTo(b) }) + val real: Median = Median(Comparator { a: Double, b: Double -> a.compareTo(b) }) } -} \ No newline at end of file +} diff --git a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt b/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt index 040eee951..551b877a7 100644 --- a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt +++ b/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt @@ -16,5 +16,5 @@ inline class ViktorBuffer(val flatArray: F64FlatArray) : MutableBuffer { return ViktorBuffer(flatArray.copy().flatten()) } - override fun iterator(): Iterator = flatArray.data.iterator() -} \ No newline at end of file + override operator fun iterator(): Iterator = flatArray.data.iterator() +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 487e1d87f..6601fd053 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,6 +1,6 @@ pluginManagement { - val toolsVersion = "0.5.0" + val toolsVersion = "0.5.2" plugins { id("kotlinx.benchmark") version "0.2.0-dev-8" @@ -20,14 +20,6 @@ pluginManagement { maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/kotlin/kotlinx") } - - resolutionStrategy { - eachPlugin { - when (requested.id.id) { - "scientifik.mpp", "scientifik.jvm", "scientifik.publish" -> useModule("scientifik:gradle-tools:$toolsVersion") - } - } - } } rootProject.name = "kmath"