From e65e1eedaad36c4f77bab18bbe7d72d506f20e8c Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 4 Apr 2022 18:43:20 +0700 Subject: [PATCH] Apply context receivers to operator extension functions in main algebras --- benchmarks/build.gradle.kts | 12 +- .../ExpressionsInterpretersBenchmark.kt | 5 +- build.gradle.kts | 2 +- buildSrc/gradle.properties | 2 +- .../kmath/ejml/codegen/ejmlCodegen.kt | 6 +- docs/algebra.md | 4 - examples/build.gradle.kts | 3 +- .../space/kscience/kmath/ast/expressions.kt | 3 +- .../space/kscience/kmath/fit/chiSquared.kt | 2 + .../kotlin/space/kscience/kmath/fit/qowFit.kt | 2 + .../kmath/functions/matrixIntegration.kt | 3 +- .../kmath/operations/mixedNDOperations.kt | 3 +- .../kscience/kmath/structures/ComplexND.kt | 2 + .../kmath/structures/StreamDoubleFieldND.kt | 2 +- .../kscience/kmath/structures/buffers.kt | 5 +- .../kscience/kmath/tensors/OLSWithSVD.kt | 17 +- .../space/kscience/kmath/tensors/PCA.kt | 3 +- .../kmath/tensors/dataSetNormalization.kt | 3 +- .../space/kscience/kmath/tensors/multik.kt | 2 + .../kscience/kmath/tensors/neuralNetwork.kt | 21 +- kmath-ast/build.gradle.kts | 5 + .../TestCompilerConsistencyWithInterpreter.kt | 4 +- .../kmath/ast/TestCompilerOperations.kt | 3 +- .../DerivativeStructureExpression.kt | 2 +- .../DerivativeStructureExpressionTest.kt | 3 + .../commons/optimization/OptimizeTest.kt | 2 +- kmath-complex/build.gradle.kts | 7 +- .../space/kscience/kmath/complex/Complex.kt | 102 ++++----- .../kscience/kmath/complex/Quaternion.kt | 5 +- .../kmath/complex/ComplexFieldTest.kt | 2 +- .../kscience/kmath/complex/ComplexTest.kt | 2 + .../complex/ExpressionFieldForComplexTest.kt | 2 + .../kmath/complex/QuaternionFieldTest.kt | 2 + kmath-core/build.gradle.kts | 7 +- .../FunctionalExpressionAlgebra.kt | 47 +++-- .../kscience/kmath/expressions/MstAlgebra.kt | 25 +-- .../kmath/expressions/SimpleAutoDiff.kt | 4 +- .../kmath/expressions/specialExpressions.kt | 3 +- .../kmath/linear/DoubleLinearSpace.kt | 4 +- .../kscience/kmath/linear/LinearSpace.kt | 5 +- .../kscience/kmath/linear/MatrixWrapper.kt | 1 - .../space/kscience/kmath/misc/cumulative.kt | 1 + .../space/kscience/kmath/nd/AlgebraND.kt | 178 ++++++++-------- .../kscience/kmath/nd/BufferAlgebraND.kt | 2 +- .../space/kscience/kmath/nd/DoubleFieldND.kt | 45 +--- .../space/kscience/kmath/nd/StructureND.kt | 1 + .../kscience/kmath/operations/Algebra.kt | 123 ++++++----- .../space/kscience/kmath/operations/BigInt.kt | 2 +- .../kmath/operations/BufferAlgebra.kt | 4 +- .../kmath/operations/DoubleBufferOps.kt | 20 +- .../kmath/operations/NumericAlgebra.kt | 58 ++--- .../kscience/kmath/operations/numbers.kt | 63 +++--- .../kmath/expressions/ExpressionFieldTest.kt | 2 + .../kmath/expressions/InterpretTest.kt | 5 +- .../kmath/expressions/SimpleAutoDiffTest.kt | 3 +- .../kscience/kmath/structures/NDFieldTest.kt | 1 + .../kmath/structures/NumberNDFieldTest.kt | 5 +- .../kscience/kmath/testutils/FieldVerifier.kt | 3 +- .../kscience/kmath/testutils/RingVerifier.kt | 4 +- .../kscience/kmath/testutils/SpaceVerifier.kt | 4 +- .../kscience/kmath/operations/BigNumbers.kt | 11 +- kmath-coroutines/build.gradle.kts | 9 +- .../space/kscience/kmath/chains/flowExtra.kt | 5 +- .../kmath/structures/LazyStructureND.kt | 14 +- kmath-dimensions/build.gradle.kts | 7 +- kmath-ejml/build.gradle.kts | 19 +- .../space/kscience/kmath/ejml/_generated.kt | 24 +-- kmath-for-real/build.gradle.kts | 7 +- kmath-functions/build.gradle.kts | 7 +- .../kscience/kmath/functions/Polynomial.kt | 6 +- .../kmath/integration/GaussIntegrator.kt | 3 + .../kmath/integration/SimpsonIntegrator.kt | 8 +- .../kmath/interpolation/LinearInterpolator.kt | 3 +- .../kmath/interpolation/SplineInterpolator.kt | 4 +- kmath-geometry/build.gradle.kts | 7 +- .../kmath/geometry/Euclidean2DSpace.kt | 5 +- .../kmath/geometry/Euclidean3DSpace.kt | 5 +- .../kscience/kmath/geometry/Projections.kt | 5 + .../kmath/geometry/Euclidean2DSpaceTest.kt | 2 + .../kmath/geometry/Euclidean3DSpaceTest.kt | 2 + .../kmath/geometry/ProjectionAlongTest.kt | 1 + .../kmath/geometry/ProjectionOntoLineTest.kt | 1 + kmath-histograms/build.gradle.kts | 13 +- .../space/kscience/kmath/histogram/Counter.kt | 3 +- .../kscience/kmath/histogram/HistogramND.kt | 4 +- .../kmath/histogram/UniformHistogram1D.kt | 9 +- .../histogram/UniformHistogramGroupND.kt | 5 +- .../histogram/MultivariateHistogramTest.kt | 1 + .../kmath/histogram/UniformHistogram1DTest.kt | 7 +- .../kmath/histogram/TreeHistogramGroup.kt | 17 +- .../kscience/kmath/jafama/KMathJafama.kt | 16 +- .../kscience/kmath/jupyter/KMathJupyter.kt | 2 + .../kmath/kotlingrad/scalarsAdapters.kt | 1 - kmath-memory/build.gradle.kts | 2 +- .../kscience/kmath/memory/ByteBufferMemory.kt | 3 +- .../kmath/multik/MultikDoubleAlgebra.kt | 18 +- .../kmath/multik/MultikTensorAlgebra.kt | 86 ++++---- .../kscience/kmath/nd4j/Nd4jArrayAlgebra.kt | 116 +++++----- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 59 ++++-- kmath-optimization/build.gradle.kts | 7 +- kmath-stat/build.gradle.kts | 9 +- .../kscience/kmath/stat/SamplerAlgebra.kt | 6 +- .../kmath/tensorflow/TensorFlowAlgebra.kt | 12 +- .../kmath/tensorflow/DoubleTensorFlowOps.kt | 12 +- kmath-tensors/build.gradle.kts | 7 +- .../tensors/api/AnalyticTensorAlgebra.kt | 4 +- .../tensors/api/LinearOpsTensorAlgebra.kt | 2 +- .../kmath/tensors/api/TensorAlgebra.kt | 112 +++------- .../api/TensorPartialDivisionAlgebra.kt | 31 +-- .../core/BroadcastDoubleTensorAlgebra.kt | 16 +- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 199 ++++++++++-------- .../kmath/tensors/core/internal/linUtils.kt | 7 +- .../kmath/tensors/core/TestBroadcasting.kt | 12 +- .../core/TestDoubleLinearOpsAlgebra.kt | 11 +- .../tensors/core/TestDoubleTensorAlgebra.kt | 60 ++++-- .../kscience/kmath/viktor/ViktorBuffer.kt | 5 +- .../kscience/kmath/viktor/ViktorFieldOpsND.kt | 29 +-- .../kmath/viktor/ViktorStructureND.kt | 7 +- 118 files changed, 989 insertions(+), 939 deletions(-) diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index 22712816d..d4b8671a4 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -18,9 +18,10 @@ repositories { kotlin { jvm() - js(IR) { - nodejs() - } +// Testing multi-receiver! +// js(IR) { +// nodejs() +// } sourceSets { all { @@ -74,7 +75,8 @@ benchmark { // Setup configurations targets { register("jvm") - register("js") + // Testing multi-receiver! + // register("js") } fun kotlinx.benchmark.gradle.BenchmarkConfiguration.commonConfiguration() { @@ -158,7 +160,7 @@ kotlin.sourceSets.all { tasks.withType { kotlinOptions { jvmTarget = "11" - freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy" + freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy" + "-Xcontext-receivers" } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt index db3524e67..74d70974c 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt @@ -11,10 +11,7 @@ import kotlinx.benchmark.Scope import kotlinx.benchmark.State import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.expressions.* -import space.kscience.kmath.operations.Algebra -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.bindSymbol -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.math.sin import kotlin.random.Random diff --git a/build.gradle.kts b/build.gradle.kts index 1e703c5e1..dc4a6c60c 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -11,7 +11,7 @@ allprojects { } group = "space.kscience" - version = "0.3.0-dev-21" + version = "0.4.0-dev-1" } subprojects { diff --git a/buildSrc/gradle.properties b/buildSrc/gradle.properties index a0b05e812..906db76f9 100644 --- a/buildSrc/gradle.properties +++ b/buildSrc/gradle.properties @@ -4,4 +4,4 @@ # kotlin.code.style=official -toolsVersion=0.11.2-kotlin-1.6.10 +toolsVersion=0.11.2-kotlin-1.6.20 diff --git a/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt b/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt index 7f8cb35b3..0361e543c 100644 --- a/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt +++ b/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt @@ -207,7 +207,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, structure.getFeature(type)?.let { return it } val origin = structure.toEjml().origin - return when (type) { + return type.cast(when (type) { ${ if (isDense) """ InverseMatrixFeature::class -> object : InverseMatrixFeature<${type}> { @@ -318,8 +318,8 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, }""" } - else -> null - }?.let(type::cast) + else -> return null + }) } /** diff --git a/docs/algebra.md b/docs/algebra.md index 20158a125..76c77ffa9 100644 --- a/docs/algebra.md +++ b/docs/algebra.md @@ -62,10 +62,6 @@ val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0) val c3 = ComplexField { c1 - i * 2.0 } ``` -**Note**: In theory it is possible to add behaviors directly to the context, but as for now Kotlin does not support -that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and -[KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates. - ## Nested fields Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 60f8f5aed..8575cbd50 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -61,7 +61,8 @@ kotlin.sourceSets.all { tasks.withType { kotlinOptions { jvmTarget = "11" - freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy" + freeCompilerArgs = + freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy" + "-Xcontext-receivers" } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt index 907f1bbe4..640f11326 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt @@ -8,8 +8,7 @@ package space.kscience.kmath.ast import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.expressions.MstExtendedField import space.kscience.kmath.expressions.Symbol.Companion.x -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* fun main() { val expr = MstExtendedField { diff --git a/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt b/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt index 63e57bd8c..59f77a877 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt @@ -13,6 +13,8 @@ import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.expressions.chiSquaredExpression import space.kscience.kmath.expressions.symbol import space.kscience.kmath.operations.asIterable +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import space.kscience.kmath.operations.toList import space.kscience.kmath.optimization.FunctionOptimizationTarget import space.kscience.kmath.optimization.optimizeWith diff --git a/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt b/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt index d52976671..02e4235d5 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt @@ -14,6 +14,8 @@ import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.binding import space.kscience.kmath.expressions.symbol import space.kscience.kmath.operations.asIterable +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import space.kscience.kmath.operations.toList import space.kscience.kmath.optimization.QowOptimizer import space.kscience.kmath.optimization.chiSquaredOrNull diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt index 4f99aeb47..780924e8a 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt @@ -13,6 +13,7 @@ import space.kscience.kmath.nd.structureND import space.kscience.kmath.nd.withNdAlgebra import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.times fun main(): Unit = Double.algebra { withNdAlgebra(2, 2) { @@ -31,4 +32,4 @@ fun main(): Unit = Double.algebra { //the value is nullable because in some cases the integration could not succeed println(result.value) } -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/space/kscience/kmath/operations/mixedNDOperations.kt b/examples/src/main/kotlin/space/kscience/kmath/operations/mixedNDOperations.kt index 62c9c8076..3f0da2e64 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/operations/mixedNDOperations.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/operations/mixedNDOperations.kt @@ -7,7 +7,6 @@ package space.kscience.kmath.operations import space.kscience.kmath.commons.linear.CMLinearSpace import space.kscience.kmath.linear.matrix -import space.kscience.kmath.nd.DoubleBufferND import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.ndAlgebra @@ -21,7 +20,7 @@ fun main() { val cmMatrix: Structure2D = CMLinearSpace.matrix(2, 2)(0.0, 1.0, 0.0, 3.0) - val res: DoubleBufferND = DoubleField.ndAlgebra { + val res = DoubleField.ndAlgebra { exp(viktorStructure) + 2.0 * cmMatrix } diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt index d55f3df09..f565f88af 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt @@ -15,6 +15,8 @@ import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.structureND import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import kotlin.system.measureTimeMillis fun main() { diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt index 548fb16c1..0ae2d3e88 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt @@ -80,7 +80,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND.unaryMinus(): StructureND = map { -it } + override fun negate(arg: StructureND): StructureND = arg.map { -it } override fun scale(a: StructureND, value: Double): StructureND = a.map { it * value } diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt index 889ea99bd..513a936aa 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt @@ -5,10 +5,7 @@ package space.kscience.kmath.structures -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.buffer -import space.kscience.kmath.operations.bufferAlgebra -import space.kscience.kmath.operations.withSize +import space.kscience.kmath.operations.* inline fun MutableBuffer.Companion.same( n: Int, diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt index b42602988..e8831b503 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt @@ -5,12 +5,13 @@ package space.kscience.kmath.tensors +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.abs import space.kscience.kmath.operations.invoke -import space.kscience.kmath.tensors.core.DoubleTensor +import space.kscience.kmath.operations.minus +import space.kscience.kmath.operations.plus import space.kscience.kmath.tensors.core.DoubleTensorAlgebra -import kotlin.math.abs - // OLS estimator using SVD fun main() { @@ -48,14 +49,16 @@ fun main() { // inverse Sigma matrix can be restored from singular values with diagonalEmbedding function - val sigma = diagonalEmbedding(singValues.map{ if (abs(it) < 1e-3) 0.0 else 1.0/it }) + val sigma = diagonalEmbedding(singValues.map { if (abs(it) < 1e-3) 0.0 else 1.0 / it }) val alphaOLS = v dot sigma dot u.transpose() dot y - println("Estimated alpha:\n" + - "$alphaOLS") + println( + "Estimated alpha:\n" + + "$alphaOLS" + ) // figure out MSE of approximation - fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double { + fun mse(yTrue: StructureND, yPred: StructureND): Double { require(yTrue.shape.size == 1) require(yTrue.shape contentEquals yPred.shape) diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt index aced0cf7d..ab7c6d3b5 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt @@ -5,10 +5,11 @@ package space.kscience.kmath.tensors +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import space.kscience.kmath.tensors.core.tensorAlgebra import space.kscience.kmath.tensors.core.withBroadcast - // simple PCA fun main(): Unit = Double.tensorAlgebra.withBroadcast { // work in context with broadcast methods diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/dataSetNormalization.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/dataSetNormalization.kt index a436ae1c3..b90c580fb 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/dataSetNormalization.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/dataSetNormalization.kt @@ -5,10 +5,11 @@ package space.kscience.kmath.tensors +import space.kscience.kmath.operations.div +import space.kscience.kmath.operations.minus import space.kscience.kmath.tensors.core.tensorAlgebra import space.kscience.kmath.tensors.core.withBroadcast - // Dataset normalization fun main() = Double.tensorAlgebra.withBroadcast { // work in context with broadcast methods diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt index f2d1f0b41..24e55de53 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt @@ -10,6 +10,8 @@ import org.jetbrains.kotlinx.multik.api.ndarray import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.nd.one import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.minus +import space.kscience.kmath.operations.plus fun main(): Unit = with(DoubleField.multikAlgebra) { val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType().wrap() diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt index 5c41ab0f1..2c52e43f4 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt @@ -5,11 +5,9 @@ package space.kscience.kmath.tensors -import space.kscience.kmath.operations.invoke -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra -import space.kscience.kmath.tensors.core.DoubleTensor -import space.kscience.kmath.tensors.core.DoubleTensorAlgebra -import space.kscience.kmath.tensors.core.copyArray +import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.operations.* +import space.kscience.kmath.tensors.core.* import kotlin.math.sqrt const val seed = 100500L @@ -27,13 +25,10 @@ open class Activation( val activation: (DoubleTensor) -> DoubleTensor, val activationDer: (DoubleTensor) -> DoubleTensor, ) : Layer { - override fun forward(input: DoubleTensor): DoubleTensor { - return activation(input) - } + override fun forward(input: DoubleTensor): DoubleTensor = activation(input) - override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor { - return DoubleTensorAlgebra { outputError * activationDer(input) } - } + override fun backward(input: DoubleTensor, outputError: DoubleTensor): DoubleTensor = + DoubleTensorAlgebra { outputError * activationDer(input) } } fun relu(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra { @@ -106,8 +101,8 @@ fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double { } // neural network class -@OptIn(ExperimentalStdlibApi::class) class NeuralNetwork(private val layers: List) { + @OptIn(PerformancePitfall::class) private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra { val onesForAnswers = yPred.zeroesLike() @@ -174,7 +169,7 @@ class NeuralNetwork(private val layers: List) { } -@OptIn(ExperimentalStdlibApi::class) +@OptIn(PerformancePitfall::class) fun main() = BroadcastDoubleTensorAlgebra { val features = 5 val sampleSize = 250 diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 15b1d0900..e9bf2e483 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -84,3 +84,8 @@ readme { ref = "src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt" ) { "Extendable MST rendering" } } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt index 1edb5923e..8f02796bf 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt @@ -9,9 +9,7 @@ import space.kscience.kmath.expressions.MstField import space.kscience.kmath.expressions.MstRing import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.interpret -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.IntRing -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt index be8a92f3e..7526df935 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt @@ -8,8 +8,7 @@ package space.kscience.kmath.ast import space.kscience.kmath.expressions.MstExtendedField import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.invoke -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 82694d95a..55f7da7d5 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -68,7 +68,7 @@ public class DerivativeStructureField( public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList()) - override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() + override fun negate(arg: DerivativeStructure): DerivativeStructure = arg.negate() override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.add(right) diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt index 56252ab34..67f4f37ad 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt @@ -6,6 +6,9 @@ package space.kscience.kmath.commons.expressions import space.kscience.kmath.expressions.* +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times +import space.kscience.kmath.operations.unaryMinus import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.test.Test diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt index 0977dc247..e4b47adab 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -13,7 +13,7 @@ import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.y import space.kscience.kmath.expressions.chiSquaredExpression import space.kscience.kmath.expressions.symbol -import space.kscience.kmath.operations.map +import space.kscience.kmath.operations.* import space.kscience.kmath.optimization.* import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.structures.DoubleBuffer diff --git a/kmath-complex/build.gradle.kts b/kmath-complex/build.gradle.kts index ea74df646..4fd7ca5f6 100644 --- a/kmath-complex/build.gradle.kts +++ b/kmath-complex/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } kotlin.sourceSets { @@ -29,3 +29,8 @@ readme { ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt" ) } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt index 77fe782a9..0df8311d8 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.complex +import space.kscience.kmath.complex.ComplexField.plus import space.kscience.kmath.memory.MemoryReader import space.kscience.kmath.memory.MemorySpec import space.kscience.kmath.memory.MemoryWriter @@ -71,14 +72,13 @@ public object ComplexField : */ public val i: Complex by lazy { Complex(0.0, 1.0) } - override fun Complex.unaryMinus(): Complex = Complex(-re, -im) override fun number(value: Number): Complex = Complex(value.toDouble(), 0.0) override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value) override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im) -// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) + override fun negate(arg: Complex): Complex = Complex(-arg.re, -arg.im) override fun multiply(left: Complex, right: Complex): Complex = Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re) @@ -107,8 +107,6 @@ public object ComplexField : } } - override operator fun Complex.div(k: Number): Complex = Complex(re / k.toDouble(), im / k.toDouble()) - override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2.0 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2.0 @@ -139,61 +137,65 @@ public object ComplexField : override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) - /** - * Adds complex number to real one. - * - * @receiver the augend. - * @param c the addend. - * @return the sum. - */ - public operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c) - - /** - * Subtracts complex number from real one. - * - * @receiver the minuend. - * @param c the subtrahend. - * @return the difference. - */ - public operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c) - - /** - * Adds real number to complex one. - * - * @receiver the augend. - * @param d the addend. - * @return the sum. - */ - public operator fun Complex.plus(d: Double): Complex = d + this - - /** - * Subtracts real number from complex one. - * - * @receiver the minuend. - * @param d the subtrahend. - * @return the difference. - */ - public operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex()) - - /** - * Multiplies real number by complex one. - * - * @receiver the multiplier. - * @param c the multiplicand. - * @receiver the product. - */ - public operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) - override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) } +/** + * Adds complex number to real one. + * + * @receiver the augend. + * @param c the addend. + * @return the sum. + */ +context(ComplexField) +public operator fun Double.plus(c: Complex): Complex = add(toComplex(), c) + +/** + * Adds real number to complex one. + * + * @receiver the augend. + * @param d the addend. + * @return the sum. + */ +context(ComplexField) +public operator fun Complex.plus(d: Double): Complex = d + this + +/** + * Subtracts complex number from real one. + * + * @receiver the minuend. + * @param c the subtrahend. + * @return the difference. + */ +context(ComplexField) +public operator fun Double.minus(c: Complex): Complex = add(toComplex(), -c) + +/** + * Subtracts real number from complex one. + * + * @receiver the minuend. + * @param d the subtrahend. + * @return the difference. + */ +context(ComplexField) +public operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex()) + +/** + * Multiplies real number by complex one. + * + * @receiver the multiplier. + * @param c the multiplicand. + * @receiver the product. + */ +context(ComplexField) +public operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) + /** * Represents `double`-based complex number. * * @property re The real part. * @property im The imaginary part. */ -@OptIn(UnstableKMathAPI::class) public data class Complex(val re: Double, val im: Double) { public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) public constructor(re: Number) : this(re.toDouble(), 0.0) diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt index 3ef3428c6..154c7a866 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt @@ -166,10 +166,7 @@ public object QuaternionField : Field, Norm, override operator fun Quaternion.plus(other: Number): Quaternion = Quaternion(w + other.toDouble(), x, y, z) override operator fun Quaternion.minus(other: Number): Quaternion = Quaternion(w - other.toDouble(), x, y, z) - override operator fun Number.times(arg: Quaternion): Quaternion = - Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z) - - override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) + override fun negate(arg: Quaternion): Quaternion = Quaternion(-arg.w, -arg.x, -arg.y, -arg.z) override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) override fun bindSymbolOrNull(value: String): Quaternion? = when (value) { diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt index cbaaa815b..e4ac0ed6d 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt @@ -5,7 +5,7 @@ package space.kscience.kmath.complex -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.math.PI import kotlin.math.abs import kotlin.test.Test diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt index 7ad7f883d..59ca232df 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt @@ -6,6 +6,8 @@ package space.kscience.kmath.complex import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.minus +import space.kscience.kmath.operations.plus import kotlin.math.sqrt import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt index 4279471d4..21e625b4c 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt @@ -9,6 +9,8 @@ import space.kscience.kmath.expressions.FunctionalExpressionField import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.symbol import space.kscience.kmath.operations.bindSymbol +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt index 6784f3516..a04d931d0 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt @@ -6,6 +6,8 @@ package space.kscience.kmath.complex import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 4a35a54fb..41e467745 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") // id("com.xcporter.metaview") version "0.0.5" } @@ -72,3 +72,8 @@ readme { ref = "src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt" ) { "Automatic differentiation" } } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 68cc8e791..1fdec0acc 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -51,8 +51,8 @@ public open class FunctionalExpressionGroup>( ) : FunctionalExpressionAlgebra(algebra), Group> { override val zero: Expression get() = const(algebra.zero) - override fun Expression.unaryMinus(): Expression = - unaryOperation(GroupOps.MINUS_OPERATION, this) + override fun negate(arg: Expression): Expression = + unaryOperation(GroupOps.MINUS_OPERATION, arg) /** * Builds an Expression of addition of two another expressions. @@ -60,26 +60,25 @@ public open class FunctionalExpressionGroup>( override fun add(left: Expression, right: Expression): Expression = binaryOperation(GroupOps.PLUS_OPERATION, left, right) -// /** -// * Builds an Expression of multiplication of expression by number. -// */ -// override fun multiply(a: Expression, k: Number): Expression = Expression { arguments -> -// algebra.multiply(a.invoke(arguments), k) -// } - - public operator fun Expression.plus(arg: T): Expression = this + const(arg) - public operator fun Expression.minus(arg: T): Expression = this - const(arg) - public operator fun T.plus(arg: Expression): Expression = arg + this - public operator fun T.minus(arg: Expression): Expression = arg - this - override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = super.binaryOperationFunction(operation) - } +context(FunctionalExpressionGroup) +public operator fun > Expression.plus(arg: T): Expression = this + const(arg) + +context(FunctionalExpressionGroup) +public operator fun > Expression.minus(arg: T): Expression = this - const(arg) + +context(FunctionalExpressionGroup) +public operator fun > T.plus(arg: Expression): Expression = arg + this + +context(FunctionalExpressionGroup) +public operator fun > T.minus(arg: Expression): Expression = arg - this + public open class FunctionalExpressionRing>( algebra: A, ) : FunctionalExpressionGroup(algebra), Ring> { @@ -91,9 +90,6 @@ public open class FunctionalExpressionRing>( override fun multiply(left: Expression, right: Expression): Expression = binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right) - public operator fun Expression.times(arg: T): Expression = this * const(arg) - public operator fun T.times(arg: Expression): Expression = arg * this - override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) @@ -101,6 +97,12 @@ public open class FunctionalExpressionRing>( super.binaryOperationFunction(operation) } +context(FunctionalExpressionRing) +public operator fun > Expression.times(arg: T): Expression = this * const(arg) + +context(FunctionalExpressionRing) +public operator fun > T.times(arg: Expression): Expression = arg * this + public open class FunctionalExpressionField>( algebra: A, ) : FunctionalExpressionRing(algebra), Field>, ScaleOperations> { @@ -110,9 +112,6 @@ public open class FunctionalExpressionField>( override fun divide(left: Expression, right: Expression): Expression = binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right) - public operator fun Expression.div(arg: T): Expression = this / const(arg) - public operator fun T.div(arg: Expression): Expression = arg / this - override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) @@ -127,6 +126,12 @@ public open class FunctionalExpressionField>( super.bindSymbolOrNull(value) } +context(FunctionalExpressionField) +public operator fun > Expression.div(arg: T): Expression = this / const(arg) + +context(FunctionalExpressionField) +public operator fun > T.div(arg: Expression): Expression = arg / this + public open class FunctionalExpressionExtendedField>( algebra: A, ) : FunctionalExpressionField(algebra), ExtendedField> { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt index 4bd2a6c53..232b16d46 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt @@ -32,14 +32,12 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right) - override operator fun MST.unaryPlus(): MST.Unary = - unaryOperationFunction(GroupOps.PLUS_OPERATION)(this) - override operator fun MST.unaryMinus(): MST.Unary = - unaryOperationFunction(GroupOps.MINUS_OPERATION)(this) + override fun negate(arg: MST): MST.Unary = + unaryOperationFunction(GroupOps.MINUS_OPERATION)(arg) - override operator fun MST.minus(arg: MST): MST.Binary = - binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, arg) + override fun subtract(left: MST, right: MST): MST.Binary = + binaryOperationFunction(GroupOps.MINUS_OPERATION)(left, right) override fun scale(a: MST, value: Double): MST.Binary = binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value)) @@ -70,9 +68,8 @@ public object MstRing : Ring, NumbersAddOps, ScaleOperations { override fun multiply(left: MST, right: MST): MST.Binary = binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right) - override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } - override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } - override operator fun MST.minus(arg: MST): MST.Binary = MstGroup { this@minus - arg } + override fun negate(arg: MST): MST.Unary = MstGroup.negate(arg) + override fun subtract(left: MST, right: MST): MST.Binary = MstGroup.subtract(left, right) override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstGroup.binaryOperationFunction(operation) @@ -101,9 +98,8 @@ public object MstField : Field, NumbersAddOps, ScaleOperations { override fun divide(left: MST, right: MST): MST.Binary = binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right) - override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } - override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } - override operator fun MST.minus(arg: MST): MST.Binary = MstRing { this@minus - arg } + override fun negate(arg: MST): MST.Unary = MstRing.negate(arg) + override fun subtract(left: MST, right: MST): MST.Binary = MstRing.subtract(left, right) override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperationFunction(operation) @@ -142,9 +138,8 @@ public object MstExtendedField : ExtendedField, NumericAlgebra { override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right) override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right) - override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } - override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } - override operator fun MST.minus(arg: MST): MST.Binary = MstField { this@minus - arg } + override fun negate(arg: MST): MST.Unary = MstField.negate(arg) + override fun subtract(left: MST, right: MST): MST.Binary = MstField.subtract(left, right) override fun power(arg: MST, pow: Number): MST.Binary = binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index ac8c44446..15cfac1c5 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -163,8 +163,8 @@ public open class SimpleAutoDiffField>( // derive(const { this@minus.value - one * b.toDouble() }) { z -> d += z.d } - override fun AutoDiffValue.unaryMinus(): AutoDiffValue = - derive(const { -value }) { z -> d -= z.d } + override fun negate(arg: AutoDiffValue): AutoDiffValue = + derive(const { -arg.value }) { z -> arg.d -= z.d } // Basic math (+, -, *, /) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/specialExpressions.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/specialExpressions.kt index 907ce4004..ad4a7411e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/specialExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/specialExpressions.kt @@ -5,8 +5,7 @@ package space.kscience.kmath.expressions -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.asIterable +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.indices import kotlin.jvm.JvmName diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt index 4e6debc60..870bdb284 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt @@ -9,9 +9,7 @@ import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.DoubleFieldOpsND import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.asND -import space.kscience.kmath.operations.DoubleBufferOps -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index 715fad07b..34a1ef036 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -10,10 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.as1D -import space.kscience.kmath.operations.BufferRingOps -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt index b1812f49d..918824f4a 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt @@ -89,6 +89,5 @@ public class TransposedFeature(public val original: Matrix) : Ma * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` */ @Suppress("UNCHECKED_CAST") -@OptIn(UnstableKMathAPI::class) public fun Matrix.transpose(): Matrix = getFeature(TransposedFeature::class)?.original as? Matrix ?: VirtualMatrix(colNum, rowNum) { i, j -> get(j, i) }.withFeature(TransposedFeature(this)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt index ee7f1d8be..c990f8bc0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt @@ -7,6 +7,7 @@ package space.kscience.kmath.misc import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.plus import kotlin.jvm.JvmName /** diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index a9712e870..ac6230a8b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -127,51 +127,54 @@ public interface GroupOpsND> : GroupOps>, override fun add(left: StructureND, right: StructureND): StructureND = zip(left, right) { aValue, bValue -> add(aValue, bValue) } - // TODO move to extensions after KEEP-176 - - /** - * Adds an ND structure to an element of it. - * - * @receiver the augend. - * @param arg the addend. - * @return the sum. - */ - @OptIn(PerformancePitfall::class) - public operator fun StructureND.plus(arg: T): StructureND = this.map { value -> add(arg, value) } - - /** - * Subtracts an element from ND structure of it. - * - * @receiver the dividend. - * @param arg the divisor. - * @return the quotient. - */ - @OptIn(PerformancePitfall::class) - public operator fun StructureND.minus(arg: T): StructureND = this.map { value -> add(arg, -value) } - - /** - * Adds an element to ND structure of it. - * - * @receiver the augend. - * @param arg the addend. - * @return the sum. - */ - @OptIn(PerformancePitfall::class) - public operator fun T.plus(arg: StructureND): StructureND = arg.map { value -> add(this@plus, value) } - - /** - * Subtracts an ND structure from an element of it. - * - * @receiver the dividend. - * @param arg the divisor. - * @return the quotient. - */ - @OptIn(PerformancePitfall::class) - public operator fun T.minus(arg: StructureND): StructureND = arg.map { value -> add(-this@minus, value) } - public companion object } +/** + * Adds an ND structure to an element of it. + * + * @receiver the augend. + * @param arg the addend. + * @return the sum. + */ +context(GroupOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > StructureND.plus(arg: T): StructureND = this.map { value -> add(arg, value) } + +/** + * Subtracts an element from ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ +context(GroupOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > StructureND.minus(arg: T): StructureND = this.map { value -> add(arg, -value) } + +/** + * Adds an element to ND structure of it. + * + * @receiver the augend. + * @param arg the addend. + * @return the sum. + */ +context(GroupOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > T.plus(arg: StructureND): StructureND = arg.map { value -> add(this@plus, value) } + +/** + * Subtracts an ND structure from an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ +context(GroupOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > T.minus(arg: StructureND): StructureND = arg.map { value -> add(-this@minus, value) } + + public interface GroupND> : Group>, GroupOpsND, WithShape { override val zero: StructureND get() = structureND(shape) { elementAlgebra.zero } } @@ -194,31 +197,34 @@ public interface RingOpsND> : RingOps>, Gro override fun multiply(left: StructureND, right: StructureND): StructureND = zip(left, right) { aValue, bValue -> multiply(aValue, bValue) } - //TODO move to extensions with context receivers - - /** - * Multiplies an ND structure by an element of it. - * - * @receiver the multiplicand. - * @param arg the multiplier. - * @return the product. - */ - @OptIn(PerformancePitfall::class) - public operator fun StructureND.times(arg: T): StructureND = this.map { value -> multiply(arg, value) } - - /** - * Multiplies an element by a ND structure of it. - * - * @receiver the multiplicand. - * @param arg the multiplier. - * @return the product. - */ - @OptIn(PerformancePitfall::class) - public operator fun T.times(arg: StructureND): StructureND = arg.map { value -> multiply(this@times, value) } public companion object } +/** + * Multiplies an ND structure by an element of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ +context(RingOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > StructureND.times(arg: T): StructureND = + this.map { value -> multiply(arg, value) } + +/** + * Multiplies an element by a ND structure of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ +context(RingOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > T.times(arg: StructureND): StructureND = + arg.map { value -> multiply(this@times, value) } + public interface RingND> : Ring>, RingOpsND, GroupND, WithShape { override val one: StructureND get() = structureND(shape) { elementAlgebra.one } } @@ -245,31 +251,33 @@ public interface FieldOpsND> : override fun divide(left: StructureND, right: StructureND): StructureND = zip(left, right) { aValue, bValue -> divide(aValue, bValue) } - //TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md - /** - * Divides an ND structure by an element of it. - * - * @receiver the dividend. - * @param arg the divisor. - * @return the quotient. - */ - @OptIn(PerformancePitfall::class) - public operator fun StructureND.div(arg: T): StructureND = this.map { value -> divide(arg, value) } - - /** - * Divides an element by an ND structure of it. - * - * @receiver the dividend. - * @param arg the divisor. - * @return the quotient. - */ - @OptIn(PerformancePitfall::class) - public operator fun T.div(arg: StructureND): StructureND = arg.map { divide(it, this@div) } - @OptIn(PerformancePitfall::class) override fun scale(a: StructureND, value: Double): StructureND = a.map { scale(it, value) } } +/** + * Divides an ND structure by an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ +context(FieldOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > StructureND.div(arg: T): StructureND = this.map { value -> divide(arg, value) } + +/** + * Divides an element by an ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ +context(FieldOpsND) +@OptIn(PerformancePitfall::class) +public operator fun > T.div(arg: StructureND): StructureND = arg.map { divide(it, this@div) } + + public interface FieldND> : Field>, FieldOpsND, RingND, WithShape { override val one: StructureND get() = structureND(shape) { elementAlgebra.one } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt index b09344d12..80c4c28ac 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt @@ -101,7 +101,7 @@ public open class BufferedGroupNDOps>( override val bufferAlgebra: BufferAlgebra, override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder, ) : GroupOpsND, BufferAlgebraND { - override fun StructureND.unaryMinus(): StructureND = map { -it } + override fun negate(arg: StructureND): StructureND = arg.map { -it } } public open class BufferedRingOpsND>( diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index d01a8ee95..a08bc1f0c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -76,50 +76,17 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D override fun add(left: StructureND, right: StructureND): DoubleBufferND = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l + r } + override fun negate(arg: StructureND): DoubleBufferND = mapInline(arg.toBufferND()) { -it } + + override fun subtract(left: StructureND, right: StructureND): DoubleBufferND = + zipInline(left.toBufferND(), right.toBufferND()) { l: Double, r: Double -> l - r } + override fun multiply(left: StructureND, right: StructureND): DoubleBufferND = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r } - override fun StructureND.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it } - - override fun StructureND.div(arg: StructureND): DoubleBufferND = - zipInline(toBufferND(), arg.toBufferND()) { l, r -> l / r } - override fun divide(left: StructureND, right: StructureND): DoubleBufferND = zipInline(left.toBufferND(), right.toBufferND()) { l: Double, r: Double -> l / r } - override fun StructureND.div(arg: Double): DoubleBufferND = - mapInline(toBufferND()) { it / arg } - - override fun Double.div(arg: StructureND): DoubleBufferND = - mapInline(arg.toBufferND()) { this / it } - - override fun StructureND.unaryPlus(): DoubleBufferND = toBufferND() - - override fun StructureND.plus(arg: StructureND): DoubleBufferND = - zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l + r } - - override fun StructureND.minus(arg: StructureND): DoubleBufferND = - zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l - r } - - override fun StructureND.times(arg: StructureND): DoubleBufferND = - zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l * r } - - override fun StructureND.times(k: Number): DoubleBufferND = - mapInline(toBufferND()) { it * k.toDouble() } - - override fun StructureND.div(k: Number): DoubleBufferND = - mapInline(toBufferND()) { it / k.toDouble() } - - override fun Number.times(arg: StructureND): DoubleBufferND = arg * this - - override fun StructureND.plus(arg: Double): DoubleBufferND = mapInline(toBufferND()) { it + arg } - - override fun StructureND.minus(arg: Double): StructureND = mapInline(toBufferND()) { it - arg } - - override fun Double.plus(arg: StructureND): StructureND = arg + this - - override fun Double.minus(arg: StructureND): StructureND = mapInline(arg.toBufferND()) { this - it } - override fun scale(a: StructureND, value: Double): DoubleBufferND = mapInline(a.toBufferND()) { it * value } @@ -181,7 +148,7 @@ public class DoubleFieldND(override val shape: Shape) : it.kpow(pow) } - override fun power(arg: StructureND, pow: Number): DoubleBufferND = if(pow.isInteger()){ + override fun power(arg: StructureND, pow: Number): DoubleBufferND = if (pow.isInteger()) { power(arg, pow.toInt()) } else { val dpow = pow.toDouble() diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt index d948cf36f..68631aeec 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt @@ -11,6 +11,7 @@ import space.kscience.kmath.misc.Featured import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.minus import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import kotlin.jvm.JvmName diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index 45ba32c13..93e7ee3a3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -9,12 +9,6 @@ import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.Ring.Companion.optimizedPower -/** - * Stub for DSL the [Algebra] is. - */ -@DslMarker -public annotation class KMathContext - /** * Represents an algebraic structure. * @@ -137,46 +131,26 @@ public interface GroupOps : Algebra { */ public fun add(left: T, right: T): T - // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176. - /** * The negation of this element. * - * @receiver this value. + * @param arg the element. * @return the additive inverse of this value. */ - public operator fun T.unaryMinus(): T - - /** - * Returns this value. - * - * @receiver this value. - * @return this value. - */ - public operator fun T.unaryPlus(): T = this - - /** - * Addition of two elements. - * - * @receiver the augend. - * @param arg the addend. - * @return the sum. - */ - public operator fun T.plus(arg: T): T = add(this, arg) + public fun negate(arg: T): T /** * Subtraction of two elements. * - * @receiver the minuend. - * @param arg the subtrahend. + * @parm left the minuend. + * @param right the subtrahend. * @return the difference. */ - public operator fun T.minus(arg: T): T = add(this, -arg) + public fun subtract(left: T, right: T): T = add(left, -right) - // Dynamic dispatch of operations override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { - PLUS_OPERATION -> { arg -> +arg } - MINUS_OPERATION -> { arg -> -arg } + PLUS_OPERATION -> { arg -> arg } + MINUS_OPERATION -> ::negate else -> super.unaryOperationFunction(operation) } @@ -199,6 +173,44 @@ public interface GroupOps : Algebra { } } +/** + * The negation of this element. + * + * @receiver the element. + * @return the additive inverse of this value. + */ +context(GroupOps) +public operator fun T.unaryMinus(): T = negate(this) + +/** + * Returns this value. + * + * @receiver this value. + * @return this value. + */ +context(GroupOps) +public operator fun T.unaryPlus(): T = this + +/** + * Addition of two elements. + * + * @receiver the augend. + * @param arg the addend. + * @return the sum. + */ +context(GroupOps) +public operator fun T.plus(arg: T): T = add(this, arg) + +/** + * Subtraction of two elements. + * + * @receiver the minuend. + * @param arg the subtrahend. + * @return the difference. + */ +context(GroupOps) +public operator fun T.minus(arg: T): T = subtract(this, arg) + /** * Represents group i.e., algebraic structure with associative, binary operation [add]. * @@ -226,14 +238,6 @@ public interface RingOps : GroupOps { */ public fun multiply(left: T, right: T): T - /** - * Multiplies this element by scalar. - * - * @receiver the multiplier. - * @param arg the multiplicand. - */ - public operator fun T.times(arg: T): T = multiply(this, arg) - override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { TIMES_OPERATION -> ::multiply else -> super.binaryOperationFunction(operation) @@ -247,6 +251,15 @@ public interface RingOps : GroupOps { } } +/** + * Multiplies two elements. + * + * @receiver the multiplier. + * @param arg the multiplicand. + */ +context(RingOps) +public operator fun T.times(arg: T): T = multiply(this, arg) + /** * Represents ring i.e., algebraic structure with two associative binary operations called "addition" and * "multiplication" and their neutral elements. @@ -264,7 +277,7 @@ public interface Ring : Group, RingOps { */ public fun power(arg: T, pow: UInt): T = optimizedPower(arg, pow) - public companion object{ + public companion object { /** * Raises [arg] to the non-negative integer power [exponent]. * @@ -311,15 +324,6 @@ public interface FieldOps : RingOps { */ public fun divide(left: T, right: T): T - /** - * Division of two elements. - * - * @receiver the dividend. - * @param arg the divisor. - * @return the quotient. - */ - public operator fun T.div(arg: T): T = divide(this, arg) - override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { DIV_OPERATION -> ::divide else -> super.binaryOperationFunction(operation) @@ -333,6 +337,16 @@ public interface FieldOps : RingOps { } } +/** + * Division of two elements. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ +context(FieldOps) +public operator fun T.div(arg: T): T = divide(this, arg) + /** * Represents field i.e., algebraic structure with three operations: associative, commutative addition and * multiplication, and division. **This interface differs from the eponymous mathematical definition: fields in KMath @@ -345,7 +359,7 @@ public interface Field : Ring, FieldOps, ScaleOperations, NumericAlg public fun power(arg: T, pow: Int): T = optimizedPower(arg, pow) - public companion object{ + public companion object { /** * Raises [arg] to the integer power [exponent]. * @@ -358,7 +372,10 @@ public interface Field : Ring, FieldOps, ScaleOperations, NumericAlg * @author Iaroslav Postovalov, Evgeniy Zhelenskiy */ private fun Field.optimizedPower(arg: T, exponent: Int): T = when { - exponent < 0 -> one / (this as Ring).optimizedPower(arg, if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt()) + exponent < 0 -> one / (this as Ring).optimizedPower( + arg, + if (exponent == Int.MIN_VALUE) Int.MAX_VALUE.toUInt().inc() else (-exponent).toUInt() + ) else -> (this as Ring).optimizedPower(arg, exponent.toUInt()) } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt index 99268348b..b4b39dfc0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt @@ -33,7 +33,7 @@ public object BigIntField : Field, NumbersAddOps, ScaleOperation override fun number(value: Number): BigInt = value.toLong().toBigInt() @Suppress("EXTENSION_SHADOWED_BY_MEMBER") - override fun BigInt.unaryMinus(): BigInt = -this + override fun negate(arg: BigInt): BigInt = -arg override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right) override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value)) override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt index 653552044..3764484ef 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt @@ -137,7 +137,7 @@ public open class BufferRingOps>( override fun add(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l + r } override fun multiply(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l * r } - override fun Buffer.unaryMinus(): Buffer = map { -it } + override fun negate(arg: Buffer): Buffer = arg.map { negate(it) } override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer = super.unaryOperationFunction(operation) @@ -159,7 +159,7 @@ public open class BufferFieldOps>( override fun divide(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l / r } override fun scale(a: Buffer, value: Double): Buffer = a.map { scale(it, value) } - override fun Buffer.unaryMinus(): Buffer = map { -it } + override fun negate(arg: Buffer): Buffer = arg.map { -it } override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = super.binaryOperationFunction(operation) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt index 0ee591acc..25fb5333d 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt @@ -6,7 +6,6 @@ package space.kscience.kmath.operations import space.kscience.kmath.linear.Point -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer @@ -32,8 +31,6 @@ public abstract class DoubleBufferOps : BufferAlgebra, Exte override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = super.binaryOperationFunction(operation) - override fun Buffer.unaryMinus(): DoubleBuffer = mapInline { -it } - override fun add(left: Buffer, right: Buffer): DoubleBuffer { require(right.size == left.size) { "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " @@ -46,18 +43,17 @@ public abstract class DoubleBufferOps : BufferAlgebra, Exte } else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] }) } - override fun Buffer.plus(arg: Buffer): DoubleBuffer = add(this, arg) + override fun negate(arg: Buffer): DoubleBuffer = arg.mapInline { -it } - override fun Buffer.minus(arg: Buffer): DoubleBuffer { - require(arg.size == this.size) { - "The size of the first buffer ${this.size} should be the same as for second one: ${arg.size} " + override fun subtract(left: Buffer, right: Buffer): DoubleBuffer { + require(left.size == right.size) { + "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " } - return if (this is DoubleBuffer && arg is DoubleBuffer) { - val aArray = this.array - val bArray = arg.array - DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] }) - } else DoubleBuffer(DoubleArray(this.size) { this[it] - arg[it] }) + return if (left is DoubleBuffer && right is DoubleBuffer) + DoubleBuffer(DoubleArray(left.size) { left.array[it] - right.array[it] }) + else + DoubleBuffer(DoubleArray(left.size) { left[it] - right[it] }) } // diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt index d0405c705..b0f4f45b8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt @@ -116,35 +116,39 @@ public interface ScaleOperations : Algebra { * @return the produce. */ public fun scale(a: T, value: Double): T - - /** - * Multiplication of this element by a scalar. - * - * @receiver the multiplier. - * @param k the multiplicand. - * @return the product. - */ - public operator fun T.times(k: Number): T = scale(this, k.toDouble()) - - /** - * Division of this element by scalar. - * - * @receiver the dividend. - * @param k the divisor. - * @return the quotient. - */ - public operator fun T.div(k: Number): T = scale(this, 1.0 / k.toDouble()) - - /** - * Multiplication of this number by element. - * - * @receiver the multiplier. - * @param arg the multiplicand. - * @return the product. - */ - public operator fun Number.times(arg: T): T = arg * this } + +/** + * Multiplication of this element by a scalar. + * + * @receiver the multiplier. + * @param k the multiplicand. + * @return the product. + */ +context(ScaleOperations) +public operator fun T.times(k: Number): T = scale(this, k.toDouble()) + +/** + * Division of this element by scalar. + * + * @receiver the dividend. + * @param k the divisor. + * @return the quotient. + */ +context(ScaleOperations) +public operator fun T.div(k: Number): T = scale(this, 1.0 / k.toDouble()) + +/** + * Multiplication of this number by element. + * + * @receiver the multiplier. + * @param arg the multiplicand. + * @return the product. + */ +context(ScaleOperations) +public operator fun Number.times(arg: T): T = arg * this + /** * A combination of [NumericAlgebra] and [Ring] that adds intrinsic simple operations on numbers like `T+1` * TODO to be removed and replaced by extensions after multiple receivers are there diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index 07a137415..1dcec24cc 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -68,7 +68,7 @@ public object DoubleField : ExtendedField, Norm, ScaleOp override inline val zero: Double get() = 0.0 override inline val one: Double get() = 1.0 - override inline fun number(value: Number): Double = value.toDouble() + override fun number(value: Number): Double = value.toDouble() override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { @@ -77,6 +77,8 @@ public object DoubleField : ExtendedField, Norm, ScaleOp } override inline fun add(left: Double, right: Double): Double = left + right + override inline fun negate(arg: Double): Double = -arg + override inline fun subtract(left: Double, right: Double): Double = left - right override inline fun multiply(left: Double, right: Double): Double = left * right override inline fun divide(left: Double, right: Double): Double = left / right @@ -108,12 +110,6 @@ public object DoubleField : ExtendedField, Norm, ScaleOp override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) override inline fun norm(arg: Double): Double = abs(arg) - - override inline fun Double.unaryMinus(): Double = -this - override inline fun Double.plus(arg: Double): Double = this + arg - override inline fun Double.minus(arg: Double): Double = this - arg - override inline fun Double.times(arg: Double): Double = this * arg - override inline fun Double.div(arg: Double): Double = this / arg } public val Double.Companion.algebra: DoubleField get() = DoubleField @@ -135,7 +131,10 @@ public object FloatField : ExtendedField, Norm { } override inline fun add(left: Float, right: Float): Float = left + right - override fun scale(a: Float, value: Double): Float = a * value.toFloat() + override inline fun negate(arg: Float): Float = -arg + override inline fun subtract(left: Float, right: Float): Float = left - right + + override inline fun scale(a: Float, value: Double): Float = a * value.toFloat() override inline fun multiply(left: Float, right: Float): Float = left * right @@ -162,12 +161,6 @@ public object FloatField : ExtendedField, Norm { override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) override inline fun norm(arg: Float): Float = abs(arg) - - override inline fun Float.unaryMinus(): Float = -this - override inline fun Float.plus(arg: Float): Float = this + arg - override inline fun Float.minus(arg: Float): Float = this - arg - override inline fun Float.times(arg: Float): Float = this * arg - override inline fun Float.div(arg: Float): Float = this / arg } public val Float.Companion.algebra: FloatField get() = FloatField @@ -185,13 +178,11 @@ public object IntRing : Ring, Norm, NumericAlgebra { override fun number(value: Number): Int = value.toInt() override inline fun add(left: Int, right: Int): Int = left + right + override inline fun negate(arg: Int): Int = -arg + override inline fun subtract(left: Int, right: Int): Int = left - right + override inline fun multiply(left: Int, right: Int): Int = left * right override inline fun norm(arg: Int): Int = abs(arg) - - override inline fun Int.unaryMinus(): Int = -this - override inline fun Int.plus(arg: Int): Int = this + arg - override inline fun Int.minus(arg: Int): Int = this - arg - override inline fun Int.times(arg: Int): Int = this * arg } public val Int.Companion.algebra: IntRing get() = IntRing @@ -209,13 +200,11 @@ public object ShortRing : Ring, Norm, NumericAlgebra override fun number(value: Number): Short = value.toShort() override inline fun add(left: Short, right: Short): Short = (left + right).toShort() - override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort() - override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() + override inline fun negate(arg: Short): Short = (-arg).toShort() + override inline fun subtract(left: Short, right: Short): Short = (left - right).toShort() - override inline fun Short.unaryMinus(): Short = (-this).toShort() - override inline fun Short.plus(arg: Short): Short = (this + arg).toShort() - override inline fun Short.minus(arg: Short): Short = (this - arg).toShort() - override inline fun Short.times(arg: Short): Short = (this * arg).toShort() + override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort() + override inline fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() } public val Short.Companion.algebra: ShortRing get() = ShortRing @@ -233,13 +222,12 @@ public object ByteRing : Ring, Norm, NumericAlgebra { override fun number(value: Number): Byte = value.toByte() override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte() - override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte() - override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() + override inline fun negate(arg: Byte): Byte = (-arg).toByte() + override inline fun subtract(left: Byte, right: Byte): Byte = (left - right).toByte() - override inline fun Byte.unaryMinus(): Byte = (-this).toByte() - override inline fun Byte.plus(arg: Byte): Byte = (this + arg).toByte() - override inline fun Byte.minus(arg: Byte): Byte = (this - arg).toByte() - override inline fun Byte.times(arg: Byte): Byte = (this * arg).toByte() + override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte() + + override inline fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() } public val Byte.Companion.algebra: ByteRing get() = ByteRing @@ -256,14 +244,13 @@ public object LongRing : Ring, Norm, NumericAlgebra { get() = 1L override fun number(value: Number): Long = value.toLong() - override inline fun add(left: Long, right: Long): Long = left + right - override inline fun multiply(left: Long, right: Long): Long = left * right - override fun norm(arg: Long): Long = abs(arg) - override inline fun Long.unaryMinus(): Long = (-this) - override inline fun Long.plus(arg: Long): Long = (this + arg) - override inline fun Long.minus(arg: Long): Long = (this - arg) - override inline fun Long.times(arg: Long): Long = (this * arg) + override inline fun add(left: Long, right: Long): Long = left + right + override inline fun negate(arg: Long): Long = -arg + override inline fun subtract(left: Long, right: Long): Long = left - right + + override inline fun multiply(left: Long, right: Long): Long = left * right + override inline fun norm(arg: Long): Long = abs(arg) } public val Long.Companion.algebra: LongRing get() = LongRing diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt index 80c5943cf..4e89e3377 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt @@ -6,6 +6,8 @@ package space.kscience.kmath.expressions import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFails diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt index 156334b2e..6ec5b1c6e 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt @@ -7,13 +7,10 @@ package space.kscience.kmath.expressions import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.BooleanAlgebra -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.test.Test import kotlin.test.assertEquals - internal class InterpretTest { @Test fun interpretation() { diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt index 201890933..d9241a773 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -5,8 +5,7 @@ package space.kscience.kmath.expressions -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.bindSymbol +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.asBuffer import kotlin.math.E diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt index b7b89d107..7a7210def 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt @@ -10,6 +10,7 @@ import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.structureND import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.times import space.kscience.kmath.testutils.FieldVerifier import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt index d33eb5112..3669740e5 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt @@ -8,10 +8,7 @@ package space.kscience.kmath.structures import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.* -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.Norm -import space.kscience.kmath.operations.algebra -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.math.abs import kotlin.math.pow import kotlin.test.Test diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt index 20a7b6a72..a24f9895b 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt @@ -5,8 +5,7 @@ package space.kscience.kmath.testutils -import space.kscience.kmath.operations.Field -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.test.assertEquals import kotlin.test.assertNotEquals diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt index daf18834a..3261f27b9 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt @@ -5,9 +5,7 @@ package space.kscience.kmath.testutils -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.ScaleOperations -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.test.assertEquals internal open class RingVerifier(algebra: A, a: T, b: T, c: T, x: Number) : diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt index 951197fc6..a3e62d4cd 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt @@ -5,9 +5,7 @@ package space.kscience.kmath.testutils -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.ScaleOperations -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import kotlin.test.assertEquals import kotlin.test.assertNotEquals diff --git a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt index 6e22c2381..5eefb4be8 100644 --- a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt @@ -19,10 +19,10 @@ public object JBigIntegerField : Ring, NumericAlgebra { override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right) - override operator fun BigInteger.minus(arg: BigInteger): BigInteger = subtract(arg) - override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right) + override fun negate(arg: BigInteger): BigInteger = arg.negate() + override fun subtract(left: BigInteger, right: BigInteger): BigInteger = left.subtract(right) - override operator fun BigInteger.unaryMinus(): BigInteger = negate() + override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right) } /** @@ -40,7 +40,9 @@ public abstract class JBigDecimalFieldBase internal constructor( get() = BigDecimal.ONE override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right) - override operator fun BigDecimal.minus(arg: BigDecimal): BigDecimal = subtract(arg) + override fun negate(arg: BigDecimal): BigDecimal = arg.negate(mathContext) + override fun subtract(left: BigDecimal, right: BigDecimal): BigDecimal = left.subtract(right) + override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun scale(a: BigDecimal, value: Double): BigDecimal = @@ -50,7 +52,6 @@ public abstract class JBigDecimalFieldBase internal constructor( override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) - override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) } /** diff --git a/kmath-coroutines/build.gradle.kts b/kmath-coroutines/build.gradle.kts index aa30c412b..92a72b4aa 100644 --- a/kmath-coroutines/build.gradle.kts +++ b/kmath-coroutines/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } kotlin.sourceSets { @@ -24,4 +24,9 @@ kotlin.sourceSets { readme { maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL -} \ No newline at end of file +} + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt index 7bf54d50f..add1a5b54 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt @@ -10,10 +10,7 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.runningReduce import kotlinx.coroutines.flow.scan -import space.kscience.kmath.operations.GroupOps -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.ScaleOperations -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* public fun Flow.cumulativeSum(group: GroupOps): Flow = group { runningReduce { sum, element -> sum + element } } diff --git a/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt b/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt index ac9eb773a..68b14b9af 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt @@ -39,15 +39,11 @@ public fun StructureND.deferred(index: IntArray): Deferred = public suspend fun StructureND.await(index: IntArray): T = if (this is LazyStructureND) await(index) else get(index) -/** - * PENDING would benefit from KEEP-176 - */ +context(CoroutineScope) public inline fun StructureND.mapAsyncIndexed( - scope: CoroutineScope, crossinline function: suspend (T, index: IntArray) -> R, -): LazyStructureND = LazyStructureND(scope, shape) { index -> function(get(index), index) } +): LazyStructureND = LazyStructureND(this@CoroutineScope, shape) { index -> function(get(index), index) } -public inline fun StructureND.mapAsync( - scope: CoroutineScope, - crossinline function: suspend (T) -> R, -): LazyStructureND = LazyStructureND(scope, shape) { index -> function(get(index)) } +context(CoroutineScope) +public inline fun StructureND.mapAsync(crossinline function: suspend (T) -> R): LazyStructureND = + LazyStructureND(this@CoroutineScope, shape) { index -> function(get(index)) } diff --git a/kmath-dimensions/build.gradle.kts b/kmath-dimensions/build.gradle.kts index 885f3c227..f4b878120 100644 --- a/kmath-dimensions/build.gradle.kts +++ b/kmath-dimensions/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } description = "A proof of concept module for adding type-safe dimensions to structures" @@ -23,3 +23,8 @@ kotlin.sourceSets { readme { maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-ejml/build.gradle.kts b/kmath-ejml/build.gradle.kts index 727d21e3a..cfb215366 100644 --- a/kmath-ejml/build.gradle.kts +++ b/kmath-ejml/build.gradle.kts @@ -33,10 +33,19 @@ readme { ) { "LinearSpace implementations." } } -kotlin.sourceSets.main { - val codegen by tasks.creating { - ejmlCodegen(kotlin.srcDirs.first().absolutePath + "/space/kscience/kmath/ejml/_generated.kt") - } +kotlin.sourceSets { + filter { it.name.contains("test", true) } + .map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings) + .forEach { + it.optIn("space.kscience.kmath.misc.PerformancePitfall") + it.optIn("space.kscience.kmath.misc.UnstableKMathAPI") + } - kotlin.srcDirs(files().builtBy(codegen)) + main { + val codegen by tasks.creating { + ejmlCodegen(kotlin.srcDirs.first().absolutePath + "/space/kscience/kmath/ejml/_generated.kt") + } + + kotlin.srcDirs(files().builtBy(codegen)) + } } diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt index dce739dc2..eec4b8b30 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt @@ -208,7 +208,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace object : InverseMatrixFeature { override val inverse: Matrix by lazy { val res = origin.copy() @@ -270,8 +270,8 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace by lazy { lup.getRowPivot(null).wrapMatrix() } } - else -> null - }?.let(type::cast) + else -> return null + }) } /** @@ -442,7 +442,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace object : InverseMatrixFeature { override val inverse: Matrix by lazy { val res = origin.copy() @@ -504,8 +504,8 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace by lazy { lup.getRowPivot(null).wrapMatrix() } } - else -> null - }?.let(type::cast) + else -> return null + }) } /** @@ -684,7 +684,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace object : QRDecompositionFeature { private val qr by lazy { DecompositionFactory_DSCC.qr(FillReducing.NONE).apply { decompose(origin.copy()) } @@ -733,8 +733,8 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace null - }?.let(type::cast) + else -> return null + }) } /** @@ -913,7 +913,7 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace object : QRDecompositionFeature { private val qr by lazy { DecompositionFactory_FSCC.qr(FillReducing.NONE).apply { decompose(origin.copy()) } @@ -962,8 +962,8 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace null - }?.let(type::cast) + else -> return null + }) } /** diff --git a/kmath-for-real/build.gradle.kts b/kmath-for-real/build.gradle.kts index 18c2c50ad..4b643beed 100644 --- a/kmath-for-real/build.gradle.kts +++ b/kmath-for-real/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } kotlin.sourceSets.commonMain { @@ -40,3 +40,8 @@ readme { "Uniform grid generators" } } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-functions/build.gradle.kts b/kmath-functions/build.gradle.kts index fadbac091..71dba6bb0 100644 --- a/kmath-functions/build.gradle.kts +++ b/kmath-functions/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } description = "Functions, integration and interpolation" @@ -32,3 +32,8 @@ readme { "Univariate and multivariate quadratures" } } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt index a36d36f52..f782496f3 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt @@ -71,7 +71,7 @@ public fun Polynomial.integrate( ): Polynomial where A : Field, A : NumericAlgebra = algebra { val integratedCoefficients = buildList(coefficients.size + 1) { add(zero) - coefficients.forEachIndexed{ index, t -> add(t / (number(index) + one)) } + coefficients.forEachIndexed { index, t -> add(t / (number(index) + one)) } } Polynomial(integratedCoefficients) } @@ -100,8 +100,8 @@ public class PolynomialSpace( ) : Group>, ScaleOperations> where C : Ring, C : ScaleOperations { override val zero: Polynomial = Polynomial(emptyList()) - override fun Polynomial.unaryMinus(): Polynomial = ring { - Polynomial(coefficients.map { -it }) + override fun negate(arg: Polynomial): Polynomial = ring { + Polynomial(arg.coefficients.map { -it }) } override fun add(left: Polynomial, right: Polynomial): Polynomial { diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt index 9785d7744..62b983e40 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt @@ -6,6 +6,9 @@ package space.kscience.kmath.integration import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.Field +import space.kscience.kmath.operations.minus +import space.kscience.kmath.operations.plus +import space.kscience.kmath.operations.times import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.indices diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt index 7815757aa..5b2e29873 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt @@ -6,10 +6,7 @@ package space.kscience.kmath.integration import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.Field -import space.kscience.kmath.operations.invoke -import space.kscience.kmath.operations.sum +import space.kscience.kmath.operations.* /** * Use double pass Simpson rule integration with a fixed number of points. @@ -60,7 +57,8 @@ public class SimpsonIntegrator( } @UnstableKMathAPI -public val Field.simpsonIntegrator: SimpsonIntegrator get() = SimpsonIntegrator(this) +public val Field.simpsonIntegrator: SimpsonIntegrator + get() = SimpsonIntegrator(this) /** * Use double pass Simpson rule integration with a fixed number of points. diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt index 34d7bcf41..2dfbc57f6 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt @@ -9,8 +9,7 @@ import space.kscience.kmath.data.XYColumnarData import space.kscience.kmath.functions.PiecewisePolynomial import space.kscience.kmath.functions.Polynomial import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.Field -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* @OptIn(UnstableKMathAPI::class) internal fun > insureSorted(points: XYColumnarData<*, T, *>) { diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt index afcb33bd4..84753ba4b 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt @@ -9,9 +9,7 @@ import space.kscience.kmath.data.XYColumnarData import space.kscience.kmath.functions.PiecewisePolynomial import space.kscience.kmath.functions.Polynomial import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.Field -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.MutableBufferFactory diff --git a/kmath-geometry/build.gradle.kts b/kmath-geometry/build.gradle.kts index 7eb814683..ddcec2874 100644 --- a/kmath-geometry/build.gradle.kts +++ b/kmath-geometry/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } kotlin.sourceSets.commonMain { @@ -13,3 +13,8 @@ kotlin.sourceSets.commonMain { readme { maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt index d00575bcc..e2321cd31 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt @@ -6,12 +6,11 @@ package space.kscience.kmath.geometry import space.kscience.kmath.linear.Point -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.minus import kotlin.math.sqrt -@OptIn(UnstableKMathAPI::class) public interface Vector2D : Point, Vector { public val x: Double public val y: Double @@ -44,7 +43,7 @@ public object Euclidean2DSpace : GeometrySpace, ScaleOperations, Vector { public val x: Double public val y: Double @@ -43,7 +42,7 @@ public object Euclidean3DSpace : GeometrySpace, ScaleOperations { + enabled = false +} diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt index fe3278026..7725283db 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt @@ -9,6 +9,7 @@ import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Group +import space.kscience.kmath.operations.plus /** * Common representation for atomic counters @@ -72,5 +73,3 @@ public class ObjectCounter(private val group: Group) : Counter { override val value: T get() = innerValue.value } - - diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/HistogramND.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/HistogramND.kt index 68b24db5d..a644653ca 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/HistogramND.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/HistogramND.kt @@ -11,9 +11,7 @@ import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.FieldOpsND import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.operations.Group -import space.kscience.kmath.operations.ScaleOperations -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* /** * @param T the type of the argument space diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogram1D.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogram1D.kt index e13928394..8acd6bb53 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogram1D.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogram1D.kt @@ -7,10 +7,7 @@ package space.kscience.kmath.histogram import space.kscience.kmath.domains.DoubleDomain1D import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.Group -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.ScaleOperations -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import kotlin.math.floor @@ -69,8 +66,8 @@ public class UniformHistogram1DGroup( ) } - override fun Histogram1D.unaryMinus(): UniformHistogram1D = valueAlgebra { - UniformHistogram1D(this@UniformHistogram1DGroup, produceFrom(this@unaryMinus).values.mapValues { -it.value }) + override fun negate(arg: Histogram1D): UniformHistogram1D = valueAlgebra { + UniformHistogram1D(this@UniformHistogram1DGroup, produceFrom(arg).values.mapValues { -it.value }) } override fun scale( diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt index 90ec29ce3..8c65bc60c 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt @@ -14,6 +14,7 @@ import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Field import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.unaryMinus import space.kscience.kmath.structures.* import kotlin.math.floor @@ -98,8 +99,8 @@ public class UniformHistogramGroupND>( return HistogramND(this, values) } - override fun HistogramND.unaryMinus(): HistogramND = - this * (-1) + override fun negate(arg: HistogramND): HistogramND = + HistogramND(this, valueAlgebraND { -arg.values }) } /** diff --git a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt index ca7c2f324..95120a30a 100644 --- a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt @@ -10,6 +10,7 @@ package space.kscience.kmath.histogram import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.minus import space.kscience.kmath.real.DoubleVector import kotlin.random.Random import kotlin.test.* diff --git a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/UniformHistogram1DTest.kt b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/UniformHistogram1DTest.kt index 09bf3939d..c4bef7346 100644 --- a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/UniformHistogram1DTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/UniformHistogram1DTest.kt @@ -10,6 +10,7 @@ import kotlinx.coroutines.test.runTest import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.plus import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.nextBuffer import kotlin.native.concurrent.ThreadLocal @@ -36,7 +37,7 @@ internal class UniformHistogram1DTest { @Test fun rebinDown() = runTest { val h1 = Histogram.uniform1D(DoubleField, 0.01).produce(generator.nextDoubleBuffer(10000)) - val h2 = Histogram.uniform1D(DoubleField,0.03).produceFrom(h1) + val h2 = Histogram.uniform1D(DoubleField, 0.03).produceFrom(h1) assertEquals(10000, h2.bins.sumOf { it.binValue }.toInt()) } @@ -44,13 +45,13 @@ internal class UniformHistogram1DTest { @Test fun rebinUp() = runTest { val h1 = Histogram.uniform1D(DoubleField, 0.03).produce(generator.nextDoubleBuffer(10000)) - val h2 = Histogram.uniform1D(DoubleField,0.01).produceFrom(h1) + val h2 = Histogram.uniform1D(DoubleField, 0.01).produceFrom(h1) assertEquals(10000, h2.bins.sumOf { it.binValue }.toInt()) } @ThreadLocal - companion object{ + companion object { private val generator = RandomGenerator.default(123) } } \ No newline at end of file diff --git a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramGroup.kt b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramGroup.kt index 6bec01f9b..969e30614 100644 --- a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramGroup.kt +++ b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramGroup.kt @@ -11,9 +11,7 @@ import space.kscience.kmath.domains.DoubleDomain1D import space.kscience.kmath.domains.center import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.sorted -import space.kscience.kmath.operations.Group -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.first import space.kscience.kmath.structures.indices @@ -123,7 +121,18 @@ public class TreeHistogramGroup( return TreeHistogram(bins) } - override fun TreeHistogram.unaryMinus(): TreeHistogram = this * (-1) + override fun negate(arg: TreeHistogram): TreeHistogram { + val bins = TreeMap>().apply { + arg.bins.forEach { bin -> + put( + bin.domain.center, + Bin1D(bin.domain, valueAlgebra { -bin.binValue }) + ) + } + } + + return TreeHistogram(bins) + } override val zero: TreeHistogram = produce { } } diff --git a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt index 91d952a76..32653a7ad 100644 --- a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt +++ b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt @@ -29,6 +29,8 @@ public object JafamaDoubleField : ExtendedField, Norm, S } override inline fun add(left: Double, right: Double): Double = left + right + override inline fun negate(arg: Double): Double = -arg + override inline fun subtract(left: Double, right: Double): Double = left - right override inline fun multiply(left: Double, right: Double): Double = left * right override inline fun divide(left: Double, right: Double): Double = left / right @@ -55,12 +57,6 @@ public object JafamaDoubleField : ExtendedField, Norm, S override inline fun ln(arg: Double): Double = FastMath.log(arg) override inline fun norm(arg: Double): Double = FastMath.abs(arg) - - override inline fun Double.unaryMinus(): Double = -this - override inline fun Double.plus(arg: Double): Double = this + arg - override inline fun Double.minus(arg: Double): Double = this - arg - override inline fun Double.times(arg: Double): Double = this * arg - override inline fun Double.div(arg: Double): Double = this / arg } /** @@ -80,6 +76,8 @@ public object StrictJafamaDoubleField : ExtendedField, Norm, Norm= 0) { "offset shouldn't be negative: $offset" } diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt index 1dc318517..7d4aa176b 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt @@ -7,28 +7,29 @@ package space.kscience.kmath.multik import org.jetbrains.kotlinx.multik.ndarray.data.DataType import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.ExponentialOperations -import space.kscience.kmath.operations.TrigonometricOperations +import space.kscience.kmath.operations.* public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra(), TrigonometricOperations>, ExponentialOperations> { override val elementAlgebra: DoubleField get() = DoubleField override val type: DataType get() = DataType.DoubleDataType - override fun sin(arg: StructureND): MultikTensor = multikMath.mathEx.sin(arg.asMultik().array).wrap() + override fun sin(arg: StructureND): MultikTensor = + multikMath.mathEx.sin(arg.asMultik().array).wrap() - override fun cos(arg: StructureND): MultikTensor = multikMath.mathEx.cos(arg.asMultik().array).wrap() + override fun cos(arg: StructureND): MultikTensor = + multikMath.mathEx.cos(arg.asMultik().array).wrap() override fun tan(arg: StructureND): MultikTensor = sin(arg) / cos(arg) - override fun asin(arg: StructureND): MultikTensor = arg.map { asin(it) } + override fun asin(arg: StructureND): MultikTensor = arg.map { asin(it) } override fun acos(arg: StructureND): MultikTensor = arg.map { acos(it) } override fun atan(arg: StructureND): MultikTensor = arg.map { atan(it) } - override fun exp(arg: StructureND): MultikTensor = multikMath.mathEx.exp(arg.asMultik().array).wrap() + override fun exp(arg: StructureND): MultikTensor = + multikMath.mathEx.exp(arg.asMultik().array).wrap() override fun ln(arg: StructureND): MultikTensor = multikMath.mathEx.log(arg.asMultik().array).wrap() @@ -39,7 +40,7 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra): MultikTensor { val expPlus = exp(arg) val expMinus = exp(-arg) - return (expPlus - expMinus) / (expPlus + expMinus) + return divide((expPlus - expMinus), (expPlus + expMinus)) } override fun asinh(arg: StructureND): MultikTensor = arg.map { asinh(it) } @@ -51,4 +52,3 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra get() = MultikDoubleAlgebra public val DoubleField.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra - diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index 250ef7e7f..d08faa159 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -50,7 +50,7 @@ private fun MultiArray.asD2Array(): D2Array { else throw ClassCastException("Cannot cast MultiArray to NDArray.") } -public abstract class MultikTensorAlgebra> : TensorAlgebra +public abstract class MultikTensorAlgebra> : TensorAlgebra where T : Number, T : Comparable { public abstract val type: DataType @@ -138,14 +138,8 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra get(intArrayOf(0)) } else null - override fun T.plus(arg: StructureND): MultikTensor = - arg.plus(this) - - override fun StructureND.plus(arg: T): MultikTensor = - asMultik().array.deepCopy().apply { plusAssign(arg) }.wrap() - - override fun StructureND.plus(arg: StructureND): MultikTensor = - asMultik().array.plus(arg.asMultik().array).wrap() + override fun add(left: StructureND, right: StructureND): MultikTensor = + left.asMultik().array.plus(right.asMultik().array).wrap() override fun Tensor.plusAssign(value: T) { if (this is MultikTensor) { @@ -163,13 +157,8 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra } } - override fun T.minus(arg: StructureND): MultikTensor = (-(arg.asMultik().array - this)).wrap() - - override fun StructureND.minus(arg: T): MultikTensor = - asMultik().array.deepCopy().apply { minusAssign(arg) }.wrap() - - override fun StructureND.minus(arg: StructureND): MultikTensor = - asMultik().array.minus(arg.asMultik().array).wrap() + override fun subtract(left: StructureND, right: StructureND): MultikTensor = + left.asMultik().array.minus(right.asMultik().array).wrap() override fun Tensor.minusAssign(value: T) { if (this is MultikTensor) { @@ -187,14 +176,8 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra } } - override fun T.times(arg: StructureND): MultikTensor = - arg.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap() - - override fun StructureND.times(arg: T): Tensor = - asMultik().array.deepCopy().apply { timesAssign(arg) }.wrap() - - override fun StructureND.times(arg: StructureND): MultikTensor = - asMultik().array.times(arg.asMultik().array).wrap() + override fun multiply(left: StructureND, right: StructureND): MultikTensor = + left.asMultik().array.times(right.asMultik().array).wrap() override fun Tensor.timesAssign(value: T) { if (this is MultikTensor) { @@ -212,12 +195,11 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra } } - override fun StructureND.unaryMinus(): MultikTensor = - asMultik().array.unaryMinus().wrap() + override fun negate(arg: StructureND): MultikTensor = (-arg.asMultik().array).wrap() override fun Tensor.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() - override fun Tensor.transpose(i: Int, j: Int): MultikTensor = asMultik().array.transpose(i, j).wrap() + override fun StructureND.transpose(i: Int, j: Int): MultikTensor = asMultik().array.transpose(i, j).wrap() override fun Tensor.view(shape: IntArray): MultikTensor { require(shape.all { it > 0 }) @@ -282,16 +264,36 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra } } -public abstract class MultikDivisionTensorAlgebra> +context(MultikTensorAlgebra) +public operator fun > T.plus(arg: StructureND): MultikTensor where T : Comparable, T : Number = + arg.plus(this) + +context(MultikTensorAlgebra) +public operator fun > StructureND.plus(arg: T): MultikTensor where T : Comparable, T : Number = + asMultik().array.deepCopy().apply { plusAssign(arg) }.wrap() + +context(MultikTensorAlgebra) +public operator fun > T.minus(arg: StructureND): MultikTensor where T : Comparable, T : Number = + (-(arg.asMultik().array - this)).wrap() + +context(MultikTensorAlgebra) +public operator fun > StructureND.minus(arg: T): MultikTensor where T : Comparable, T : Number = + asMultik().array.deepCopy().apply { minusAssign(arg) }.wrap() + +context(MultikTensorAlgebra) +public operator fun > T.times(arg: StructureND): MultikTensor where T : Comparable, T : Number = + arg.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap() + +context(MultikTensorAlgebra) +public operator fun > StructureND.times(arg: T): Tensor where T : Comparable, T : Number = + asMultik().array.deepCopy().apply { timesAssign(arg) }.wrap() + + +public abstract class MultikDivisionTensorAlgebra> : MultikTensorAlgebra(), TensorPartialDivisionAlgebra where T : Number, T : Comparable { - override fun T.div(arg: StructureND): MultikTensor = arg.map { elementAlgebra.divide(this@div, it) } - - override fun StructureND.div(arg: T): MultikTensor = - asMultik().array.deepCopy().apply { divAssign(arg) }.wrap() - - override fun StructureND.div(arg: StructureND): MultikTensor = - asMultik().array.div(arg.asMultik().array).wrap() + override fun divide(left: StructureND, right: StructureND): MultikTensor = + left.asMultik().array.div(right.asMultik().array).wrap() override fun Tensor.divAssign(value: T) { if (this is MultikTensor) { @@ -310,6 +312,18 @@ public abstract class MultikDivisionTensorAlgebra> } } +context(MultikDivisionTensorAlgebra) +public operator fun > StructureND.div(arg: StructureND): MultikTensor where T : Number, T : Comparable = + divide(this, arg) + +context(MultikDivisionTensorAlgebra) +public operator fun > T.div(arg: StructureND): MultikTensor where T : Number, T : Comparable = + arg.map { elementAlgebra.divide(this@div, it) } + +context(MultikDivisionTensorAlgebra) +public operator fun > StructureND.div(arg: T): MultikTensor where T : Number, T : Comparable = + asMultik().array.deepCopy().apply { divAssign(arg) }.wrap() + public object MultikFloatAlgebra : MultikDivisionTensorAlgebra() { override val elementAlgebra: FloatField get() = FloatField override val type: DataType get() = DataType.FloatDataType @@ -340,4 +354,4 @@ public object MultikLongAlgebra : MultikTensorAlgebra() { } public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra -public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file +public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt index e29a3f467..29072c294 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt @@ -45,6 +45,7 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND.mapIndexed( transform: C.(index: IntArray, T) -> T, ): Nd4jArrayStructure { @@ -53,6 +54,7 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND, right: StructureND, @@ -72,15 +74,14 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND> : GroupOpsND, Nd4jArrayAlgebra { - override fun add(left: StructureND, right: StructureND): Nd4jArrayStructure = left.ndArray.add(right.ndArray).wrap() - override operator fun StructureND.minus(arg: StructureND): Nd4jArrayStructure = - ndArray.sub(arg.ndArray).wrap() + override fun subtract(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.sub(right.ndArray).wrap() - override operator fun StructureND.unaryMinus(): Nd4jArrayStructure = - ndArray.neg().wrap() + override fun negate(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.neg().wrap() public fun multiply(a: StructureND, k: Number): Nd4jArrayStructure = a.ndArray.mul(k).wrap() @@ -92,7 +93,6 @@ public sealed interface Nd4jArrayGroupOps> : GroupOpsND * @param T the type of the element contained in ND structure. * @param R the type of ring of structure elements. */ -@OptIn(UnstableKMathAPI::class) public sealed interface Nd4jArrayRingOps> : RingOpsND, Nd4jArrayGroupOps { override fun multiply(left: StructureND, right: StructureND): Nd4jArrayStructure = @@ -201,23 +201,29 @@ public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps, value: Double): Nd4jArrayStructure = a.ndArray.mul(value).wrap() - override operator fun StructureND.div(arg: Double): Nd4jArrayStructure = ndArray.div(arg).wrap() - - override operator fun StructureND.plus(arg: Double): Nd4jArrayStructure = ndArray.add(arg).wrap() - - override operator fun StructureND.minus(arg: Double): Nd4jArrayStructure = ndArray.sub(arg).wrap() - - override operator fun StructureND.times(arg: Double): Nd4jArrayStructure = ndArray.mul(arg).wrap() - - override operator fun Double.div(arg: StructureND): Nd4jArrayStructure = - arg.ndArray.rdiv(this).wrap() - - override operator fun Double.minus(arg: StructureND): Nd4jArrayStructure = - arg.ndArray.rsub(this).wrap() - public companion object : DoubleNd4jArrayFieldOps() } +context(DoubleNd4jArrayFieldOps) +public operator fun StructureND.div(arg: Double): Nd4jArrayStructure = ndArray.div(arg).wrap() + +context(DoubleNd4jArrayFieldOps) +public operator fun StructureND.plus(arg: Double): Nd4jArrayStructure = ndArray.add(arg).wrap() + +context(DoubleNd4jArrayFieldOps) +public operator fun StructureND.minus(arg: Double): Nd4jArrayStructure = ndArray.sub(arg).wrap() + +context(DoubleNd4jArrayFieldOps) +public operator fun StructureND.times(arg: Double): Nd4jArrayStructure = ndArray.mul(arg).wrap() + +context(DoubleNd4jArrayFieldOps) +public operator fun Double.div(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rdiv(this).wrap() + +context(DoubleNd4jArrayFieldOps) +public operator fun Double.minus(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rsub(this).wrap() + public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND @@ -246,27 +252,33 @@ public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps, value: Double): StructureND = a.ndArray.mul(value).wrap() - override operator fun StructureND.div(arg: Float): Nd4jArrayStructure = - ndArray.div(arg).wrap() - - override operator fun StructureND.plus(arg: Float): Nd4jArrayStructure = - ndArray.add(arg).wrap() - - override operator fun StructureND.minus(arg: Float): Nd4jArrayStructure = - ndArray.sub(arg).wrap() - - override operator fun StructureND.times(arg: Float): Nd4jArrayStructure = - ndArray.mul(arg).wrap() - - override operator fun Float.div(arg: StructureND): Nd4jArrayStructure = - arg.ndArray.rdiv(this).wrap() - - override operator fun Float.minus(arg: StructureND): Nd4jArrayStructure = - arg.ndArray.rsub(this).wrap() - public companion object : FloatNd4jArrayFieldOps() } +context(FloatNd4jArrayFieldOps) +public operator fun StructureND.div(arg: Float): Nd4jArrayStructure = + ndArray.div(arg).wrap() + +context(FloatNd4jArrayFieldOps) +public operator fun StructureND.plus(arg: Float): Nd4jArrayStructure = + ndArray.add(arg).wrap() + +context(FloatNd4jArrayFieldOps) +public operator fun StructureND.minus(arg: Float): Nd4jArrayStructure = + ndArray.sub(arg).wrap() + +context(FloatNd4jArrayFieldOps) +public operator fun StructureND.times(arg: Float): Nd4jArrayStructure = + ndArray.mul(arg).wrap() + +context(FloatNd4jArrayFieldOps) +public operator fun Float.div(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rdiv(this).wrap() + +context(FloatNd4jArrayFieldOps) +public operator fun Float.minus(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rsub(this).wrap() + public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps @@ -291,21 +303,25 @@ public open class IntNd4jArrayRingOps : Nd4jArrayRingOps { } } - override operator fun StructureND.plus(arg: Int): Nd4jArrayStructure = - ndArray.add(arg).wrap() - - override operator fun StructureND.minus(arg: Int): Nd4jArrayStructure = - ndArray.sub(arg).wrap() - - override operator fun StructureND.times(arg: Int): Nd4jArrayStructure = - ndArray.mul(arg).wrap() - - override operator fun Int.minus(arg: StructureND): Nd4jArrayStructure = - arg.ndArray.rsub(this).wrap() - public companion object : IntNd4jArrayRingOps() } +context(IntNd4jArrayRingOps) +public operator fun StructureND.plus(arg: Int): Nd4jArrayStructure = + ndArray.add(arg).wrap() + +context(IntNd4jArrayRingOps) +public operator fun StructureND.minus(arg: Int): Nd4jArrayStructure = + ndArray.sub(arg).wrap() + +context(IntNd4jArrayRingOps) +public operator fun StructureND.times(arg: Int): Nd4jArrayStructure = + ndArray.mul(arg).wrap() + +context(IntNd4jArrayRingOps) +public operator fun Int.minus(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rsub(this).wrap() + public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index ceb384f0d..e8702815d 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -16,6 +16,7 @@ import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Field import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra @@ -26,7 +27,7 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra /** * ND4J based [TensorAlgebra] implementation. */ -public sealed interface Nd4jTensorAlgebra> : AnalyticTensorAlgebra { +public sealed interface Nd4jTensorAlgebra> : AnalyticTensorAlgebra { /** * Wraps [INDArray] to [Nd4jArrayStructure]. @@ -51,10 +52,8 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe return structureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) } } - override fun T.plus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.add(this).wrap() - override fun StructureND.plus(arg: T): Nd4jArrayStructure = ndArray.add(arg).wrap() - - override fun StructureND.plus(arg: StructureND): Nd4jArrayStructure = ndArray.add(arg.ndArray).wrap() + override fun add(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.add(right.ndArray).wrap() override fun Tensor.plusAssign(value: T) { ndArray.addi(value) @@ -64,9 +63,8 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe ndArray.addi(arg.ndArray) } - override fun T.minus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() - override fun StructureND.minus(arg: T): Nd4jArrayStructure = ndArray.sub(arg).wrap() - override fun StructureND.minus(arg: StructureND): Nd4jArrayStructure = ndArray.sub(arg.ndArray).wrap() + override fun subtract(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.sub(right.ndArray).wrap() override fun Tensor.minusAssign(value: T) { ndArray.rsubi(value) @@ -76,12 +74,7 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe ndArray.subi(arg.ndArray) } - override fun T.times(arg: StructureND): Nd4jArrayStructure = arg.ndArray.mul(this).wrap() - - override fun StructureND.times(arg: T): Nd4jArrayStructure = - ndArray.mul(arg).wrap() - - override fun StructureND.times(arg: StructureND): Nd4jArrayStructure = ndArray.mul(arg.ndArray).wrap() + override fun multiply(left: StructureND, right: StructureND): Nd4jArrayStructure = left.ndArray.mul(right.ndArray).wrap() override fun Tensor.timesAssign(value: T) { ndArray.muli(value) @@ -91,9 +84,9 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe ndArray.mmuli(arg.ndArray) } - override fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() + override fun negate(arg: StructureND): Nd4jArrayStructure = arg.ndArray.neg().wrap() override fun Tensor.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() - override fun Tensor.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() + override fun StructureND.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() override fun StructureND.dot(other: StructureND): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() override fun StructureND.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = @@ -140,9 +133,7 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe override fun StructureND.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.std(true, keepDim, dim).wrap() - override fun T.div(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rdiv(this).wrap() - override fun StructureND.div(arg: T): Nd4jArrayStructure = ndArray.div(arg).wrap() - override fun StructureND.div(arg: StructureND): Nd4jArrayStructure = ndArray.div(arg.ndArray).wrap() + override fun divide(left: StructureND, right: StructureND): Nd4jArrayStructure = left.ndArray.div(right.ndArray).wrap() override fun Tensor.divAssign(value: T) { ndArray.divi(value) @@ -160,6 +151,36 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe } } +context(Nd4jTensorAlgebra) +public fun > T.plus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.add(this).wrap() + +context(Nd4jTensorAlgebra) +public fun > StructureND.plus(arg: T): Nd4jArrayStructure = ndArray.add(arg).wrap() + +context(Nd4jTensorAlgebra) +public fun > T.times(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.mul(this).wrap() + +context(Nd4jTensorAlgebra) +public fun > StructureND.times(arg: T): Nd4jArrayStructure = + ndArray.mul(arg).wrap() + +context(Nd4jTensorAlgebra) +public operator fun > T.minus(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rsub(this).wrap() + +context(Nd4jTensorAlgebra) +public operator fun > StructureND.minus(arg: T): Nd4jArrayStructure = + ndArray.sub(arg).wrap() + +context(Nd4jTensorAlgebra) +public operator fun > T.div(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rdiv(this).wrap() + +context(Nd4jTensorAlgebra) +public operator fun > StructureND.div(arg: T): Nd4jArrayStructure = + ndArray.div(arg).wrap() + /** * [Double] specialization of [Nd4jTensorAlgebra]. */ diff --git a/kmath-optimization/build.gradle.kts b/kmath-optimization/build.gradle.kts index b920b9267..f3a7924ad 100644 --- a/kmath-optimization/build.gradle.kts +++ b/kmath-optimization/build.gradle.kts @@ -1,6 +1,6 @@ plugins { id("ru.mipt.npm.gradle.mpp") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } kscience { @@ -22,3 +22,8 @@ kotlin.sourceSets { readme { maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-stat/build.gradle.kts b/kmath-stat/build.gradle.kts index 41a1666f8..fe55ea882 100644 --- a/kmath-stat/build.gradle.kts +++ b/kmath-stat/build.gradle.kts @@ -1,6 +1,6 @@ plugins { id("ru.mipt.npm.gradle.mpp") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } kscience { @@ -24,4 +24,9 @@ kotlin.sourceSets { readme { maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL -} \ No newline at end of file +} + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt index 1f442c09b..deadcaca4 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt @@ -9,9 +9,7 @@ import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.ConstantChain import space.kscience.kmath.chains.map import space.kscience.kmath.chains.zip -import space.kscience.kmath.operations.Group -import space.kscience.kmath.operations.ScaleOperations -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* /** * Implements [Sampler] by sampling only certain [value]. @@ -51,5 +49,5 @@ public class SamplerSpace(public val algebra: S) : Group.unaryMinus(): Sampler = scale(this, -1.0) + override fun negate(arg: Sampler): Sampler = scale(arg, -1.0) } diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt index b40739ee0..5e2b99ba1 100644 --- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt +++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt @@ -153,7 +153,8 @@ public abstract class TensorFlowAlgebra> internal c override fun StructureND.plus(arg: T): TensorFlowOutput = operate(arg, ops.math::add) - override fun StructureND.plus(arg: StructureND): TensorFlowOutput = operate(arg, ops.math::add) + override fun add(left: StructureND, right: StructureND): TensorFlowOutput = + left.operate(right, ops.math::add) override fun Tensor.plusAssign(value: T): Unit = operateInPlace(value, ops.math::add) @@ -161,7 +162,8 @@ public abstract class TensorFlowAlgebra> internal c override fun StructureND.minus(arg: T): TensorFlowOutput = operate(arg, ops.math::sub) - override fun StructureND.minus(arg: StructureND): TensorFlowOutput = operate(arg, ops.math::sub) + override fun subtract(left: StructureND, right: StructureND): TensorFlowOutput = + left.operate(right, ops.math::sub) override fun T.minus(arg: StructureND): Tensor = operate(arg, ops.math::sub) @@ -173,19 +175,19 @@ public abstract class TensorFlowAlgebra> internal c override fun StructureND.times(arg: T): TensorFlowOutput = operate(arg, ops.math::mul) - override fun StructureND.times(arg: StructureND): TensorFlowOutput = operate(arg, ops.math::mul) + override fun multiply(left: StructureND, right: StructureND): TensorFlowOutput = left.operate(right, ops.math::mul) override fun Tensor.timesAssign(value: T): Unit = operateInPlace(value, ops.math::mul) override fun Tensor.timesAssign(arg: StructureND): Unit = operateInPlace(arg, ops.math::mul) - override fun StructureND.unaryMinus(): TensorFlowOutput = operate(ops.math::neg) + override fun negate(arg: StructureND): TensorFlowOutput = arg.operate(ops.math::neg) override fun Tensor.get(i: Int): Tensor = operate { TODO("Not yet implemented") } - override fun Tensor.transpose(i: Int, j: Int): Tensor = operate { + override fun StructureND.transpose(i: Int, j: Int): Tensor = operate { ops.linalg.transpose(it, ops.constant(intArrayOf(i, j))) } diff --git a/kmath-tensorflow/src/test/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowOps.kt b/kmath-tensorflow/src/test/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowOps.kt index 308469eed..b271d8a39 100644 --- a/kmath-tensorflow/src/test/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowOps.kt +++ b/kmath-tensorflow/src/test/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowOps.kt @@ -2,8 +2,8 @@ package space.kscience.kmath.tensorflow import org.junit.jupiter.api.Test import space.kscience.kmath.nd.get -import space.kscience.kmath.nd.structureND import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.plus import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.sum import kotlin.test.assertEquals @@ -12,7 +12,7 @@ class DoubleTensorFlowOps { @Test fun basicOps() { val res = DoubleField.produceWithTF { - val initial = structureND(2, 2) { 1.0 } + val initial = structureND(intArrayOf(2, 2)) { 1.0 } initial + (initial * 2.0) } @@ -21,7 +21,7 @@ class DoubleTensorFlowOps { } @Test - fun dot(){ + fun dot() { val dim = 1000 val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224) @@ -33,14 +33,14 @@ class DoubleTensorFlowOps { } @Test - fun extensionOps(){ + fun extensionOps() { val res = DoubleField.produceWithTF { - val i = structureND(2, 2) { 0.5 } + val i = structureND(intArrayOf(2, 2)) { 0.5 } sin(i).pow(2) + cos(i).pow(2) } - assertEquals(1.0, res[0,0],0.01) + assertEquals(1.0, res[0, 0], 0.01) } diff --git a/kmath-tensors/build.gradle.kts b/kmath-tensors/build.gradle.kts index 66316d21d..fb38ddb93 100644 --- a/kmath-tensors/build.gradle.kts +++ b/kmath-tensors/build.gradle.kts @@ -1,7 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") - id("ru.mipt.npm.gradle.native") +// id("ru.mipt.npm.gradle.native") } kotlin.sourceSets { @@ -40,3 +40,8 @@ readme { ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt" ) { "Advanced linear algebra operations like LU decomposition, SVD, etc." } } + +// Testing multi-receiver! +tasks.withType { + enabled = false +} diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index 3ed34ae5e..40f11332b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -15,7 +15,7 @@ import space.kscience.kmath.operations.Field * * @param T the type of items closed under analytic functions in the tensors. */ -public interface AnalyticTensorAlgebra> : +public interface AnalyticTensorAlgebra> : TensorPartialDivisionAlgebra, ExtendedFieldOps> { /** @@ -146,4 +146,4 @@ public interface AnalyticTensorAlgebra> : override fun acosh(arg: StructureND): StructureND = arg.acosh() override fun atanh(arg: StructureND): StructureND = arg.atanh() -} \ No newline at end of file +} diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index 0bddc3f9c..5a6e7f29f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Field * * @param T the type of items closed under division in the tensors. */ -public interface LinearOpsTensorAlgebra> : TensorPartialDivisionAlgebra { +public interface LinearOpsTensorAlgebra> : TensorPartialDivisionAlgebra { /** * Computes the determinant of a square matrix input, or of each square matrix in a batched input. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 86d4eaa4e..e3547f174 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -7,7 +7,7 @@ package space.kscience.kmath.tensors.api import space.kscience.kmath.nd.RingOpsND import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.* /** * Algebra over a ring on [Tensor]. @@ -15,7 +15,7 @@ import space.kscience.kmath.operations.Ring * * @param T the type of items in the tensors. */ -public interface TensorAlgebra> : RingOpsND { +public interface TensorAlgebra> : RingOpsND { /** * Returns a single tensor value of unit dimension if tensor shape equals to [1]. * @@ -31,32 +31,6 @@ public interface TensorAlgebra> : RingOpsND { public fun StructureND.value(): T = valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") - /** - * Each element of the tensor [arg] is added to this value. - * The resulting tensor is returned. - * - * @param arg tensor to be added. - * @return the sum of this value and tensor [arg]. - */ - override operator fun T.plus(arg: StructureND): Tensor - - /** - * Adds the scalar [arg] to each element of this tensor and returns a new resulting tensor. - * - * @param arg the number to be added to each element of this tensor. - * @return the sum of this tensor and [arg]. - */ - override operator fun StructureND.plus(arg: T): Tensor - - /** - * Each element of the tensor [arg] is added to each element of this tensor. - * The resulting tensor is returned. - * - * @param arg tensor to be added. - * @return the sum of this tensor and [arg]. - */ - override operator fun StructureND.plus(arg: StructureND): Tensor - /** * Adds the scalar [value] to each element of this tensor. * @@ -71,23 +45,6 @@ public interface TensorAlgebra> : RingOpsND { */ public operator fun Tensor.plusAssign(arg: StructureND) - /** - * Each element of the tensor [arg] is subtracted from this value. - * The resulting tensor is returned. - * - * @param arg tensor to be subtracted. - * @return the difference between this value and tensor [arg]. - */ - override operator fun T.minus(arg: StructureND): Tensor - - /** - * Subtracts the scalar [arg] from each element of this tensor and returns a new resulting tensor. - * - * @param arg the number to be subtracted from each element of this tensor. - * @return the difference between this tensor and [arg]. - */ - override operator fun StructureND.minus(arg: T): Tensor - /** * Each element of the tensor [arg] is subtracted from each element of this tensor. * The resulting tensor is returned. @@ -95,7 +52,7 @@ public interface TensorAlgebra> : RingOpsND { * @param arg tensor to be subtracted. * @return the difference between this tensor and [arg]. */ - override operator fun StructureND.minus(arg: StructureND): Tensor + override fun subtract(left: StructureND, right: StructureND): Tensor /** * Subtracts the scalar [value] from each element of this tensor. @@ -111,33 +68,6 @@ public interface TensorAlgebra> : RingOpsND { */ public operator fun Tensor.minusAssign(arg: StructureND) - - /** - * Each element of the tensor [arg] is multiplied by this value. - * The resulting tensor is returned. - * - * @param arg tensor to be multiplied. - * @return the product of this value and tensor [arg]. - */ - override operator fun T.times(arg: StructureND): Tensor - - /** - * Multiplies the scalar [arg] by each element of this tensor and returns a new resulting tensor. - * - * @param arg the number to be multiplied by each element of this tensor. - * @return the product of this tensor and [arg]. - */ - override operator fun StructureND.times(arg: T): Tensor - - /** - * Each element of the tensor [arg] is multiplied by each element of this tensor. - * The resulting tensor is returned. - * - * @param arg tensor to be multiplied. - * @return the product of this tensor and [arg]. - */ - override operator fun StructureND.times(arg: StructureND): Tensor - /** * Multiplies the scalar [value] by each element of this tensor. * @@ -152,13 +82,6 @@ public interface TensorAlgebra> : RingOpsND { */ public operator fun Tensor.timesAssign(arg: StructureND) - /** - * Numerical negative, element-wise. - * - * @return tensor negation of the original tensor. - */ - override operator fun StructureND.unaryMinus(): Tensor - /** * Returns the tensor at index i * For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html @@ -176,7 +99,7 @@ public interface TensorAlgebra> : RingOpsND { * @param j the second dimension to be transposed * @return transposed tensor */ - public fun Tensor.transpose(i: Int = -2, j: Int = -1): Tensor + public fun StructureND.transpose(i: Int = -2, j: Int = -1): Tensor /** * Returns a new tensor with the same data as the self tensor but of a different shape. @@ -326,7 +249,30 @@ public interface TensorAlgebra> : RingOpsND { */ public fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor - override fun add(left: StructureND, right: StructureND): Tensor = left + right + /** + * Each element of the tensor [right] is added to each element of [left] tensor. + * The resulting tensor is returned. + * + * @param left tensor to be added. + * @param right tensor to be added. + * @return the sum of [left] tensor and [right] one. + */ + override fun add(left: StructureND, right: StructureND): Tensor - override fun multiply(left: StructureND, right: StructureND): Tensor = left * right + /** + * Numerical negative, element-wise. + * + * @return tensor negation of the original tensor. + */ + override fun negate(arg: StructureND): Tensor + + /** + * Each element of the tensor [right] is multiplied by each element of [left] tensor. + * The resulting tensor is returned. + * + * @param left tensor to be multiplied. + * @param right tensor to be multiplied. + * @return the product of [left] tensor and [right] one. + */ + override fun multiply(left: StructureND, right: StructureND): Tensor } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt index 9c492cda1..7260ddd41 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt @@ -15,35 +15,16 @@ import space.kscience.kmath.operations.Field * * @param T the type of items closed under division in the tensors. */ -public interface TensorPartialDivisionAlgebra> : TensorAlgebra, FieldOpsND { - +public interface TensorPartialDivisionAlgebra> : TensorAlgebra, FieldOpsND { /** - * Each element of the tensor [arg] is divided by this value. + * Each element of the tensor [right] is divided by each element of [left] tensor. * The resulting tensor is returned. * - * @param arg tensor to divide by. - * @return the division of this value by the tensor [arg]. + * @param left tensor to be divided by. + * @param right tensor to be divided by. + * @return the division of [left] tensor by [right] one. */ - override operator fun T.div(arg: StructureND): Tensor - - /** - * Divide by the scalar [arg] each element of this tensor returns a new resulting tensor. - * - * @param arg the number to divide by each element of this tensor. - * @return the division of this tensor by the [arg]. - */ - override operator fun StructureND.div(arg: T): Tensor - - /** - * Each element of the tensor [arg] is divided by each element of this tensor. - * The resulting tensor is returned. - * - * @param arg tensor to be divided by. - * @return the division of this tensor by [arg]. - */ - override operator fun StructureND.div(arg: StructureND): Tensor - - override fun divide(left: StructureND, right: StructureND): StructureND = left.div(right) + override fun divide(left: StructureND, right: StructureND): Tensor /** * Divides by the scalar [value] each element of this tensor. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index e412ab5bb..0e1529705 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -22,8 +22,8 @@ import space.kscience.kmath.tensors.core.internal.tensor @PerformancePitfall public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { - override fun StructureND.plus(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + override fun add(left: StructureND, right: StructureND): DoubleTensor { + val broadcast = broadcastTensors(left.tensor, right.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> @@ -40,8 +40,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } } - override fun StructureND.minus(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + override fun subtract(left: StructureND, right: StructureND): DoubleTensor { + val broadcast = broadcastTensors(left.tensor, right.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> @@ -58,8 +58,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } } - override fun StructureND.times(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + override fun multiply(left: StructureND, right: StructureND): DoubleTensor { + val broadcast = broadcastTensors(left.tensor, right.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> @@ -77,8 +77,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } } - override fun StructureND.div(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + override fun divide(left: StructureND, right: StructureND): DoubleTensor { + val broadcast = broadcastTensors(left.tensor, right.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index e9dc34748..d9a9cfd80 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -10,7 +10,7 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.* -import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.indices import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra @@ -206,21 +206,12 @@ public open class DoubleTensorAlgebra : public fun StructureND.copy(): DoubleTensor = DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) - override fun Double.plus(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this + override fun add(left: StructureND, right: StructureND): DoubleTensor { + checkShapesCompatible(left, right) + val resBuffer = DoubleArray(left.tensor.numElements) { i -> + left.tensor.mutableBuffer.array()[i] + right.tensor.mutableBuffer.array()[i] } - return DoubleTensor(arg.shape, resBuffer) - } - - override fun StructureND.plus(arg: Double): DoubleTensor = arg + tensor - - override fun StructureND.plus(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg.tensor) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] + arg.tensor.mutableBuffer.array()[i] - } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(left.tensor.shape, resBuffer) } override fun Tensor.plusAssign(value: Double) { @@ -230,33 +221,19 @@ public open class DoubleTensorAlgebra : } override fun Tensor.plusAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg.tensor) + checkShapesCompatible(tensor, arg) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Double.minus(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - this - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + override fun subtract(left: StructureND, right: StructureND): DoubleTensor { + checkShapesCompatible(left, right) + val resBuffer = DoubleArray(left.tensor.numElements) { i -> + left.tensor.mutableBuffer.array()[i] - right.tensor.mutableBuffer.array()[i] } - return DoubleTensor(arg.shape, resBuffer) - } - - override fun StructureND.minus(arg: Double): DoubleTensor { - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg - } - return DoubleTensor(tensor.shape, resBuffer) - } - - override fun StructureND.minus(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] - arg.tensor.mutableBuffer.array()[i] - } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(left.tensor.shape, resBuffer) } override fun Tensor.minusAssign(value: Double) { @@ -273,22 +250,13 @@ public open class DoubleTensorAlgebra : } } - override fun Double.times(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this + override fun multiply(left: StructureND, right: StructureND): DoubleTensor { + checkShapesCompatible(left, right) + val resBuffer = DoubleArray(left.tensor.numElements) { i -> + left.tensor.mutableBuffer.array()[left.tensor.bufferStart + i] * + right.tensor.mutableBuffer.array()[right.tensor.bufferStart + i] } - return DoubleTensor(arg.shape, resBuffer) - } - - override fun StructureND.times(arg: Double): DoubleTensor = arg * tensor - - override fun StructureND.times(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] * - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] - } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(left.tensor.shape, resBuffer) } override fun Tensor.timesAssign(value: Double) { @@ -305,27 +273,13 @@ public open class DoubleTensorAlgebra : } } - override fun Double.div(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + override fun divide(left: StructureND, right: StructureND): DoubleTensor { + checkShapesCompatible(left, right) + val resBuffer = DoubleArray(left.tensor.numElements) { i -> + left.tensor.mutableBuffer.array()[right.tensor.bufferStart + i] / + right.tensor.mutableBuffer.array()[right.tensor.bufferStart + i] } - return DoubleTensor(arg.shape, resBuffer) - } - - override fun StructureND.div(arg: Double): DoubleTensor { - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg - } - return DoubleTensor(shape, resBuffer) - } - - override fun StructureND.div(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] / - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] - } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(right.tensor.shape, resBuffer) } override fun Tensor.divAssign(value: Double) { @@ -342,14 +296,14 @@ public open class DoubleTensorAlgebra : } } - override fun StructureND.unaryMinus(): DoubleTensor { - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() + override fun negate(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i].unaryMinus() } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(arg.tensor.shape, resBuffer) } - override fun Tensor.transpose(i: Int, j: Int): DoubleTensor { + override fun StructureND.transpose(i: Int, j: Int): DoubleTensor { val ii = tensor.minusIndex(i) val jj = tensor.minusIndex(j) checkTranspose(tensor.dimension, ii, jj) @@ -383,7 +337,7 @@ public open class DoubleTensorAlgebra : override infix fun StructureND.dot(other: StructureND): DoubleTensor { if (tensor.shape.size == 1 && other.shape.size == 1) { - return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) + return DoubleTensor(intArrayOf(1), doubleArrayOf((tensor * other).tensor.mutableBuffer.array().sum())) } var newThis = tensor.copy() @@ -551,7 +505,7 @@ public open class DoubleTensorAlgebra : * @param tensors the [List] of tensors with same shapes to concatenate * @return tensor with concatenation result */ - public fun stack(tensors: List>): DoubleTensor { + public fun stack(tensors: List>): DoubleTensor { check(tensors.isNotEmpty()) { "List must have at least 1 element" } val shape = tensors[0].shape check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } @@ -586,8 +540,10 @@ public open class DoubleTensorAlgebra : } val resNumElements = resShape.reduce(Int::times) val init = foldFunction(DoubleArray(1) { 0.0 }) - val resTensor = BufferedTensor(resShape, - MutableBuffer.auto(resNumElements) { init }, 0) + val resTensor = BufferedTensor( + resShape, + MutableBuffer.auto(resNumElements) { init }, 0 + ) for (index in resTensor.indices) { val prefix = index.take(dim).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray() @@ -882,7 +838,8 @@ public open class DoubleTensorAlgebra : return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) } - override fun StructureND.symEig(): Pair = symEigJacobi(maxIteration = 50, epsilon = 1e-15) + override fun StructureND.symEig(): Pair = + symEigJacobi(maxIteration = 50, epsilon = 1e-15) /** * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, @@ -1138,7 +1095,85 @@ public open class DoubleTensorAlgebra : override fun StructureND.lu(): Triple = lu(1e-9) } +/** + * Divides each element of the tensor [arg] by this double. + */ +context(DoubleTensorAlgebra) +public operator fun Double.div(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + } + return DoubleTensor(arg.shape, resBuffer) +} + +/** + * Divides by the double [arg] each element of this tensor returns a new resulting tensor. + */ +context(DoubleTensorAlgebra) +public operator fun StructureND.div(arg: Double): DoubleTensor { + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg + } + return DoubleTensor(shape, resBuffer) +} + +/** + * Adds each element of the tensor [arg] to this double. + */ +context(DoubleTensorAlgebra) +public operator fun Double.plus(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this + } + return DoubleTensor(arg.shape, resBuffer) +} + +/** + * Each element of the [arg] tensor is added to each element of the receiver tensor. The resulting tensor is returned. + */ +context(DoubleTensorAlgebra) +public operator fun StructureND.plus(arg: StructureND): DoubleTensor = add(this, arg) + +/** + * Adds the scalar [arg] to each element of this tensor and returns a new resulting tensor. + */ +context(DoubleTensorAlgebra) + public operator fun StructureND.plus(arg: Double): DoubleTensor = arg.plus(this) + +/** + * Subtracts each element of the tensor [arg] from this value. + */ +context(DoubleTensorAlgebra) +public operator fun Double.minus(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] - this + } + return DoubleTensor(arg.shape, resBuffer) +} + +/** + * Subtracts the scalar [arg] from each element of this tensor and returns a new resulting tensor. + */ +context(DoubleTensorAlgebra) +public operator fun StructureND.minus(arg: Double): DoubleTensor = arg.plus(-this) + +context(DoubleTensorAlgebra) +public operator fun StructureND.times(arg: StructureND): DoubleTensor = multiply(this, arg) + +context(DoubleTensorAlgebra) +public operator fun Double.times(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this + } + return DoubleTensor(arg.shape, resBuffer) +} + +/** + * Multiplies each element of the [arg] tensor is by each element of receiver tensor and returns a new resulting tensor. + */ +context(DoubleTensorAlgebra) +public operator fun StructureND.times(arg: Double): DoubleTensor = arg * tensor + public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra + public val DoubleField.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra - - diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt index aba6167ce..bc9be062e 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt @@ -12,10 +12,7 @@ import space.kscience.kmath.nd.as2D import space.kscience.kmath.operations.asSequence import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.VirtualBuffer -import space.kscience.kmath.tensors.core.BufferedTensor -import space.kscience.kmath.tensors.core.DoubleTensor -import space.kscience.kmath.tensors.core.DoubleTensorAlgebra -import space.kscience.kmath.tensors.core.IntTensor +import space.kscience.kmath.tensors.core.* import kotlin.math.abs import kotlin.math.min import kotlin.math.sqrt @@ -316,7 +313,7 @@ internal fun DoubleTensorAlgebra.svdHelper( outerProduct[i * v.shape[0] + j] = u[i].value() * v[j].value() } } - a = a - singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct)) + a = subtract(a, singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct))) } var v: DoubleTensor var u: DoubleTensor diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt index 6788ae792..0781d7b68 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt @@ -5,7 +5,7 @@ package space.kscience.kmath.tensors.core -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.core.internal.* import kotlin.test.Test import kotlin.test.assertTrue @@ -73,8 +73,8 @@ internal class TestBroadcasting { @Test fun testBroadcastOuterTensorsShapes() = DoubleTensorAlgebra { - val tensor1 = fromArray(intArrayOf(2, 1, 3, 2, 3), DoubleArray(2 * 1 * 3 * 2 * 3) {0.0}) - val tensor2 = fromArray(intArrayOf(4, 2, 5, 1, 3, 3), DoubleArray(4 * 2 * 5 * 1 * 3 * 3) {0.0}) + val tensor1 = fromArray(intArrayOf(2, 1, 3, 2, 3), DoubleArray(2 * 1 * 3 * 2 * 3) { 0.0 }) + val tensor2 = fromArray(intArrayOf(4, 2, 5, 1, 3, 3), DoubleArray(4 * 2 * 5 * 1 * 3 * 3) { 0.0 }) val tensor3 = fromArray(intArrayOf(1, 1), doubleArrayOf(500.0)) val res = broadcastOuterTensors(tensor1, tensor2, tensor3) @@ -95,16 +95,16 @@ internal class TestBroadcasting { val tensor32 = tensor3 - tensor2 assertTrue(tensor21.shape contentEquals intArrayOf(2, 3)) - assertTrue(tensor21.mutableBuffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) + assertTrue(tensor21.tensor.mutableBuffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3)) assertTrue( - tensor31.mutableBuffer.array() + tensor31.tensor.mutableBuffer.array() contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0) ) assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3)) - assertTrue(tensor32.mutableBuffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) + assertTrue(tensor32.tensor.mutableBuffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) } } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index e025d4b71..fddbd3b2b 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -5,9 +5,10 @@ package space.kscience.kmath.tensors.core -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.svd1d +import space.kscience.kmath.tensors.core.internal.tensor import kotlin.math.abs import kotlin.test.Test import kotlin.test.assertEquals @@ -147,7 +148,7 @@ internal class TestDoubleLinearOpsTensorAlgebra { ) val low = sigma.cholesky() val sigmChol = low dot low.transpose() - assertTrue(sigma.eq(sigmChol)) + assertTrue(sigma.tensor.eq(sigmChol)) } @Test @@ -162,7 +163,7 @@ internal class TestDoubleLinearOpsTensorAlgebra { } @Test - fun testSVD() = DoubleTensorAlgebra{ + fun testSVD() = DoubleTensorAlgebra { testSVDFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))) testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0))) } @@ -181,10 +182,8 @@ internal class TestDoubleLinearOpsTensorAlgebra { val tensorSigma = tensor + tensor.transpose() val (tensorS, tensorV) = tensorSigma.symEig() val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) - assertTrue(tensorSigma.eq(tensorSigmaCalc)) + assertTrue(tensorSigma.tensor.eq(tensorSigmaCalc)) } - - } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 205ae2fee..e65d0d259 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -5,8 +5,7 @@ package space.kscience.kmath.tensors.core - -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.core.internal.array import kotlin.test.Test import kotlin.test.assertFalse @@ -24,7 +23,7 @@ internal class TestDoubleTensorAlgebra { @Test fun testDoubleDiv() = DoubleTensorAlgebra { val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0)) - val res = 2.0/tensor + val res = 2.0 / tensor assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 0.5)) } @@ -127,10 +126,12 @@ internal class TestDoubleTensorAlgebra { assertTrue(res11.shape contentEquals intArrayOf(2, 2)) val res45 = tensor4.dot(tensor5) - assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf( - 36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0, - 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 - )) + assertTrue( + res45.mutableBuffer.array() contentEquals doubleArrayOf( + 36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0, + 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 + ) + ) assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3)) } @@ -140,31 +141,44 @@ internal class TestDoubleTensorAlgebra { val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor3 = zeros(intArrayOf(2, 3, 4, 5)) - assertTrue(diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals - intArrayOf(2, 3, 4, 5, 5)) - assertTrue(diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals - intArrayOf(2, 3, 4, 6, 6)) - assertTrue(diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals - intArrayOf(7, 2, 3, 7, 4)) + assertTrue( + diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals + intArrayOf(2, 3, 4, 5, 5) + ) + assertTrue( + diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals + intArrayOf(2, 3, 4, 6, 6) + ) + assertTrue( + diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals + intArrayOf(7, 2, 3, 7, 4) + ) val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0) assertTrue(diagonal1.shape contentEquals intArrayOf(3, 3)) - assertTrue(diagonal1.mutableBuffer.array() contentEquals - doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)) + assertTrue( + diagonal1.mutableBuffer.array() contentEquals + doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0) + ) val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0) assertTrue(diagonal1Offset.shape contentEquals intArrayOf(4, 4)) - assertTrue(diagonal1Offset.mutableBuffer.array() contentEquals - doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)) + assertTrue( + diagonal1Offset.mutableBuffer.array() contentEquals + doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0) + ) val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2) assertTrue(diagonal2.shape contentEquals intArrayOf(4, 2, 4)) - assertTrue(diagonal2.mutableBuffer.array() contentEquals - doubleArrayOf( - 0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, - 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0, - 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 6.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + assertTrue( + diagonal2.mutableBuffer.array() contentEquals + doubleArrayOf( + 0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, + 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0, + 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 6.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 + ) + ) } @Test diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt index 4eedcb5ee..8585fdce0 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt @@ -9,15 +9,14 @@ import org.jetbrains.bio.viktor.F64FlatArray import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableBuffer -@Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE") @JvmInline public value class ViktorBuffer(public val flatArray: F64FlatArray) : MutableBuffer { override val size: Int get() = flatArray.size - override inline fun get(index: Int): Double = flatArray[index] + override fun get(index: Int): Double = flatArray[index] - override inline fun set(index: Int, value: Double) { + override fun set(index: Int, value: Double) { flatArray[index] = value } diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt index 1d4d6cebd..840092edc 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt @@ -9,15 +9,11 @@ package space.kscience.kmath.viktor import org.jetbrains.bio.viktor.F64Array import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.* -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.ExtendedFieldOps -import space.kscience.kmath.operations.NumbersAddOps -import space.kscience.kmath.operations.PowerOperations +import space.kscience.kmath.operations.* -@OptIn(UnstableKMathAPI::class, PerformancePitfall::class) -@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +@OptIn(PerformancePitfall::class) +@Suppress("OVERRIDE_BY_INLINE") public open class ViktorFieldOpsND : FieldOpsND, ExtendedFieldOps>, @@ -38,8 +34,6 @@ public open class ViktorFieldOpsND : } }.asStructure() - override fun StructureND.unaryMinus(): StructureND = -1 * this - @PerformancePitfall override fun StructureND.map(transform: DoubleField.(Double) -> Double): ViktorStructureND = F64Array(*shape).apply { @@ -74,21 +68,14 @@ public open class ViktorFieldOpsND : override fun add(left: StructureND, right: StructureND): ViktorStructureND = (left.f64Buffer + right.f64Buffer).asStructure() + override fun negate(arg: StructureND): StructureND = -1 * arg + + override fun subtract(left: StructureND, right: StructureND): ViktorStructureND = + (left.f64Buffer - right.f64Buffer).asStructure() + override fun scale(a: StructureND, value: Double): ViktorStructureND = (a.f64Buffer * value).asStructure() - override fun StructureND.plus(arg: StructureND): ViktorStructureND = - (f64Buffer + arg.f64Buffer).asStructure() - - override fun StructureND.minus(arg: StructureND): ViktorStructureND = - (f64Buffer - arg.f64Buffer).asStructure() - - override fun StructureND.times(k: Number): ViktorStructureND = - (f64Buffer * k.toDouble()).asStructure() - - override fun StructureND.plus(arg: Double): ViktorStructureND = - (f64Buffer.plus(arg)).asStructure() - override fun sin(arg: StructureND): ViktorStructureND = arg.map { sin(it) } override fun cos(arg: StructureND): ViktorStructureND = arg.map { cos(it) } override fun tan(arg: StructureND): ViktorStructureND = arg.map { tan(it) } diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt index 25ca3a10e..01032a463 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt @@ -10,13 +10,12 @@ import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.MutableStructureND -@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructureND { override val shape: IntArray get() = f64Buffer.shape - override inline fun get(index: IntArray): Double = f64Buffer.get(*index) + override fun get(index: IntArray): Double = f64Buffer.get(*index) - override inline fun set(index: IntArray, value: Double) { + override fun set(index: IntArray, value: Double) { f64Buffer.set(*index, value = value) } @@ -26,5 +25,3 @@ public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructur } public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this) - -